├── keras_aug ├── _src │ ├── layers │ │ ├── __init__.py │ │ ├── base │ │ │ └── __init__.py │ │ ├── vision │ │ │ ├── __init__.py │ │ │ ├── identity.py │ │ │ ├── random_invert.py │ │ │ ├── to_dtype.py │ │ │ ├── random_solarize.py │ │ │ ├── random_auto_contrast.py │ │ │ ├── rescale_test.py │ │ │ ├── random_grayscale.py │ │ │ ├── random_posterize.py │ │ │ ├── rescale.py │ │ │ ├── random_auto_contrast_test.py │ │ │ ├── gaussian_noise.py │ │ │ ├── random_invert_test.py │ │ │ ├── random_channel_permutation.py │ │ │ ├── gaussian_blur_test.py │ │ │ ├── random_sharpen.py │ │ │ ├── gaussian_noise_test.py │ │ │ ├── rand_augment_test.py │ │ │ ├── identity_test.py │ │ │ ├── trivial_augment_test.py │ │ │ ├── to_dtype_test.py │ │ │ ├── normalize.py │ │ │ ├── mix_up_test.py │ │ │ ├── normalize_test.py │ │ │ ├── max_bounding_box.py │ │ │ ├── cut_mix_test.py │ │ │ ├── random_equalize_test.py │ │ │ ├── random_posterize_test.py │ │ │ ├── random_hsv_test.py │ │ │ ├── random_grayscale_test.py │ │ │ ├── random_solarize_test.py │ │ │ ├── random_channel_permutation_test.py │ │ │ ├── gaussian_blur.py │ │ │ ├── random_sharpen_test.py │ │ │ ├── mix_up.py │ │ │ ├── max_bounding_box_test.py │ │ │ ├── random_equalize.py │ │ │ └── random_erasing_test.py │ │ └── composition │ │ │ ├── __init__.py │ │ │ ├── random_apply_test.py │ │ │ ├── random_choice_test.py │ │ │ ├── random_order_test.py │ │ │ ├── random_order.py │ │ │ ├── random_apply.py │ │ │ └── random_choice.py │ ├── utils │ │ ├── __init__.py │ │ ├── test_utils.py │ │ └── argument_validation.py │ ├── backend │ │ ├── __init__.py │ │ ├── bounding_box_test.py │ │ └── dynamic_backend.py │ ├── visualization │ │ ├── __init__.py │ │ ├── draw_segmentation_masks.py │ │ └── draw_bounding_boxes.py │ ├── ops │ │ ├── __init__.py │ │ ├── bounding_box.py │ │ └── image.py │ ├── version.py │ ├── testing │ │ └── test_case.py │ └── keras_aug_export.py ├── ops │ ├── __init__.py │ ├── bounding_box │ │ └── __init__.py │ └── image │ │ └── __init__.py ├── layers │ ├── base │ │ └── __init__.py │ ├── __init__.py │ ├── composition │ │ └── __init__.py │ └── vision │ │ └── __init__.py ├── __init__.py └── visualization │ └── __init__.py ├── requirements.txt ├── requirements_ci.txt ├── shell ├── format.sh ├── run_guides.sh ├── api_gen.sh └── lint.sh ├── api_gen.py ├── .gitignore ├── .pre-commit-config.yaml ├── .github ├── dependabot.yml └── workflows │ ├── release.yml │ └── actions.yml ├── conftest.py ├── guides ├── quick_start.py ├── oxford_yolov8_aug.py └── voc_yolov8_aug.py ├── docs ├── generate_semantic_segmentation_gif.py └── generate_object_detection_gif.py └── pyproject.toml /keras_aug/_src/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_aug/_src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_aug/_src/backend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_aug/_src/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/composition/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keras_aug/_src/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from keras_aug._src.ops import bounding_box 2 | from keras_aug._src.ops import image 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # For Gradio app 2 | --extra-index-url https://download.pytorch.org/whl/cpu 3 | torch 4 | torchvision 5 | 6 | keras 7 | -------------------------------------------------------------------------------- /requirements_ci.txt: -------------------------------------------------------------------------------- 1 | # For CI 2 | tensorflow-cpu 3 | 4 | --extra-index-url https://download.pytorch.org/whl/cpu 5 | torch 6 | torchvision 7 | 8 | jax[cpu] 9 | 10 | keras 11 | -------------------------------------------------------------------------------- /keras_aug/_src/version.py: -------------------------------------------------------------------------------- 1 | from keras_aug._src.keras_aug_export import keras_aug_export 2 | 3 | __version__ = "1.1.1" 4 | 5 | 6 | @keras_aug_export("keras_aug") 7 | def version(): 8 | return __version__ 9 | -------------------------------------------------------------------------------- /keras_aug/ops/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from keras_aug.ops import bounding_box 8 | from keras_aug.ops import image 9 | -------------------------------------------------------------------------------- /shell/format.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -Eeuo pipefail 3 | 4 | base_dir=$(dirname $(dirname $0)) 5 | isort --sp "${base_dir}/pyproject.toml" . 6 | black --config "${base_dir}/pyproject.toml" . 7 | ruff check --config "${base_dir}/pyproject.toml" --fix . 8 | -------------------------------------------------------------------------------- /keras_aug/layers/base/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 8 | -------------------------------------------------------------------------------- /shell/run_guides.sh: -------------------------------------------------------------------------------- 1 | export KERAS_BACKEND=tensorflow 2 | export TF_CPP_MIN_LOG_LEVEL=3 3 | python3 -m guides.voc_yolov8_aug && echo "Finished guides.voc_yolov8_aug" 4 | python3 -m guides.oxford_yolov8_aug && echo "Finished guides.oxford_yolov8_aug" 5 | rm output_* && echo "All passed!" 6 | -------------------------------------------------------------------------------- /keras_aug/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from keras_aug.layers import base 8 | from keras_aug.layers import composition 9 | from keras_aug.layers import vision 10 | -------------------------------------------------------------------------------- /shell/api_gen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -Eeuo pipefail 3 | 4 | base_dir=$(dirname $(dirname $0)) 5 | 6 | echo "Generating api directory with public APIs..." 7 | python3 "${base_dir}"/api_gen.py 8 | 9 | echo "Formatting api directory..." 10 | bash "${base_dir}"/shell/format.sh 11 | 12 | echo -e "\nAPI generation finish!" 13 | -------------------------------------------------------------------------------- /keras_aug/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from keras_aug import layers 8 | from keras_aug import ops 9 | from keras_aug import visualization 10 | from keras_aug._src.version import version 11 | 12 | __version__ = "1.1.1" 13 | -------------------------------------------------------------------------------- /keras_aug/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from keras_aug._src.visualization.draw_bounding_boxes import draw_bounding_boxes 8 | from keras_aug._src.visualization.draw_segmentation_masks import ( 9 | draw_segmentation_masks, 10 | ) 11 | -------------------------------------------------------------------------------- /api_gen.py: -------------------------------------------------------------------------------- 1 | import namex 2 | 3 | from keras_aug._src.version import __version__ 4 | 5 | namex.generate_api_files(package="keras_aug", code_directory="_src") 6 | 7 | # Add version string 8 | 9 | with open("keras_aug/__init__.py", "r") as f: 10 | contents = f.read() 11 | with open("keras_aug/__init__.py", "w") as f: 12 | contents += f'__version__ = "{__version__}"\n' 13 | f.write(contents) 14 | -------------------------------------------------------------------------------- /keras_aug/layers/composition/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from keras_aug._src.layers.composition.random_apply import RandomApply 8 | from keras_aug._src.layers.composition.random_choice import RandomChoice 9 | from keras_aug._src.layers.composition.random_order import RandomOrder 10 | -------------------------------------------------------------------------------- /keras_aug/ops/bounding_box/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from keras_aug._src.ops.bounding_box import affine 8 | from keras_aug._src.ops.bounding_box import clip_to_images 9 | from keras_aug._src.ops.bounding_box import convert_format 10 | from keras_aug._src.ops.bounding_box import crop 11 | from keras_aug._src.ops.bounding_box import pad 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | logs/ 3 | dist/ 4 | .DS_Store 5 | build/ 6 | *.swp 7 | .idea 8 | *.pyc 9 | .pytest_cache 10 | *.egg-info 11 | __pycache__/ 12 | *.so 13 | 14 | # VS Code files and container 15 | .vscode/ 16 | .devcontainer/ 17 | 18 | # pytest 19 | .coverage* 20 | htmlcov/ 21 | coverage.xml 22 | 23 | # docs 24 | docs/build/ 25 | 26 | # venv 27 | ./venv 28 | ./.venv 29 | 30 | # ruff 31 | .ruff_cache 32 | 33 | # Gradio 34 | flagged 35 | 36 | # outputs 37 | *.jpg 38 | *.png 39 | -------------------------------------------------------------------------------- /shell/lint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Usage: # lint.sh can be used without arguments to lint the entire project: 3 | # 4 | # ./lint.sh 5 | # 6 | # or with arguments to lint a subset of files 7 | # 8 | # ./lint.sh examples/* 9 | 10 | files="." 11 | if [ $# -ne 0 ] 12 | then 13 | files=$@ 14 | fi 15 | 16 | ruff check $files 17 | if ! [ $? -eq 0 ] 18 | then 19 | echo "Please fix the code style issue." 20 | exit 1 21 | fi 22 | [ $# -eq 0 ] && echo "no issues with ruff" 23 | 24 | black --check $files 25 | if ! [ $? -eq 0 ] 26 | then 27 | echo "Please run \"sh shell/format.sh\" to format the code." 28 | exit 1 29 | fi 30 | [ $# -eq 0 ] && echo "no issues with black" 31 | echo "linting success!" 32 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-merge-conflict 7 | - id: check-toml 8 | - id: check-yaml 9 | - id: end-of-file-fixer 10 | files: \.py$ 11 | - id: debug-statements 12 | files: \.py$ 13 | - id: trailing-whitespace 14 | files: \.py$ 15 | 16 | - repo: https://github.com/pycqa/isort 17 | rev: 5.13.2 18 | hooks: 19 | - id: isort 20 | 21 | - repo: https://github.com/psf/black-pre-commit-mirror 22 | rev: 24.4.2 23 | hooks: 24 | - id: black 25 | 26 | - repo: https://github.com/astral-sh/ruff-pre-commit 27 | rev: v0.4.4 28 | hooks: 29 | - id: ruff 30 | args: 31 | - --fix 32 | - id: ruff-format 33 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "github-actions" 9 | directory: "/" 10 | schedule: 11 | interval: "monthly" 12 | groups: 13 | github-actions: 14 | patterns: 15 | - "*" 16 | - package-ecosystem: "pip" 17 | directory: "/" 18 | schedule: 19 | interval: "monthly" 20 | groups: 21 | python: 22 | patterns: 23 | - "*" 24 | -------------------------------------------------------------------------------- /keras_aug/ops/image/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from keras_aug._src.ops.image import adjust_brightness 8 | from keras_aug._src.ops.image import adjust_contrast 9 | from keras_aug._src.ops.image import adjust_hue 10 | from keras_aug._src.ops.image import adjust_saturation 11 | from keras_aug._src.ops.image import affine 12 | from keras_aug._src.ops.image import auto_contrast 13 | from keras_aug._src.ops.image import blend 14 | from keras_aug._src.ops.image import crop 15 | from keras_aug._src.ops.image import equalize 16 | from keras_aug._src.ops.image import guassian_blur 17 | from keras_aug._src.ops.image import invert 18 | from keras_aug._src.ops.image import pad 19 | from keras_aug._src.ops.image import posterize 20 | from keras_aug._src.ops.image import rgb_to_grayscale 21 | from keras_aug._src.ops.image import sharpen 22 | from keras_aug._src.ops.image import solarize 23 | from keras_aug._src.ops.image import transform_dtype 24 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/identity.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | from keras_aug._src.keras_aug_export import keras_aug_export 4 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 5 | 6 | 7 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 8 | @keras.saving.register_keras_serializable(package="keras_aug") 9 | class Identity(VisionRandomLayer): 10 | """Applies nothing to the inputs.""" 11 | 12 | def __init__(self, **kwargs): 13 | super().__init__(has_generator=False, **kwargs) 14 | 15 | def compute_output_shape(self, input_shape): 16 | return input_shape 17 | 18 | def augment_images(self, images, transformations, **kwargs): 19 | return images 20 | 21 | def augment_labels(self, labels, transformations, **kwargs): 22 | return labels 23 | 24 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 25 | return bounding_boxes 26 | 27 | def augment_segmentation_masks( 28 | self, segmentation_masks, transformations, **kwargs 29 | ): 30 | return segmentation_masks 31 | 32 | def augment_keypoints(self, keypoints, transformations, **kwargs): 33 | return keypoints 34 | 35 | def get_config(self): 36 | return super().get_config() 37 | -------------------------------------------------------------------------------- /keras_aug/_src/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import ml_dtypes as ml_dtypes 2 | import numpy as np 3 | from keras import distribution 4 | 5 | 6 | def get_images(dtype, data_format="channels_first", size=(32, 32)): 7 | # channels_first 8 | if dtype == "float32": 9 | x = np.random.uniform(0, 1, (2, 3, *size)).astype(dtype) 10 | elif dtype == "mixed_bfloat16": 11 | x = np.random.uniform(0, 1, (2, 3, *size)).astype(ml_dtypes.bfloat16) 12 | elif dtype == "bfloat16": 13 | x = np.random.uniform(0, 1, (2, 3, *size)).astype(ml_dtypes.bfloat16) 14 | elif dtype == "float16": 15 | x = np.random.uniform(0, 1, (2, 3, *size)).astype(dtype) 16 | elif dtype == "uint8": 17 | x = np.random.uniform(0, 255, (2, 3, *size)).astype(dtype) 18 | elif dtype == "int8": 19 | x = np.random.uniform(-128, 127, (2, 3, *size)).astype(dtype) 20 | elif dtype == "int16": 21 | x = np.random.uniform(-32768, 32767, (2, 3, *size)).astype(dtype) 22 | elif dtype == "int32": 23 | x = np.random.uniform(-2147483648, 2147483647, (2, 3, *size)).astype( 24 | dtype 25 | ) 26 | if data_format == "channels_last": 27 | x = np.transpose(x, [0, 2, 3, 1]) 28 | return x 29 | 30 | 31 | def uses_gpu(): 32 | # Condition used to skip tests when using the GPU 33 | devices = distribution.list_devices() 34 | if any(d.startswith("gpu") for d in devices): 35 | return True 36 | return False 37 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from keras import backend 5 | 6 | 7 | def pytest_addoption(parser): 8 | parser.addoption( 9 | "--run_serialization", 10 | action="store_true", 11 | default=False, 12 | help="run serialization tests", 13 | ) 14 | 15 | 16 | def pytest_configure(config): 17 | import tensorflow as tf 18 | 19 | # disable tensorflow gpu memory preallocation 20 | physical_devices = tf.config.list_physical_devices("GPU") 21 | for device in physical_devices: 22 | tf.config.experimental.set_memory_growth(device, True) 23 | 24 | # disable jax gpu memory preallocation 25 | # https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html 26 | os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" 27 | 28 | config.addinivalue_line( 29 | "markers", "serialization: mark test as a serialization test" 30 | ) 31 | config.addinivalue_line( 32 | "markers", 33 | "requires_trainable_backend: mark test for trainable backend only", 34 | ) 35 | 36 | 37 | def pytest_collection_modifyitems(config, items): 38 | run_serialization_tests = config.getoption("--run_serialization") 39 | skip_serialization = pytest.mark.skipif( 40 | not run_serialization_tests, 41 | reason="need --run_serialization option to run", 42 | ) 43 | requires_trainable_backend = pytest.mark.skipif( 44 | backend.backend() == "numpy", reason="require trainable backend" 45 | ) 46 | for item in items: 47 | if "requires_trainable_backend" in item.keywords: 48 | item.add_marker(requires_trainable_backend) 49 | if "serialization" in item.name: 50 | item.add_marker(skip_serialization) 51 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | pypi-publish: 9 | name: Upload release to PyPI 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: pypi 13 | url: https://pypi.org/p/keras-aug 14 | permissions: 15 | id-token: write 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 3.9 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: '3.9' 22 | - name: Lint 23 | uses: pre-commit/action@v3.0.1 24 | - name: Get pip cache dir 25 | id: pip-cache 26 | run: | 27 | python -m pip install --upgrade pip setuptools 28 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 29 | - name: Cache pip 30 | uses: actions/cache@v4 31 | with: 32 | path: ${{ steps.pip-cache.outputs.dir }} 33 | key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements_ci.txt') }} 34 | - name: Install dependencies 35 | run: | 36 | pip install -r requirements_ci.txt --progress-bar off --upgrade 37 | pip install -e ".[tests]" --progress-bar off --upgrade 38 | - name: Check for API changes 39 | run: | 40 | bash shell/api_gen.sh 41 | git status 42 | clean=$(git status | grep "nothing to commit") 43 | if [ -z "$clean" ]; then 44 | echo "Please run shell/api_gen.sh to generate API." 45 | exit 1 46 | fi 47 | - name: Build wheels 48 | shell: bash 49 | run: | 50 | pip install --upgrade pip setuptools wheel twine build 51 | python -m build 52 | - name: Publish package distributions to PyPI 53 | uses: pypa/gh-action-pypi-publish@release/v1 54 | with: 55 | verbose: true 56 | -------------------------------------------------------------------------------- /keras_aug/_src/testing/test_case.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from absl.testing import parameterized 3 | from keras import backend 4 | from keras import ops 5 | from keras.src import testing 6 | 7 | 8 | class TestCase(testing.TestCase, parameterized.TestCase): 9 | def setUp(self): 10 | # Defaults to channels_last 11 | self.data_format = backend.image_data_format() 12 | backend.set_image_data_format("channels_last") 13 | return super().setUp() 14 | 15 | def tearDown(self) -> None: 16 | backend.set_image_data_format(self.data_format) 17 | return super().tearDown() 18 | 19 | def convert_to_numpy(self, inputs): 20 | import torch 21 | from keras.src.backend.torch import convert_to_numpy 22 | 23 | if isinstance(inputs, torch.Tensor): 24 | inputs = convert_to_numpy(inputs) 25 | if not isinstance(inputs, np.ndarray): 26 | inputs = ops.convert_to_numpy(inputs) 27 | return inputs 28 | 29 | def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None): 30 | x1 = self.convert_to_numpy(x1) 31 | x2 = self.convert_to_numpy(x2) 32 | if backend.standardize_dtype(x1.dtype) == "bfloat16": 33 | x1 = x1.astype("float32") 34 | if backend.standardize_dtype(x2.dtype) == "bfloat16": 35 | x2 = x2.astype("float32") 36 | super().assertAllClose(x1, x2, atol, rtol, msg) 37 | 38 | def assertNotAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None): 39 | x1 = self.convert_to_numpy(x1) 40 | x2 = self.convert_to_numpy(x2) 41 | if backend.standardize_dtype(x1.dtype) == "bfloat16": 42 | x1 = x1.astype("float32") 43 | if backend.standardize_dtype(x2.dtype) == "bfloat16": 44 | x2 = x2.astype("float32") 45 | super().assertNotAllClose(x1, x2, atol, rtol, msg) 46 | 47 | def assertDType(self, x, dtype, msg=None): 48 | dtype = dtype.replace("mixed_", "") 49 | return super().assertDType(x, dtype, msg) 50 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_invert.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | from keras_aug._src.keras_aug_export import keras_aug_export 4 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 5 | 6 | 7 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 8 | @keras.saving.register_keras_serializable(package="keras_aug") 9 | class RandomInvert(VisionRandomLayer): 10 | """Inverts the colors of the given images. 11 | 12 | The equation of the inversion: `y = value_range[1] - x`. 13 | 14 | Args: 15 | p: A float specifying the probability. Defaults to `0.5`. 16 | """ 17 | 18 | def __init__(self, p: float = 0.5, **kwargs): 19 | super().__init__(**kwargs) 20 | self.p = float(p) 21 | 22 | def compute_output_shape(self, input_shape): 23 | return input_shape 24 | 25 | def get_params(self, batch_size, images=None, **kwargs): 26 | ops = self.backend 27 | random_generator = self.random_generator 28 | p = ops.random.uniform([batch_size], seed=random_generator) 29 | return p 30 | 31 | def augment_images(self, images, transformations, **kwargs): 32 | ops = self.backend 33 | p = transformations 34 | 35 | prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3]) 36 | images = ops.numpy.where( 37 | prob, self.image_backend.invert(images), images 38 | ) 39 | return images 40 | 41 | def augment_labels(self, labels, transformations, **kwargs): 42 | return labels 43 | 44 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 45 | return bounding_boxes 46 | 47 | def augment_segmentation_masks( 48 | self, segmentation_masks, transformations, **kwargs 49 | ): 50 | return segmentation_masks 51 | 52 | def augment_keypoints(self, keypoints, transformations, **kwargs): 53 | return keypoints 54 | 55 | def get_config(self): 56 | config = super().get_config() 57 | config.update({"p": self.p}) 58 | return config 59 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/to_dtype.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras import backend 3 | 4 | from keras_aug._src.keras_aug_export import keras_aug_export 5 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 6 | 7 | 8 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 9 | @keras.saving.register_keras_serializable(package="keras_aug") 10 | class ToDType(VisionRandomLayer): 11 | """Converts the input to a specific dtype, optionally scaling the values. 12 | 13 | If `scale` is `True`, the value range will changed as follows: 14 | - `"uint8"`: `[0, 255]` 15 | - `"int16"`: `[-32768, 32767]` 16 | - `"int32"`: `[-2147483648, 2147483647]` 17 | - float: `[0.0, 1.0]` 18 | 19 | Args: 20 | to_dtype: A string specifying the target dtype. 21 | scale: Whether to scale the values. Defaults to `False`. 22 | """ 23 | 24 | def __init__(self, to_dtype, scale=False, **kwargs): 25 | to_dtype = backend.standardize_dtype(to_dtype) 26 | self.scale = bool(scale) 27 | if "dtype" in kwargs: 28 | kwargs.pop("dtype") 29 | super().__init__(has_generator=False, dtype=to_dtype, **kwargs) 30 | self.to_dtype = to_dtype 31 | self.transform_dtype_scale = self.scale 32 | 33 | def compute_output_shape(self, input_shape): 34 | return input_shape 35 | 36 | def augment_images(self, images, transformations, **kwargs): 37 | return images 38 | 39 | def augment_labels(self, labels, transformations, **kwargs): 40 | return labels 41 | 42 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 43 | return bounding_boxes 44 | 45 | def augment_segmentation_masks( 46 | self, segmentation_masks, transformations, **kwargs 47 | ): 48 | return segmentation_masks 49 | 50 | def augment_keypoints(self, keypoints, transformations, **kwargs): 51 | return keypoints 52 | 53 | def get_config(self): 54 | config = super().get_config() 55 | config.update({"to_dtype": self.to_dtype, "scale": self.scale}) 56 | return config 57 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_solarize.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | from keras_aug._src.keras_aug_export import keras_aug_export 4 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 5 | 6 | 7 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 8 | @keras.saving.register_keras_serializable(package="keras_aug") 9 | class RandomSolarize(VisionRandomLayer): 10 | """Solarize the input images with a given probability 11 | 12 | Solarization inverts all pixel values above a threshold. 13 | 14 | Args: 15 | threshold: All pixels equal or above this value are inverted. 16 | p: A float specifying the probability. Defaults to `0.5`. 17 | """ 18 | 19 | def __init__(self, threshold: float, p: float = 0.5, **kwargs): 20 | super().__init__(**kwargs) 21 | self.threshold = float(threshold) 22 | self.p = float(p) 23 | 24 | def compute_output_shape(self, input_shape): 25 | return input_shape 26 | 27 | def get_params(self, batch_size, images=None, **kwargs): 28 | ops = self.backend 29 | random_generator = self.random_generator 30 | p = ops.random.uniform([batch_size], seed=random_generator) 31 | return p 32 | 33 | def augment_images(self, images, transformations=None, **kwargs): 34 | ops = self.backend 35 | p = transformations 36 | 37 | prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3]) 38 | images = ops.numpy.where( 39 | prob, 40 | self.image_backend.solarize(images, self.threshold), 41 | images, 42 | ) 43 | return images 44 | 45 | def augment_labels(self, labels, transformations, **kwargs): 46 | return labels 47 | 48 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 49 | return bounding_boxes 50 | 51 | def augment_segmentation_masks( 52 | self, segmentation_masks, transformations, **kwargs 53 | ): 54 | return segmentation_masks 55 | 56 | def augment_keypoints(self, keypoints, transformations, **kwargs): 57 | return keypoints 58 | 59 | def get_config(self): 60 | config = super().get_config() 61 | config.update({"threshold": self.threshold, "p": self.p}) 62 | return config 63 | -------------------------------------------------------------------------------- /keras_aug/layers/vision/__init__.py: -------------------------------------------------------------------------------- 1 | """DO NOT EDIT. 2 | 3 | This file was autogenerated. Do not edit it by hand, 4 | since your modifications would be overwritten. 5 | """ 6 | 7 | from keras_aug._src.layers.vision.center_crop import CenterCrop 8 | from keras_aug._src.layers.vision.color_jitter import ColorJitter 9 | from keras_aug._src.layers.vision.cut_mix import CutMix 10 | from keras_aug._src.layers.vision.gaussian_blur import GaussianBlur 11 | from keras_aug._src.layers.vision.gaussian_noise import GaussianNoise 12 | from keras_aug._src.layers.vision.identity import Identity 13 | from keras_aug._src.layers.vision.max_bounding_box import MaxBoundingBox 14 | from keras_aug._src.layers.vision.mix_up import MixUp 15 | from keras_aug._src.layers.vision.mosaic import Mosaic 16 | from keras_aug._src.layers.vision.normalize import Normalize 17 | from keras_aug._src.layers.vision.pad import Pad 18 | from keras_aug._src.layers.vision.rand_augment import RandAugment 19 | from keras_aug._src.layers.vision.random_affine import RandomAffine 20 | from keras_aug._src.layers.vision.random_auto_contrast import RandomAutoContrast 21 | from keras_aug._src.layers.vision.random_channel_permutation import ( 22 | RandomChannelPermutation, 23 | ) 24 | from keras_aug._src.layers.vision.random_crop import RandomCrop 25 | from keras_aug._src.layers.vision.random_equalize import RandomEqualize 26 | from keras_aug._src.layers.vision.random_erasing import RandomErasing 27 | from keras_aug._src.layers.vision.random_flip import RandomFlip 28 | from keras_aug._src.layers.vision.random_grayscale import RandomGrayscale 29 | from keras_aug._src.layers.vision.random_hsv import RandomHSV 30 | from keras_aug._src.layers.vision.random_invert import RandomInvert 31 | from keras_aug._src.layers.vision.random_posterize import RandomPosterize 32 | from keras_aug._src.layers.vision.random_resized_crop import RandomResizedCrop 33 | from keras_aug._src.layers.vision.random_rotation import RandomRotation 34 | from keras_aug._src.layers.vision.random_sharpen import RandomSharpen 35 | from keras_aug._src.layers.vision.random_solarize import RandomSolarize 36 | from keras_aug._src.layers.vision.rescale import Rescale 37 | from keras_aug._src.layers.vision.resize import Resize 38 | from keras_aug._src.layers.vision.to_dtype import ToDType 39 | from keras_aug._src.layers.vision.trivial_augment import TrivialAugmentWide 40 | -------------------------------------------------------------------------------- /guides/quick_start.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import tensorflow as tf 3 | import tensorflow_datasets as tfds 4 | 5 | from keras_aug import layers as ka_layers 6 | 7 | BATCH_SIZE = 64 8 | NUM_CLASSES = 3 9 | INPUT_SIZE = (128, 128) 10 | 11 | # Create a `tf.data.Dataset`-compatible preprocessing pipeline. 12 | # Note that this example works with all backends. 13 | train_dataset, validation_dataset = tfds.load( 14 | "rock_paper_scissors", as_supervised=True, split=["train", "test"] 15 | ) 16 | train_dataset = ( 17 | train_dataset.batch(BATCH_SIZE) 18 | .map( 19 | lambda images, labels: { 20 | "images": tf.cast(images, "float32") / 255.0, 21 | "labels": tf.one_hot(labels, NUM_CLASSES), 22 | } 23 | ) 24 | .map(ka_layers.vision.Resize(INPUT_SIZE)) 25 | .shuffle(128) 26 | .map(ka_layers.vision.RandAugment()) 27 | .map(ka_layers.vision.CutMix(num_classes=NUM_CLASSES)) 28 | .map(ka_layers.vision.Rescale(scale=2.0, offset=-1)) # [0, 1] to [-1, 1] 29 | .map(lambda data: (data["images"], data["labels"])) 30 | .prefetch(tf.data.AUTOTUNE) 31 | ) 32 | validation_dataset = ( 33 | validation_dataset.batch(BATCH_SIZE) 34 | .map( 35 | lambda images, labels: { 36 | "images": tf.cast(images, "float32") / 255.0, 37 | "labels": tf.one_hot(labels, NUM_CLASSES), 38 | } 39 | ) 40 | .map(ka_layers.vision.Resize(INPUT_SIZE)) 41 | .map(ka_layers.vision.Rescale(scale=2.0, offset=-1)) # [0, 1] to [-1, 1] 42 | .map(lambda data: (data["images"], data["labels"])) 43 | .prefetch(tf.data.AUTOTUNE) 44 | ) 45 | 46 | # Create a model using MobileNetV2 as the backbone. 47 | backbone = keras.applications.MobileNetV2( 48 | input_shape=(*INPUT_SIZE, 3), include_top=False 49 | ) 50 | backbone.trainable = False 51 | inputs = keras.Input((*INPUT_SIZE, 3)) 52 | x = backbone(inputs) 53 | x = keras.layers.GlobalAveragePooling2D()(x) 54 | outputs = keras.layers.Dense(NUM_CLASSES, activation="softmax")(x) 55 | model = keras.Model(inputs, outputs) 56 | model.summary() 57 | model.compile( 58 | loss="categorical_crossentropy", 59 | optimizer=keras.optimizers.SGD(learning_rate=1e-3, momentum=0.9), 60 | metrics=["accuracy"], 61 | ) 62 | 63 | # Train and evaluate your model 64 | model.fit(train_dataset, validation_data=validation_dataset, epochs=8) 65 | model.evaluate(validation_dataset) 66 | -------------------------------------------------------------------------------- /keras_aug/_src/ops/bounding_box.py: -------------------------------------------------------------------------------- 1 | from keras.src.utils.backend_utils import in_tf_graph 2 | 3 | from keras_aug._src.backend.bounding_box import BoundingBoxBackend 4 | from keras_aug._src.keras_aug_export import keras_aug_export 5 | 6 | 7 | @keras_aug_export(parent_path=["keras_aug.ops.bounding_box"]) 8 | def convert_format( 9 | boxes, source: str, target: str, height=None, width=None, dtype="float32" 10 | ): 11 | backend = "tensorflow" if in_tf_graph() else None 12 | return BoundingBoxBackend(backend).convert_format( 13 | boxes, source, target, height=height, width=width, dtype=dtype 14 | ) 15 | 16 | 17 | @keras_aug_export(parent_path=["keras_aug.ops.bounding_box"]) 18 | def clip_to_images(bounding_boxes, height=None, width=None, format="xyxy"): 19 | backend = "tensorflow" if in_tf_graph() else None 20 | return BoundingBoxBackend(backend).clip_to_images( 21 | bounding_boxes, height=height, width=width, format=format 22 | ) 23 | 24 | 25 | @keras_aug_export(parent_path=["keras_aug.ops.bounding_box"]) 26 | def affine( 27 | boxes, 28 | angle, 29 | translate_x, 30 | translate_y, 31 | scale, 32 | shear_x, 33 | shear_y, 34 | height, 35 | width, 36 | center_x=None, 37 | center_y=None, 38 | format="xyxy", 39 | ): 40 | if format != "xyxy": 41 | raise NotImplementedError 42 | backend = "tensorflow" if in_tf_graph() else None 43 | return BoundingBoxBackend(backend).affine( 44 | boxes, 45 | angle, 46 | translate_x, 47 | translate_y, 48 | scale, 49 | shear_x, 50 | shear_y, 51 | height, 52 | width, 53 | center_x=center_x, 54 | center_y=center_y, 55 | ) 56 | 57 | 58 | @keras_aug_export(parent_path=["keras_aug.ops.bounding_box"]) 59 | def crop(boxes, top, left, height, width, format="xyxy"): 60 | if format != "xyxy": 61 | raise NotImplementedError 62 | backend = "tensorflow" if in_tf_graph() else None 63 | return BoundingBoxBackend(backend).crop(boxes, top, left, height, width) 64 | 65 | 66 | @keras_aug_export(parent_path=["keras_aug.ops.bounding_box"]) 67 | def pad(boxes, top, left, format="xyxy"): 68 | if format != "xyxy": 69 | raise NotImplementedError 70 | backend = "tensorflow" if in_tf_graph() else None 71 | return BoundingBoxBackend(backend).pad(boxes, top, left) 72 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_auto_contrast.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | from keras_aug._src.keras_aug_export import keras_aug_export 4 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 5 | from keras_aug._src.utils.argument_validation import standardize_data_format 6 | 7 | 8 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 9 | @keras.saving.register_keras_serializable(package="keras_aug") 10 | class RandomAutoContrast(VisionRandomLayer): 11 | """Autocontrast the images randomly with a given probability. 12 | 13 | Auto contrast stretches the values of an image across the entire available 14 | value range. This makes differences between pixels more obvious. 15 | 16 | Args: 17 | p: A float specifying the probability. Defaults to `0.5`. 18 | """ 19 | 20 | def __init__(self, p: float = 0.5, data_format=None, **kwargs): 21 | super().__init__(**kwargs) 22 | self.p = float(p) 23 | self.data_format = standardize_data_format(data_format) 24 | 25 | def compute_output_shape(self, input_shape): 26 | return input_shape 27 | 28 | def get_params(self, batch_size, images=None, **kwargs): 29 | ops = self.backend 30 | random_generator = self.random_generator 31 | p = ops.random.uniform([batch_size], seed=random_generator) 32 | return p 33 | 34 | def augment_images(self, images, transformations, **kwargs): 35 | ops = self.backend 36 | p = transformations 37 | 38 | prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3]) 39 | images = ops.numpy.where( 40 | prob, 41 | self.image_backend.auto_contrast(images, self.data_format), 42 | images, 43 | ) 44 | return images 45 | 46 | def augment_labels(self, labels, transformations, **kwargs): 47 | return labels 48 | 49 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 50 | return bounding_boxes 51 | 52 | def augment_segmentation_masks( 53 | self, segmentation_masks, transformations, **kwargs 54 | ): 55 | return segmentation_masks 56 | 57 | def augment_keypoints(self, keypoints, transformations, **kwargs): 58 | return keypoints 59 | 60 | def get_config(self): 61 | config = super().get_config() 62 | config.update({"p": self.p}) 63 | return config 64 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/rescale_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras.src.testing.test_utils import named_product 5 | 6 | from keras_aug._src.layers.vision.rescale import Rescale 7 | from keras_aug._src.testing.test_case import TestCase 8 | from keras_aug._src.utils.test_utils import get_images 9 | 10 | 11 | class RescaleTest(TestCase): 12 | @parameterized.named_parameters( 13 | named_product(dtype=["float32", "mixed_bfloat16", "bfloat16"]) 14 | ) 15 | def test_correctness(self, dtype): 16 | if "bfloat16" in dtype: 17 | atol = 1e-2 18 | else: 19 | atol = 1e-6 20 | np.random.seed(42) 21 | 22 | x = get_images(dtype, "channels_last") 23 | layer = Rescale(scale=0.5, offset=0.1, dtype=dtype) 24 | y = layer(x) 25 | 26 | ref_y = x * 0.5 + 0.1 27 | self.assertDType(y, dtype) 28 | self.assertAllClose(y, ref_y, atol=atol) 29 | 30 | def test_shape(self): 31 | # Test dynamic shape 32 | x = keras.KerasTensor((None, None, None, 3)) 33 | y = Rescale(scale=2, offset=0.5)(x) 34 | self.assertEqual(y.shape, (None, None, None, 3)) 35 | 36 | # Test static shape 37 | x = keras.KerasTensor((None, 32, 32, 3)) 38 | y = Rescale(scale=2, offset=0.5)(x) 39 | self.assertEqual(y.shape, (None, 32, 32, 3)) 40 | 41 | def test_model(self): 42 | layer = Rescale(scale=2, offset=0.5) 43 | inputs = keras.layers.Input(shape=[None, None, 5]) 44 | outputs = layer(inputs) 45 | model = keras.models.Model(inputs, outputs) 46 | self.assertEqual(model.output_shape, (None, None, None, 5)) 47 | 48 | def test_config(self): 49 | x = get_images("float32", "channels_last") 50 | layer = Rescale(scale=2, offset=0.5) 51 | y = layer(x) 52 | 53 | layer = Rescale.from_config(layer.get_config()) 54 | y2 = layer(x) 55 | self.assertAllClose(y, y2) 56 | 57 | def test_tf_data_compatibility(self): 58 | import tensorflow as tf 59 | 60 | layer = Rescale(scale=2, offset=0.5) 61 | x = get_images("float32", "channels_last") 62 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 63 | for output in ds.take(1): 64 | self.assertIsInstance(output, tf.Tensor) 65 | self.assertEqual(output.shape, (2, 32, 32, 3)) 66 | -------------------------------------------------------------------------------- /keras_aug/_src/backend/bounding_box_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from absl.testing import parameterized 3 | from keras.src.testing.test_utils import named_product 4 | 5 | from keras_aug._src.backend.bounding_box import BoundingBoxBackend 6 | from keras_aug._src.testing.test_case import TestCase 7 | 8 | 9 | class BoundingBoxBackendTest(TestCase): 10 | size = 1000.0 11 | xyxy_box = np.array([[[10, 20, 110, 120], [20, 30, 120, 130]]], "float32") 12 | yxyx_box = np.array([[[20, 10, 120, 110], [30, 20, 130, 120]]], "float32") 13 | xywh_box = np.array([[[10, 20, 100, 100], [20, 30, 100, 100]]], "float32") 14 | center_xywh_box = np.array( 15 | [[[60, 70, 100, 100], [70, 80, 100, 100]]], "float32" 16 | ) 17 | 18 | def get_box(self, name): 19 | box_dict = { 20 | "xyxy": self.xyxy_box, 21 | "yxyx": self.yxyx_box, 22 | "xywh": self.xywh_box, 23 | "center_xywh": self.center_xywh_box, 24 | "rel_xyxy": self.xyxy_box / self.size, 25 | "rel_yxyx": self.yxyx_box / self.size, 26 | "rel_xywh": self.xywh_box / self.size, 27 | "rel_center_xywh": self.center_xywh_box / self.size, 28 | } 29 | return box_dict[name] 30 | 31 | @parameterized.named_parameters( 32 | named_product( 33 | source=[ 34 | "xyxy", 35 | "yxyx", 36 | "xywh", 37 | "center_xywh", 38 | "rel_xyxy", 39 | "rel_yxyx", 40 | "rel_xywh", 41 | "rel_center_xywh", 42 | ], 43 | target=[ 44 | "xyxy", 45 | "yxyx", 46 | "xywh", 47 | "center_xywh", 48 | "rel_xyxy", 49 | "rel_yxyx", 50 | "rel_xywh", 51 | "rel_center_xywh", 52 | ], 53 | ) 54 | ) 55 | def test_convert_format(self, source, target): 56 | bbox_backend = BoundingBoxBackend() 57 | boxes = self.get_box(source) 58 | ref_boxes = self.get_box(target) 59 | 60 | # Test batched 61 | result = bbox_backend.convert_format(boxes, source, target, 1000, 1000) 62 | self.assertAllClose(result, ref_boxes) 63 | 64 | # Test unbatched 65 | boxes = boxes[0] 66 | ref_boxes = ref_boxes[0] 67 | result = bbox_backend.convert_format(boxes, source, target, 1000, 1000) 68 | self.assertAllClose(result, ref_boxes) 69 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_grayscale.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | from keras_aug._src.keras_aug_export import keras_aug_export 4 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 5 | 6 | 7 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 8 | @keras.saving.register_keras_serializable(package="keras_aug") 9 | class RandomGrayscale(VisionRandomLayer): 10 | """Randomly convert the images to grayscale. 11 | 12 | The input images must be 3 channels. 13 | 14 | Args: 15 | p: A float specifying the probability. Defaults to `0.5`. 16 | data_format: A string specifying the data format of the input images. 17 | It can be either `"channels_last"` or `"channels_first"`. 18 | If not specified, the value will be interpreted by 19 | `keras.config.image_data_format`. Defaults to `None`. 20 | """ 21 | 22 | def __init__(self, p: float = 0.5, data_format=None, **kwargs): 23 | super().__init__(**kwargs) 24 | self.p = float(p) 25 | self.data_format = data_format or keras.config.image_data_format() 26 | 27 | def get_params(self, batch_size, images=None, **kwargs): 28 | ops = self.backend 29 | random_generator = self.random_generator 30 | p = ops.random.uniform([batch_size], seed=random_generator) 31 | return p 32 | 33 | def compute_output_shape(self, input_shape): 34 | return input_shape 35 | 36 | def augment_images(self, images, transformations=None, **kwargs): 37 | ops = self.backend 38 | p = transformations 39 | prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3]) 40 | images = ops.numpy.where( 41 | prob, 42 | self.image_backend.rgb_to_grayscale( 43 | images, data_format=self.data_format 44 | ), 45 | images, 46 | ) 47 | return images 48 | 49 | def augment_labels(self, labels, transformations, **kwargs): 50 | return labels 51 | 52 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 53 | return bounding_boxes 54 | 55 | def augment_segmentation_masks( 56 | self, segmentation_masks, transformations, **kwargs 57 | ): 58 | return segmentation_masks 59 | 60 | def augment_keypoints(self, keypoints, transformations, **kwargs): 61 | return keypoints 62 | 63 | def get_config(self): 64 | config = super().get_config() 65 | config.update({"p": self.p}) 66 | return config 67 | -------------------------------------------------------------------------------- /keras_aug/_src/backend/dynamic_backend.py: -------------------------------------------------------------------------------- 1 | from keras import backend 2 | from keras import random 3 | 4 | 5 | class DynamicBackend: 6 | def __init__(self, name=None): 7 | if name is not None and not isinstance(name, str): 8 | raise TypeError 9 | self._name = name 10 | 11 | # Variable 12 | self._backend = None 13 | 14 | # Init 15 | self.set_backend(self._name, force=True) 16 | 17 | @property 18 | def name(self): 19 | return self._name 20 | 21 | @property 22 | def backend(self): 23 | return self._backend 24 | 25 | def set_backend(self, name=None, force=False): 26 | name = name or backend.backend() 27 | self._backend = get_backend(name) 28 | self._name = name 29 | 30 | def reset(self): 31 | self.set_backend() 32 | 33 | 34 | class DynamicRandomGenerator: 35 | def __init__(self, name=None, seed=None): 36 | if name is not None and not isinstance(name, str): 37 | raise TypeError 38 | self._name = name 39 | self._seed = seed 40 | 41 | # Variable 42 | self._cached_random_generator = {} 43 | 44 | # Init 45 | self.set_generator(self._name) 46 | 47 | @property 48 | def name(self): 49 | return self._name 50 | 51 | @property 52 | def random_generator(self): 53 | return self._cached_random_generator[self._name] 54 | 55 | def set_generator(self, name=None): 56 | name = name or backend.backend() 57 | if name in self._cached_random_generator: 58 | return 59 | self._cached_random_generator[name] = random.SeedGenerator( 60 | seed=self._seed, backend=get_backend(name) 61 | ) 62 | 63 | def reset(self): 64 | self.set_generator() 65 | 66 | 67 | def get_backend(name=None): 68 | name = name or backend.backend() 69 | if name == "tensorflow": 70 | import keras.src.backend.tensorflow as module 71 | elif name == "jax": 72 | import keras.src.backend.jax as module 73 | elif name == "torch": 74 | import keras.src.backend.torch as module 75 | elif name == "numpy": 76 | if backend.backend() == "numpy": 77 | import keras.src.backend as module 78 | else: 79 | raise NotImplementedError( 80 | "Currently, we cannot dynamically import the numpy backend " 81 | "because it would disrupt the namespace of the import." 82 | ) 83 | else: 84 | raise NotImplementedError 85 | return module 86 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_posterize.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | from keras_aug._src.keras_aug_export import keras_aug_export 4 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 5 | 6 | 7 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 8 | @keras.saving.register_keras_serializable(package="keras_aug") 9 | class RandomPosterize(VisionRandomLayer): 10 | """Posterize the input images with a given probability. 11 | 12 | Posterization reduces the number of bits for each color channel. 13 | 14 | Args: 15 | bits: The number of bits to keep for each channel (0-8). 16 | p: A float specifying the probability. Defaults to `0.5`. 17 | data_format: A string specifying the data format of the input images. 18 | It can be either `"channels_last"` or `"channels_first"`. 19 | If not specified, the value will be interpreted by 20 | `keras.config.image_data_format`. Defaults to `None`. 21 | """ 22 | 23 | def __init__(self, bits: int, p: float = 0.5, data_format=None, **kwargs): 24 | super().__init__(**kwargs) 25 | self.bits = int(bits) 26 | self.p = float(p) 27 | self.data_format = data_format or keras.config.image_data_format() 28 | 29 | def compute_output_shape(self, input_shape): 30 | return input_shape 31 | 32 | def get_params(self, batch_size, images=None, **kwargs): 33 | ops = self.backend 34 | random_generator = self.random_generator 35 | p = ops.random.uniform([batch_size], seed=random_generator) 36 | return p 37 | 38 | def augment_images(self, images, transformations=None, **kwargs): 39 | ops = self.backend 40 | p = transformations 41 | 42 | prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3]) 43 | images = ops.numpy.where( 44 | prob, self.image_backend.posterize(images, self.bits), images 45 | ) 46 | return images 47 | 48 | def augment_labels(self, labels, transformations, **kwargs): 49 | return labels 50 | 51 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 52 | return bounding_boxes 53 | 54 | def augment_segmentation_masks( 55 | self, segmentation_masks, transformations, **kwargs 56 | ): 57 | return segmentation_masks 58 | 59 | def augment_keypoints(self, keypoints, transformations, **kwargs): 60 | return keypoints 61 | 62 | def get_config(self): 63 | config = super().get_config() 64 | config.update({"bits": self.bits, "p": self.p}) 65 | return config 66 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/rescale.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras import backend 3 | 4 | from keras_aug._src.keras_aug_export import keras_aug_export 5 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 6 | 7 | 8 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 9 | @keras.saving.register_keras_serializable(package="keras_aug") 10 | class Rescale(VisionRandomLayer): 11 | """Rescales the values of the images to a new range 12 | 13 | The rescaling equation: `y = x * scale + offset`. 14 | 15 | Args: 16 | scale: The scale to apply to the images. 17 | offset: The offset to apply to the images. Defaults to `0.0` 18 | """ 19 | 20 | def __init__(self, scale: float, offset: float = 0.0, **kwargs): 21 | super().__init__(has_generator=False, **kwargs) 22 | self.scale = float(scale) 23 | self.offset = float(offset) 24 | 25 | if not backend.is_float_dtype(self.compute_dtype): 26 | dtype = self.dtype_policy 27 | raise ValueError( 28 | f"The `dtype` of '{self.__class__.__name__}' must be float. " 29 | f"Received: dtype={dtype}" 30 | ) 31 | 32 | def compute_output_shape(self, input_shape): 33 | return input_shape 34 | 35 | def augment_images(self, images, transformations, **kwargs): 36 | ops = self.backend 37 | original_dtype = backend.standardize_dtype(images.dtype) 38 | images = self.image_backend.transform_dtype( 39 | images, images.dtype, backend.result_type(images.dtype, float) 40 | ) 41 | scale = ops.convert_to_tensor(self.scale, images.dtype) 42 | offset = ops.convert_to_tensor(self.offset, images.dtype) 43 | images = ops.numpy.add(ops.numpy.multiply(images, scale), offset) 44 | images = self.image_backend.transform_dtype( 45 | images, images.dtype, original_dtype 46 | ) 47 | return images 48 | 49 | def augment_labels(self, labels, transformations, **kwargs): 50 | return labels 51 | 52 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 53 | return bounding_boxes 54 | 55 | def augment_segmentation_masks( 56 | self, segmentation_masks, transformations, **kwargs 57 | ): 58 | return segmentation_masks 59 | 60 | def augment_keypoints(self, keypoints, transformations, **kwargs): 61 | return keypoints 62 | 63 | def get_config(self): 64 | config = super().get_config() 65 | config.update({"scale": self.scale, "offset": self.offset}) 66 | return config 67 | -------------------------------------------------------------------------------- /docs/generate_semantic_segmentation_gif.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_datasets as tfds 3 | from PIL import Image 4 | 5 | from keras_aug import layers as ka_layers 6 | from keras_aug import visualization 7 | 8 | size = (320, 320) 9 | mosaic_size = (640, 640) 10 | 11 | 12 | def load_oxford(name, split, shuffle, batch_size, position): 13 | def unpack_oxford_inputs(x): 14 | segmentation_masks = tf.cast(x["segmentation_mask"], "int8") 15 | segmentation_masks = tf.where( 16 | tf.equal(segmentation_masks, 2), # Background index 17 | tf.constant(-1, dtype=segmentation_masks.dtype), 18 | segmentation_masks, 19 | ) 20 | return { 21 | "images": x["image"], 22 | "segmentation_masks": segmentation_masks, 23 | } 24 | 25 | ds = tfds.load(name, split=split, with_info=False, shuffle_files=shuffle) 26 | ds: tf.data.Dataset = ds.map(lambda x: unpack_oxford_inputs(x)) 27 | ds = ds.shuffle(128, reshuffle_each_iteration=True) 28 | ds = ds.map( 29 | ka_layers.vision.Resize(size[0], along_long_edge=True, dtype="uint8") 30 | ) 31 | ds = ds.map( 32 | ka_layers.vision.Pad( 33 | size, padding_position=position, padding_value=114, dtype="uint8" 34 | ) 35 | ) 36 | ds = ds.batch(batch_size) 37 | return ds 38 | 39 | 40 | args = dict(name="oxford_iiit_pet", split="train", shuffle=True, batch_size=16) 41 | ds_tl = load_oxford(**args, position="top_left") 42 | ds_tr = load_oxford(**args, position="top_right") 43 | ds_bl = load_oxford(**args, position="bottom_left") 44 | ds_br = load_oxford(**args, position="bottom_right") 45 | ds = tf.data.Dataset.zip(ds_tl, ds_tr, ds_bl, ds_br) 46 | 47 | # Augment 48 | ds = ds.map( 49 | ka_layers.vision.Mosaic( 50 | mosaic_size, offset=(0.25, 0.75), padding_value=114, dtype="uint8" 51 | ) 52 | ) 53 | ds = ds.map( 54 | ka_layers.vision.RandomAffine( 55 | translate=0.05, scale=0.25, padding_value=114, dtype="uint8" 56 | ) 57 | ) 58 | ds = ds.map(ka_layers.vision.CenterCrop(size, dtype="uint8")) 59 | ds = ds.map(ka_layers.vision.RandomGrayscale(p=0.01)) 60 | ds = ds.map(ka_layers.vision.RandomHSV(hue=0.015, saturation=0.7, value=0.4)) 61 | ds = ds.map(ka_layers.vision.RandomFlip(mode="horizontal")) 62 | 63 | # Make gif 64 | images = [] 65 | for x in ds.take(1): 66 | drawed_images = visualization.draw_segmentation_masks( 67 | x["images"], x["segmentation_masks"], num_classes=2, alpha=0.5 68 | ) 69 | for i in range(drawed_images.shape[0]): 70 | images.append(Image.fromarray(drawed_images[i])) 71 | images[0].save( 72 | "output.gif", 73 | save_all=True, 74 | append_images=images[1:10], 75 | optimize=False, 76 | duration=1000, 77 | loop=0, 78 | ) 79 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_auto_contrast_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras.src.testing.test_utils import named_product 5 | 6 | from keras_aug._src.layers.vision.random_auto_contrast import RandomAutoContrast 7 | from keras_aug._src.testing.test_case import TestCase 8 | from keras_aug._src.utils.test_utils import get_images 9 | 10 | 11 | class RandomAutoContrastTest(TestCase): 12 | @parameterized.named_parameters( 13 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 14 | ) 15 | def test_correctness(self, dtype): 16 | import torchvision.transforms.v2.functional as TF 17 | from keras.src.backend.torch import convert_to_tensor 18 | 19 | if dtype == "uint8": 20 | atol = 1 21 | else: 22 | atol = 1e-6 23 | np.random.seed(42) 24 | 25 | x = get_images(dtype, "channels_first") 26 | layer = RandomAutoContrast( 27 | p=1.0, dtype=dtype, data_format="channels_first" 28 | ) 29 | y = layer(x) 30 | 31 | ref_y = TF.autocontrast(convert_to_tensor(x)) 32 | self.assertDType(y, dtype) 33 | self.assertAllClose(y, ref_y, atol=atol) 34 | 35 | # Test p=0.0 36 | x = get_images(dtype, "channels_last") 37 | layer = RandomAutoContrast(p=0.0, dtype=dtype) 38 | y = layer(x) 39 | self.assertDType(y, dtype) 40 | self.assertAllClose(y, x) 41 | 42 | def test_shape(self): 43 | # Test dynamic shape 44 | x = keras.KerasTensor((None, None, None, 3)) 45 | y = RandomAutoContrast()(x) 46 | self.assertEqual(y.shape, (None, None, None, 3)) 47 | 48 | # Test static shape 49 | x = keras.KerasTensor((None, 32, 32, 3)) 50 | y = RandomAutoContrast()(x) 51 | self.assertEqual(y.shape, (None, 32, 32, 3)) 52 | 53 | def test_model(self): 54 | layer = RandomAutoContrast() 55 | inputs = keras.layers.Input(shape=[None, None, 5]) 56 | outputs = layer(inputs) 57 | model = keras.models.Model(inputs, outputs) 58 | self.assertEqual(model.output_shape, (None, None, None, 5)) 59 | 60 | def test_config(self): 61 | x = get_images("float32", "channels_last") 62 | layer = RandomAutoContrast(p=1.0) 63 | y = layer(x) 64 | 65 | layer = RandomAutoContrast.from_config(layer.get_config()) 66 | y2 = layer(x) 67 | self.assertAllClose(y, y2) 68 | 69 | def test_tf_data_compatibility(self): 70 | import tensorflow as tf 71 | 72 | layer = RandomAutoContrast() 73 | x = get_images("float32", "channels_last") 74 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 75 | for output in ds.take(1): 76 | self.assertIsInstance(output, tf.Tensor) 77 | self.assertEqual(output.shape, (2, 32, 32, 3)) 78 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/gaussian_noise.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras import backend 3 | 4 | from keras_aug._src.keras_aug_export import keras_aug_export 5 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 6 | 7 | 8 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 9 | @keras.saving.register_keras_serializable(package="keras_aug") 10 | class GaussianNoise(VisionRandomLayer): 11 | """Add gaussian noise to the input images. 12 | 13 | Args: 14 | mean: Mean of the sampled normal distribution. Defaults to `0.0`. 15 | sigma: Standard deviation of the sampled normal distribution. Defaults 16 | to `0.1`. 17 | clip: Whether to clip the values in `[0, 1]`. Defaults to `True`. 18 | """ 19 | 20 | def __init__( 21 | self, mean: float = 0.0, sigma: float = 0.1, clip: bool = True, **kwargs 22 | ): 23 | super().__init__(**kwargs) 24 | self.mean = float(mean) 25 | self.sigma = float(sigma) 26 | self.clip = bool(clip) 27 | 28 | def compute_output_shape(self, input_shape): 29 | return input_shape 30 | 31 | def get_params(self, batch_size, images=None, **kwargs): 32 | ops = self.backend 33 | random_generator = self.random_generator 34 | 35 | dtype = backend.result_type(images.dtype, float) 36 | noise = ( 37 | self.mean 38 | + ops.random.normal( 39 | ops.shape(images), dtype=dtype, seed=random_generator 40 | ) 41 | * self.sigma 42 | ) 43 | return noise 44 | 45 | def augment_images(self, images, transformations, **kwargs): 46 | ops = self.backend 47 | original_dtype = backend.standardize_dtype(images.dtype) 48 | noise = transformations 49 | 50 | images = self.image_backend.transform_dtype( 51 | images, images.dtype, backend.result_type(images.dtype, float) 52 | ) 53 | images = ops.numpy.add(images, noise) 54 | if self.clip: 55 | images = ops.numpy.clip(images, 0, 1) 56 | images = self.image_backend.transform_dtype( 57 | images, images.dtype, original_dtype 58 | ) 59 | return images 60 | 61 | def augment_labels(self, labels, transformations, **kwargs): 62 | return labels 63 | 64 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 65 | return bounding_boxes 66 | 67 | def augment_segmentation_masks( 68 | self, segmentation_masks, transformations, **kwargs 69 | ): 70 | return segmentation_masks 71 | 72 | def augment_keypoints(self, keypoints, transformations, **kwargs): 73 | return keypoints 74 | 75 | def get_config(self): 76 | config = super().get_config() 77 | config.update( 78 | {"mean": self.mean, "sigma": self.sigma, "clip": self.clip} 79 | ) 80 | return config 81 | -------------------------------------------------------------------------------- /keras_aug/_src/visualization/draw_segmentation_masks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras import backend 3 | from keras import ops 4 | 5 | from keras_aug._src import ops as ka_ops 6 | from keras_aug._src.keras_aug_export import keras_aug_export 7 | 8 | 9 | @keras_aug_export(parent_path=["keras_aug.visualization"]) 10 | def draw_segmentation_masks( 11 | images, 12 | segmentation_masks, 13 | num_classes=None, 14 | color_mapping=None, 15 | alpha=0.8, 16 | ignore_index=-1, 17 | data_format=None, 18 | ): 19 | data_format = data_format or backend.image_data_format() 20 | images_shape = ops.shape(images) 21 | if len(images_shape) != 4: 22 | raise ValueError( 23 | "`images` must be batched 4D tensor. " 24 | f"Received: images.shape={images_shape}" 25 | ) 26 | images = ops.convert_to_tensor(images) 27 | images = ka_ops.image.transform_dtype(images, images.dtype, "float32") 28 | segmentation_masks = ops.convert_to_tensor(segmentation_masks) 29 | 30 | if not backend.is_int_dtype(segmentation_masks.dtype): 31 | dtype = backend.standardize_dtype(segmentation_masks.dtype) 32 | raise TypeError( 33 | "`segmentation_masks` must be in integer dtype. " 34 | f"Received: segmentation_masks.dtype={dtype}" 35 | ) 36 | 37 | # Infer num_classes 38 | if num_classes is None: 39 | num_classes = int(ops.convert_to_numpy(ops.max(segmentation_masks))) 40 | if color_mapping is None: 41 | colors = _generate_color_palette(num_classes) 42 | else: 43 | colors = [color_mapping[i] for i in range(num_classes)] 44 | valid_masks = ops.not_equal(segmentation_masks, ignore_index) 45 | valid_masks = ops.squeeze(valid_masks, axis=-1) 46 | segmentation_masks = ops.nn.one_hot(segmentation_masks, num_classes) 47 | segmentation_masks = segmentation_masks[..., 0, :] 48 | segmentation_masks = ops.convert_to_numpy(segmentation_masks) 49 | 50 | # Replace class with color 51 | masks = segmentation_masks 52 | masks = np.transpose(masks, axes=(3, 0, 1, 2)).astype("bool") 53 | images_to_draw = ops.convert_to_numpy(images).copy() 54 | for mask, color in zip(masks, colors): 55 | color = np.array(color, dtype=images_to_draw.dtype) 56 | images_to_draw[mask, ...] = color[None, :] 57 | images_to_draw = ops.convert_to_tensor(images_to_draw) 58 | images_to_draw = ka_ops.image.transform_dtype( 59 | images_to_draw, "uint8", "float32" 60 | ) 61 | 62 | # Apply blending 63 | outputs = images * (1 - alpha) + images_to_draw * alpha 64 | outputs = ops.where(valid_masks[..., None], outputs, images) 65 | outputs = ka_ops.image.transform_dtype(outputs, "float32", "uint8") 66 | outputs = ops.convert_to_numpy(outputs) 67 | return outputs 68 | 69 | 70 | def _generate_color_palette(num_classes: int): 71 | palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1]) 72 | return [((i * palette) % 255).tolist() for i in range(num_classes)] 73 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_invert_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras.src.testing.test_utils import named_product 5 | 6 | from keras_aug._src.layers.vision.random_invert import RandomInvert 7 | from keras_aug._src.testing.test_case import TestCase 8 | from keras_aug._src.utils.test_utils import get_images 9 | 10 | 11 | class RandomInvertTest(TestCase): 12 | @parameterized.named_parameters( 13 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 14 | ) 15 | def test_correctness(self, dtype): 16 | import torch 17 | import torchvision.transforms.v2.functional as TF 18 | from keras.src.backend.torch import convert_to_tensor 19 | 20 | np.random.seed(42) 21 | 22 | # Test channels_last 23 | x = get_images(dtype, "channels_last") 24 | layer = RandomInvert(p=1.0, dtype=dtype) 25 | y = layer(x) 26 | 27 | ref_y = TF.invert(convert_to_tensor(np.transpose(x, [0, 3, 1, 2]))) 28 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 29 | self.assertDType(y, dtype) 30 | self.assertAllClose(y, ref_y) 31 | 32 | # Test channels_first 33 | x = get_images(dtype, "channels_first") 34 | layer = RandomInvert(p=1.0, dtype=dtype) 35 | y = layer(x) 36 | 37 | ref_y = TF.invert(convert_to_tensor(x)) 38 | self.assertDType(y, dtype) 39 | self.assertAllClose(y, ref_y) 40 | 41 | # Test p=0.0 42 | x = get_images(dtype, "channels_last") 43 | layer = RandomInvert(p=0.0, dtype=dtype) 44 | y = layer(x) 45 | 46 | self.assertDType(y, dtype) 47 | self.assertAllClose(y, x) 48 | 49 | def test_shape(self): 50 | # Test dynamic shape 51 | x = keras.KerasTensor((None, None, None, 3)) 52 | y = RandomInvert()(x) 53 | self.assertEqual(y.shape, (None, None, None, 3)) 54 | 55 | # Test static shape 56 | x = keras.KerasTensor((None, 32, 32, 3)) 57 | y = RandomInvert()(x) 58 | self.assertEqual(y.shape, (None, 32, 32, 3)) 59 | 60 | def test_model(self): 61 | layer = RandomInvert() 62 | inputs = keras.layers.Input(shape=(None, None, 3)) 63 | outputs = layer(inputs) 64 | model = keras.models.Model(inputs, outputs) 65 | self.assertEqual(model.output_shape, (None, None, None, 3)) 66 | 67 | def test_config(self): 68 | x = get_images("float32", "channels_last") 69 | layer = RandomInvert(p=1.0) 70 | y = layer(x) 71 | 72 | layer = RandomInvert.from_config(layer.get_config()) 73 | y2 = layer(x) 74 | self.assertAllClose(y, y2) 75 | 76 | def test_tf_data_compatibility(self): 77 | import tensorflow as tf 78 | 79 | layer = RandomInvert() 80 | x = get_images("float32", "channels_last") 81 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 82 | for output in ds.take(1): 83 | self.assertIsInstance(output, tf.Tensor) 84 | self.assertEqual(output.shape, (2, 32, 32, 3)) 85 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_channel_permutation.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | from keras_aug._src.keras_aug_export import keras_aug_export 4 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 5 | 6 | 7 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 8 | @keras.saving.register_keras_serializable(package="keras_aug") 9 | class RandomChannelPermutation(VisionRandomLayer): 10 | """Randomly permute the channels of the input images. 11 | 12 | Args: 13 | num_channels: The number of channels to permute. 14 | data_format: A string specifying the data format of the input images. 15 | It can be either `"channels_last"` or `"channels_first"`. 16 | If not specified, the value will be interpreted by 17 | `keras.config.image_data_format`. Defaults to `None`. 18 | """ 19 | 20 | def __init__(self, num_channels: int, data_format=None, **kwargs): 21 | super().__init__(**kwargs) 22 | self.num_channels = int(num_channels) 23 | self.data_format = data_format or keras.config.image_data_format() 24 | 25 | self.channels_axis = -1 if self.data_format == "channels_last" else -3 26 | 27 | def get_params(self, batch_size, images=None, **kwargs): 28 | ops = self.backend 29 | random_generator = self.random_generator 30 | perm = ops.random.uniform( 31 | [batch_size, self.num_channels], seed=random_generator 32 | ) 33 | perm = ops.numpy.argsort(perm, axis=-1) 34 | return perm 35 | 36 | def compute_output_shape(self, input_shape): 37 | images_shape, _ = self._get_shape_or_spec(input_shape) 38 | if images_shape[self.channels_axis] != self.num_channels: 39 | raise ValueError( 40 | "`num_channels` must match the channels of the input images. " 41 | f"Received: images.shape={images_shape}, " 42 | f"num_channels={self.num_channels}" 43 | ) 44 | return input_shape 45 | 46 | def augment_images(self, images, transformations=None, **kwargs): 47 | ops = self.backend 48 | perm = transformations 49 | if self.data_format == "channels_last": 50 | perm = ops.numpy.expand_dims(perm, axis=[1, 2]) 51 | else: 52 | perm = ops.numpy.expand_dims(perm, axis=[2, 3]) 53 | images = ops.numpy.take_along_axis( 54 | images, perm, axis=self.channels_axis 55 | ) 56 | return images 57 | 58 | def augment_labels(self, labels, transformations, **kwargs): 59 | return labels 60 | 61 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 62 | return bounding_boxes 63 | 64 | def augment_segmentation_masks( 65 | self, segmentation_masks, transformations, **kwargs 66 | ): 67 | return segmentation_masks 68 | 69 | def augment_keypoints(self, keypoints, transformations, **kwargs): 70 | return keypoints 71 | 72 | def get_config(self): 73 | config = super().get_config() 74 | config.update({"num_channels": self.num_channels}) 75 | return config 76 | -------------------------------------------------------------------------------- /guides/oxford_yolov8_aug.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import keras 3 | import tensorflow as tf 4 | import tensorflow_datasets as tfds 5 | 6 | from keras_aug import layers as ka_layers 7 | from keras_aug import visualization 8 | 9 | 10 | def load_oxford(name, split, shuffle, batch_size, position): 11 | def unpack_oxford_inputs(x): 12 | segmentation_masks = tf.cast(x["segmentation_mask"], "int8") 13 | segmentation_masks = tf.where( 14 | tf.equal(segmentation_masks, 2), # Background index 15 | tf.constant(-1, dtype=segmentation_masks.dtype), 16 | segmentation_masks, 17 | ) 18 | return { 19 | "images": x["image"], 20 | "segmentation_masks": segmentation_masks, 21 | } 22 | 23 | ds = tfds.load(name, split=split, with_info=False, shuffle_files=shuffle) 24 | ds: tf.data.Dataset = ds.map(lambda x: unpack_oxford_inputs(x)) 25 | ds = ds.shuffle(128, reshuffle_each_iteration=True) 26 | 27 | # You can utilize KerasAug's layers in `tf.data` pipeline. 28 | # The layer will automatically switch to the TensorFlow backend to be 29 | # compatible with `tf.data`. 30 | ds = ds.map( 31 | ka_layers.vision.Resize( 32 | 640, along_long_edge=True, bounding_box_format="xyxy", dtype="uint8" 33 | ) 34 | ) 35 | ds = ds.map( 36 | ka_layers.vision.Pad( 37 | (640, 640), 38 | padding_position=position, 39 | padding_value=114, 40 | bounding_box_format="xyxy", 41 | dtype="uint8", 42 | ) 43 | ) 44 | ds = ds.batch(batch_size) 45 | return ds 46 | 47 | 48 | args = dict(name="oxford_iiit_pet", split="train", shuffle=True, batch_size=16) 49 | ds_tl = load_oxford(**args, position="top_left") 50 | ds_tr = load_oxford(**args, position="top_right") 51 | ds_bl = load_oxford(**args, position="bottom_left") 52 | ds_br = load_oxford(**args, position="bottom_right") 53 | ds = tf.data.Dataset.zip(ds_tl, ds_tr, ds_bl, ds_br) 54 | ds = ds.map( 55 | ka_layers.vision.Mosaic( 56 | (1280, 1280), offset=(0.25, 0.75), padding_value=114, dtype="uint8" 57 | ) 58 | ) 59 | 60 | # You can also utilize KerasAug's layers in a typical Keras manner. 61 | # `augmenter`` will be called just like a regular Keras model, benefiting from 62 | # accelerator (such as GPU & TPU) and compilation. 63 | augmenter = keras.Sequential( 64 | [ 65 | ka_layers.vision.RandomAffine( 66 | translate=0.05, scale=0.25, padding_value=114, dtype="uint8" 67 | ), 68 | ka_layers.vision.CenterCrop((640, 640), dtype="uint8"), 69 | ka_layers.vision.RandomGrayscale(p=0.01), 70 | ka_layers.vision.RandomHSV(hue=0.015, saturation=0.7, value=0.4), 71 | ka_layers.vision.RandomFlip(mode="horizontal"), 72 | ] 73 | ) 74 | 75 | for x in ds.take(1): 76 | x = augmenter(x) 77 | drawed_images = visualization.draw_segmentation_masks( 78 | x["images"], x["segmentation_masks"], num_classes=2 79 | ) 80 | for i_d in range(drawed_images.shape[0]): 81 | output_path = f"output_{i_d}.jpg" 82 | output_image = cv2.cvtColor(drawed_images[i_d], cv2.COLOR_RGB2BGR) 83 | cv2.imwrite(output_path, output_image) 84 | -------------------------------------------------------------------------------- /keras_aug/_src/keras_aug_export.py: -------------------------------------------------------------------------------- 1 | try: 2 | import namex 3 | except ImportError: 4 | namex = None 5 | 6 | # These dicts reference "canonical names" only 7 | # (i.e. the first name an object was registered with). 8 | REGISTERED_NAMES_TO_OBJS = {} 9 | REGISTERED_OBJS_TO_NAMES = {} 10 | 11 | 12 | def register_internal_serializable(path, symbol): 13 | global REGISTERED_NAMES_TO_OBJS 14 | if isinstance(path, (list, tuple)): 15 | name = path[0] 16 | else: 17 | name = path 18 | REGISTERED_NAMES_TO_OBJS[name] = symbol 19 | REGISTERED_OBJS_TO_NAMES[symbol] = name 20 | 21 | 22 | def get_symbol_from_name(name): 23 | return REGISTERED_NAMES_TO_OBJS.get(name, None) 24 | 25 | 26 | def get_name_from_symbol(symbol): 27 | return REGISTERED_OBJS_TO_NAMES.get(symbol, None) 28 | 29 | 30 | if namex: 31 | 32 | class keras_aug_export: 33 | def __init__(self, parent_path): 34 | package = "keras_aug" 35 | 36 | if isinstance(parent_path, str): 37 | export_paths = [parent_path] 38 | elif isinstance(parent_path, list): 39 | export_paths = parent_path 40 | else: 41 | raise ValueError( 42 | f"Invalid type for `parent_path` argument: " 43 | f"Received '{parent_path}' " 44 | f"of type {type(parent_path)}" 45 | ) 46 | for p in export_paths: 47 | if not p.startswith(package): 48 | raise ValueError( 49 | "All `export_path` values should start with " 50 | f"'{package}.'. Received: parent_path={parent_path}" 51 | ) 52 | self.package = package 53 | self.parent_path = parent_path 54 | 55 | def __call__(self, symbol): 56 | if hasattr(symbol, "_api_export_path") and ( 57 | symbol._api_export_symbol_id == id(symbol) 58 | ): 59 | raise ValueError( 60 | f"Symbol {symbol} is already exported as " 61 | f"'{symbol._api_export_path}'. " 62 | f"Cannot also export it to '{self.parent_path}'." 63 | ) 64 | if isinstance(self.parent_path, list): 65 | path = [p + f".{symbol.__name__}" for p in self.parent_path] 66 | elif isinstance(self.parent_path, str): 67 | path = self.parent_path + f".{symbol.__name__}" 68 | symbol._api_export_path = path 69 | symbol._api_export_symbol_id = id(symbol) 70 | 71 | register_internal_serializable(path, symbol) 72 | return symbol 73 | 74 | else: 75 | 76 | class keras_aug_export: 77 | def __init__(self, parent_path): 78 | self.parent_path = parent_path 79 | 80 | def __call__(self, symbol): 81 | if isinstance(self.parent_path, list): 82 | path = [p + f".{symbol.__name__}" for p in self.parent_path] 83 | elif isinstance(self.parent_path, str): 84 | path = self.parent_path + f".{symbol.__name__}" 85 | 86 | register_internal_serializable(path, symbol) 87 | return symbol 88 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/gaussian_blur_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras.src.testing.test_utils import named_product 5 | 6 | from keras_aug._src.layers.vision.gaussian_blur import GaussianBlur 7 | from keras_aug._src.testing.test_case import TestCase 8 | from keras_aug._src.utils.test_utils import get_images 9 | 10 | 11 | class FixedGaussianBlur(GaussianBlur): 12 | def get_params(self, batch_size, images=None, **kwargs): 13 | ops = self.backend 14 | compute_dtype = keras.backend.result_type(self.compute_dtype, float) 15 | sigma = ops.numpy.ones((), dtype=compute_dtype) * 0.1 16 | return sigma 17 | 18 | 19 | class GaussianBlurTest(TestCase): 20 | @parameterized.named_parameters( 21 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 22 | ) 23 | def test_correctness(self, dtype): 24 | import torch 25 | import torchvision.transforms.v2.functional as TF 26 | from keras.src.backend.torch import convert_to_tensor 27 | 28 | # Test channels_last 29 | x = get_images(dtype, "channels_last") 30 | layer = FixedGaussianBlur(3, dtype=dtype) 31 | y = layer(x) 32 | 33 | ref_y = TF.gaussian_blur( 34 | convert_to_tensor(np.transpose(x, [0, 3, 1, 2])), (3, 3), (0.1, 0.1) 35 | ) 36 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 37 | self.assertDType(y, dtype) 38 | self.assertAllClose(y, ref_y) 39 | 40 | # Test channels_first 41 | x = get_images(dtype, "channels_first") 42 | layer = FixedGaussianBlur(3, dtype=dtype) 43 | y = layer(x) 44 | 45 | ref_y = TF.gaussian_blur(convert_to_tensor(x), (3, 3), (0.1, 0.1)) 46 | self.assertDType(y, dtype) 47 | self.assertAllClose(y, ref_y) 48 | 49 | def test_shape(self): 50 | # Test dynamic shape 51 | x = keras.KerasTensor((None, None, None, 3)) 52 | y = GaussianBlur(3)(x) 53 | self.assertEqual(y.shape, (None, None, None, 3)) 54 | 55 | # Test static shape 56 | x = keras.KerasTensor((None, 32, 32, 3)) 57 | y = GaussianBlur(3)(x) 58 | self.assertEqual(y.shape, (None, 32, 32, 3)) 59 | 60 | def test_model(self): 61 | layer = GaussianBlur(3) 62 | inputs = keras.layers.Input(shape=(None, None, 3)) 63 | outputs = layer(inputs) 64 | model = keras.models.Model(inputs, outputs) 65 | self.assertEqual(model.output_shape, (None, None, None, 3)) 66 | 67 | def test_config(self): 68 | x = get_images("float32", "channels_last") 69 | layer = FixedGaussianBlur(3) 70 | y = layer(x) 71 | 72 | layer = FixedGaussianBlur.from_config(layer.get_config()) 73 | y2 = layer(x) 74 | self.assertAllClose(y, y2) 75 | 76 | def test_tf_data_compatibility(self): 77 | import tensorflow as tf 78 | 79 | layer = GaussianBlur(3) 80 | x = get_images("float32", "channels_last") 81 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 82 | for output in ds.take(1): 83 | self.assertIsInstance(output, tf.Tensor) 84 | self.assertEqual(output.shape, (2, 32, 32, 3)) 85 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_sharpen.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | from keras_aug._src.keras_aug_export import keras_aug_export 4 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 5 | 6 | 7 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 8 | @keras.saving.register_keras_serializable(package="keras_aug") 9 | class RandomSharpen(VisionRandomLayer): 10 | """Adjust the sharpness of the input images with a given probability. 11 | 12 | Args: 13 | sharpness_factor: How much to adjust the sharpness. Can be any 14 | non-negative number. 0 gives a blurred image, 1 gives the 15 | original image while 2 increases the sharpness by a factor of 2. 16 | p: A float specifying the probability. Defaults to `0.5`. 17 | data_format: A string specifying the data format of the input images. 18 | It can be either `"channels_last"` or `"channels_first"`. 19 | If not specified, the value will be interpreted by 20 | `keras.config.image_data_format`. Defaults to `None`. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | sharpness_factor: float, 26 | p: float = 0.5, 27 | data_format=None, 28 | **kwargs, 29 | ): 30 | super().__init__(**kwargs) 31 | self.sharpness_factor = float(sharpness_factor) 32 | self.p = float(p) 33 | self.data_format = data_format or keras.config.image_data_format() 34 | 35 | if self.sharpness_factor < 0: 36 | raise ValueError( 37 | "`sharpness_factor` should be a non-negative number. " 38 | f"Received: sharpness_factor={sharpness_factor}" 39 | ) 40 | 41 | def compute_output_shape(self, input_shape): 42 | return input_shape 43 | 44 | def get_params(self, batch_size, images=None, **kwargs): 45 | ops = self.backend 46 | random_generator = self.random_generator 47 | p = ops.random.uniform([batch_size], seed=random_generator) 48 | return p 49 | 50 | def augment_images(self, images, transformations=None, **kwargs): 51 | ops = self.backend 52 | p = transformations 53 | 54 | prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3]) 55 | images = ops.numpy.where( 56 | prob, 57 | self.image_backend.sharpen( 58 | images, 59 | ops.convert_to_tensor([self.sharpness_factor]), 60 | self.data_format, 61 | ), 62 | images, 63 | ) 64 | return images 65 | 66 | def augment_labels(self, labels, transformations, **kwargs): 67 | return labels 68 | 69 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 70 | return bounding_boxes 71 | 72 | def augment_segmentation_masks( 73 | self, segmentation_masks, transformations, **kwargs 74 | ): 75 | return segmentation_masks 76 | 77 | def augment_keypoints(self, keypoints, transformations, **kwargs): 78 | return keypoints 79 | 80 | def get_config(self): 81 | config = super().get_config() 82 | config.update({"sharpness_factor": self.sharpness_factor, "p": self.p}) 83 | return config 84 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/composition/random_apply_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | 4 | from keras_aug._src.layers.composition.random_apply import RandomApply 5 | from keras_aug._src.layers.vision.rand_augment import RandAugment 6 | from keras_aug._src.layers.vision.random_grayscale import RandomGrayscale 7 | from keras_aug._src.layers.vision.resize import Resize 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class RandomApplyTest(TestCase): 13 | def test_correctness(self): 14 | import torch 15 | import torchvision.transforms.v2.functional as TF 16 | from keras.src.backend.torch import convert_to_tensor 17 | 18 | layer = RandomApply(transforms=[RandomGrayscale(p=1.0)], p=1.0) 19 | 20 | x = get_images("float32", "channels_last") 21 | y = layer(x) 22 | 23 | ref_y = TF.rgb_to_grayscale( 24 | convert_to_tensor(np.transpose(x, [0, 3, 1, 2])), 25 | num_output_channels=3, 26 | ) 27 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 28 | self.assertAllClose(y, ref_y) 29 | 30 | # Test p=0.0 31 | layer = RandomApply(transforms=[RandomGrayscale(p=1.0)], p=0.0) 32 | y = layer(x) 33 | 34 | self.assertAllClose(y, x) 35 | 36 | def test_shape(self): 37 | layer = RandomApply(transforms=RandomGrayscale(p=1.0)) 38 | 39 | # Test dynamic shape 40 | x = keras.KerasTensor((None, None, None, 3)) 41 | y = layer(x) 42 | self.assertEqual(y.shape, (None, None, None, 3)) 43 | 44 | # Test static shape 45 | x = keras.KerasTensor((None, 32, 32, 3)) 46 | y = layer(x) 47 | self.assertEqual(y.shape, (None, 32, 32, 3)) 48 | 49 | # Test deterministic shape 50 | transform = Resize((16, 16)) 51 | layer = RandomApply(transforms=transform) 52 | x = keras.KerasTensor((None, 16, 16, 3)) 53 | y = layer(x) 54 | self.assertEqual(y.shape, (None, 16, 16, 3)) 55 | 56 | def test_model(self): 57 | layer = RandomApply(transforms=RandomGrayscale(p=1.0)) 58 | inputs = keras.layers.Input(shape=[None, None, 3]) 59 | outputs = layer(inputs) 60 | model = keras.models.Model(inputs, outputs) 61 | self.assertEqual(model.output_shape, (None, None, None, 3)) 62 | 63 | def test_config(self): 64 | x = get_images("float32", "channels_last") 65 | layer = RandomApply(transforms=RandomGrayscale(p=1.0), p=1.0) 66 | y = layer(x) 67 | 68 | layer = RandomApply.from_config(layer.get_config()) 69 | y2 = layer(x) 70 | self.assertAllClose(y, y2) 71 | 72 | def test_tf_data_compatibility(self): 73 | import tensorflow as tf 74 | 75 | def to_dict(x): 76 | return {"images": x, "labels": tf.convert_to_tensor([0, 1])} 77 | 78 | layer = RandomApply(transforms=[RandAugment()], p=0.5) 79 | x = get_images("float32", "channels_last") 80 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2) 81 | ds = ds.map(to_dict).map(layer) 82 | for output in ds.take(1): 83 | output = output["images"] 84 | self.assertIsInstance(output, tf.Tensor) 85 | self.assertEqual(output.shape, (2, 32, 32, 3)) 86 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/gaussian_noise_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras.src.testing.test_utils import named_product 5 | 6 | from keras_aug._src.layers.vision.gaussian_noise import GaussianNoise 7 | from keras_aug._src.testing.test_case import TestCase 8 | from keras_aug._src.utils.test_utils import get_images 9 | 10 | 11 | class FixedGaussianNoise(GaussianNoise): 12 | def get_params(self, batch_size, images=None, **kwargs): 13 | ops = self.backend 14 | compute_dtype = keras.backend.result_type(self.compute_dtype, float) 15 | noise = ops.numpy.ones(ops.shape(images), dtype=compute_dtype) * 0.5 16 | return noise 17 | 18 | 19 | class GaussianNoiseTest(TestCase): 20 | @parameterized.named_parameters( 21 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 22 | ) 23 | def test_correctness(self, dtype): 24 | atol = 1 if dtype == "uint8" else 1e-2 25 | rtol = 1 if dtype == "uint8" else 1e-2 26 | 27 | # Test channels_last 28 | x = get_images(dtype, "channels_last") 29 | layer = FixedGaussianNoise(dtype=dtype) 30 | y = layer(x) 31 | 32 | if dtype == "uint8": 33 | ref_y = np.clip(x.astype("float32") + 255 * 0.5, 0, 255) 34 | else: 35 | ref_y = np.clip(x + 0.5, 0, 1) 36 | self.assertDType(y, dtype) 37 | self.assertAllClose(y, ref_y, atol=atol, rtol=rtol) 38 | 39 | # Test channels_first 40 | x = get_images(dtype, "channels_first") 41 | layer = FixedGaussianNoise(dtype=dtype) 42 | y = layer(x) 43 | 44 | if dtype == "uint8": 45 | ref_y = np.clip(x.astype("float32") + 255 * 0.5, 0, 255) 46 | else: 47 | ref_y = np.clip(x + 0.5, 0, 1) 48 | self.assertDType(y, dtype) 49 | self.assertAllClose(y, ref_y, atol=atol, rtol=rtol) 50 | 51 | def test_shape(self): 52 | # Test dynamic shape 53 | x = keras.KerasTensor((None, None, None, 3)) 54 | y = GaussianNoise()(x) 55 | self.assertEqual(y.shape, (None, None, None, 3)) 56 | 57 | # Test static shape 58 | x = keras.KerasTensor((None, 32, 32, 3)) 59 | y = GaussianNoise()(x) 60 | self.assertEqual(y.shape, (None, 32, 32, 3)) 61 | 62 | def test_model(self): 63 | layer = GaussianNoise() 64 | inputs = keras.layers.Input(shape=(None, None, 3)) 65 | outputs = layer(inputs) 66 | model = keras.models.Model(inputs, outputs) 67 | self.assertEqual(model.output_shape, (None, None, None, 3)) 68 | 69 | def test_config(self): 70 | x = get_images("float32", "channels_last") 71 | layer = FixedGaussianNoise() 72 | y = layer(x) 73 | 74 | layer = FixedGaussianNoise.from_config(layer.get_config()) 75 | y2 = layer(x) 76 | self.assertAllClose(y, y2) 77 | 78 | def test_tf_data_compatibility(self): 79 | import tensorflow as tf 80 | 81 | layer = GaussianNoise() 82 | x = get_images("float32", "channels_last") 83 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 84 | for output in ds.take(1): 85 | self.assertIsInstance(output, tf.Tensor) 86 | self.assertEqual(output.shape, (2, 32, 32, 3)) 87 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/rand_augment_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.rand_augment import RandAugment 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class RandAugmentTest(TestCase): 13 | @parameterized.named_parameters( 14 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 15 | ) 16 | def test_correctness(self, dtype): 17 | # TODO: Add assertAllClose test 18 | 19 | np.random.seed(42) 20 | 21 | # Test channels_last 22 | x = get_images(dtype, "channels_last") 23 | layer = RandAugment(dtype=dtype) 24 | y = layer(x) 25 | 26 | self.assertDType(y, dtype) 27 | 28 | # Test channels_first 29 | if backend.backend() == "tensorflow": 30 | # Some ops not supported by tensorflow CPU 31 | return 32 | backend.set_image_data_format("channels_first") 33 | x = get_images(dtype, "channels_first") 34 | layer = RandAugment(dtype=dtype) 35 | y = layer(x) 36 | 37 | self.assertDType(y, dtype) 38 | 39 | def test_shape(self): 40 | # Test dynamic shape 41 | x = keras.KerasTensor((None, None, None, 3)) 42 | y = RandAugment()(x) 43 | self.assertEqual(y.shape, (None, None, None, 3)) 44 | 45 | # Test static shape 46 | x = keras.KerasTensor((None, 32, 32, 3)) 47 | y = RandAugment()(x) 48 | self.assertEqual(y.shape, (None, 32, 32, 3)) 49 | 50 | def test_model(self): 51 | # Test dynamic shape 52 | layer = RandAugment() 53 | inputs = keras.layers.Input(shape=[None, None, 3]) 54 | outputs = layer(inputs) 55 | model = keras.models.Model(inputs, outputs) 56 | self.assertEqual(model.output_shape, (None, None, None, 3)) 57 | 58 | # Test static shape 59 | layer = RandAugment() 60 | inputs = keras.layers.Input(shape=[32, 32, 3]) 61 | outputs = layer(inputs) 62 | model = keras.models.Model(inputs, outputs) 63 | self.assertEqual(model.output_shape, (None, 32, 32, 3)) 64 | 65 | def test_config(self): 66 | x = get_images("float32", "channels_last") 67 | layer = RandAugment() 68 | y = layer(x) 69 | 70 | layer = RandAugment.from_config(layer.get_config()) 71 | y2 = layer(x) 72 | self.assertEqual(y.shape, y2.shape) 73 | 74 | # Test `p=0.0` 75 | layer = RandAugment(p=0.0) 76 | y = layer(x) 77 | 78 | layer = RandAugment.from_config(layer.get_config()) 79 | y2 = layer(x) 80 | self.assertAllClose(y, x) 81 | self.assertAllClose(y2, x) 82 | self.assertEqual(y.shape, y2.shape) 83 | 84 | def test_tf_data_compatibility(self): 85 | import tensorflow as tf 86 | 87 | layer = RandAugment() 88 | x = get_images("float32", "channels_last") 89 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 90 | for output in ds.take(1): 91 | self.assertIsInstance(output, tf.Tensor) 92 | self.assertEqual(output.shape, (2, 32, 32, 3)) 93 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/identity_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import ml_dtypes 3 | import numpy as np 4 | from absl.testing import parameterized 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.identity import Identity 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class IdentityTest(TestCase): 13 | @parameterized.named_parameters( 14 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 15 | ) 16 | def test_correctness(self, dtype): 17 | bbox_dtype = ml_dtypes.bfloat16 if dtype == "mixed_bfloat16" else dtype 18 | x = get_images(dtype, "channels_last") 19 | layer = Identity(dtype=dtype) 20 | y = layer(x) 21 | self.assertAllClose(y, x) 22 | 23 | x = { 24 | "images": get_images(dtype, "channels_last"), 25 | "bounding_boxes": { 26 | "boxes": np.random.uniform(0, 1, (2, 10, 4)).astype(bbox_dtype), 27 | "classes": np.random.uniform(0, 1, (2, 10, 5)).astype( 28 | bbox_dtype 29 | ), 30 | }, 31 | "segmentation_masks": np.random.uniform( 32 | 0, 9, (2, 32, 32, 1) 33 | ).astype("int32"), 34 | "keypoints": np.random.uniform(0, 1, (2, 10, 17)).astype( 35 | bbox_dtype 36 | ), 37 | } 38 | y = layer(x) 39 | self.assertDType(y["images"], dtype) 40 | self.assertAllClose(y["images"], x["images"]) 41 | self.assertAllClose( 42 | y["bounding_boxes"]["boxes"], x["bounding_boxes"]["boxes"] 43 | ) 44 | self.assertAllClose( 45 | y["bounding_boxes"]["classes"], x["bounding_boxes"]["classes"] 46 | ) 47 | self.assertAllClose(y["segmentation_masks"], x["segmentation_masks"]) 48 | self.assertAllClose(y["keypoints"], x["keypoints"]) 49 | 50 | def test_shape(self): 51 | # Test dynamic shape 52 | x = keras.KerasTensor((None, None, None, 3)) 53 | y = Identity()(x) 54 | self.assertEqual(y.shape, (None, None, None, 3)) 55 | 56 | # Test static shape 57 | x = keras.KerasTensor((None, 32, 32, 3)) 58 | y = Identity()(x) 59 | self.assertEqual(y.shape, (None, 32, 32, 3)) 60 | 61 | def test_model(self): 62 | layer = Identity() 63 | inputs = keras.layers.Input(shape=(None, None, 3)) 64 | outputs = layer(inputs) 65 | model = keras.models.Model(inputs, outputs) 66 | self.assertEqual(model.output_shape, (None, None, None, 3)) 67 | 68 | def test_config(self): 69 | x = get_images("float32", "channels_last") 70 | layer = Identity() 71 | y = layer(x) 72 | 73 | layer = Identity.from_config(layer.get_config()) 74 | y2 = layer(x) 75 | self.assertAllClose(y, y2) 76 | 77 | def test_tf_data_compatibility(self): 78 | import tensorflow as tf 79 | 80 | layer = Identity() 81 | x = get_images("float32", "channels_last") 82 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 83 | for output in ds.take(1): 84 | self.assertIsInstance(output, tf.Tensor) 85 | self.assertEqual(output.shape, (2, 32, 32, 3)) 86 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "keras-aug" 7 | description = "A library that includes Keras 3 preprocessing and augmentation layers" 8 | keywords = [ 9 | "deep-learning", 10 | "preprocessing", 11 | "augmentation", 12 | "keras", 13 | "jax", 14 | "tensorflow", 15 | "torch", 16 | ] 17 | authors = [{ name = "Hong-Yu Chiu", email = "james77777778@gmail.com" }] 18 | maintainers = [{ name = "Hong-Yu Chiu", email = "james77777778@gmail.com" }] 19 | readme = "README.md" 20 | requires-python = ">=3.9" 21 | license = { text = "Apache License 2.0" } 22 | classifiers = [ 23 | "Programming Language :: Python", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.9", 26 | "Programming Language :: Python :: 3.10", 27 | "Programming Language :: Python :: 3.11", 28 | "Programming Language :: Python :: 3.12", 29 | "Programming Language :: Python :: 3 :: Only", 30 | "Operating System :: Unix", 31 | "Operating System :: MacOS", 32 | "Intended Audience :: Science/Research", 33 | "Topic :: Scientific/Engineering", 34 | "Topic :: Software Development", 35 | ] 36 | dynamic = ["version"] 37 | dependencies = ["keras"] 38 | 39 | [project.urls] 40 | Homepage = "https://github.com/james77777778/keras-aug" 41 | Documentation = "https://github.com/james77777778/keras-aug" 42 | Repository = "https://github.com/james77777778/keras-aug.git" 43 | Issues = "https://github.com/james77777778/keras-aug/issues" 44 | 45 | [project.optional-dependencies] 46 | tests = [ 47 | # linter and formatter 48 | "isort", 49 | "ruff", 50 | "black", 51 | "pytest", 52 | "pytest-cov", 53 | "coverage", 54 | # tool 55 | "pre-commit", 56 | "namex", 57 | ] 58 | 59 | [tool.setuptools.packages] 60 | find = { include = ["keras_aug*"] } 61 | 62 | [tool.setuptools.dynamic] 63 | version = { attr = "keras_aug.__version__" } 64 | 65 | [tool.black] 66 | line-length = 80 67 | 68 | [tool.ruff] 69 | line-length = 80 70 | lint.select = ["E", "W", "F"] 71 | lint.isort.force-single-line = true 72 | exclude = [ 73 | ".venv", 74 | ".vscode", 75 | ".github", 76 | ".devcontainer", 77 | "venv", 78 | "__pycache__", 79 | ] 80 | 81 | [tool.ruff.lint.per-file-ignores] 82 | "**/__init__.py" = ["F401"] 83 | "app.py" = ["E402"] 84 | 85 | [tool.isort] 86 | profile = "black" 87 | force_single_line = true 88 | known_first_party = ["keras_aug"] 89 | line_length = 80 90 | 91 | [tool.pytest.ini_options] 92 | addopts = "-vv --durations 10 --cov --cov-report html --cov-report term:skip-covered --cov-report xml" 93 | testpaths = ["keras_aug"] 94 | filterwarnings = [ 95 | "error", 96 | "ignore::UserWarning", 97 | "ignore::DeprecationWarning", 98 | "ignore::ImportWarning", 99 | "ignore::RuntimeWarning", 100 | "ignore::PendingDeprecationWarning", 101 | "ignore::FutureWarning", 102 | ] 103 | 104 | [tool.coverage.run] 105 | source = ["keras_aug"] 106 | omit = ["**/__init__.py", "*test*"] 107 | 108 | [tool.coverage.report] 109 | exclude_lines = [ 110 | "pragma: no cover", 111 | "@abstract", 112 | "raise NotImplementedError", 113 | "raise ValueError", 114 | ] 115 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/composition/random_choice_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | 4 | from keras_aug._src.layers.composition.random_choice import RandomChoice 5 | from keras_aug._src.layers.vision.identity import Identity 6 | from keras_aug._src.layers.vision.random_grayscale import RandomGrayscale 7 | from keras_aug._src.layers.vision.resize import Resize 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class RandomChoiceTest(TestCase): 13 | def test_correctness(self): 14 | import torch 15 | import torchvision.transforms.v2.functional as TF 16 | from keras.src.backend.torch import convert_to_tensor 17 | 18 | layer = RandomChoice( 19 | transforms=[RandomGrayscale(p=1.0), Identity()], p=[1.0, 0.0] 20 | ) 21 | 22 | x = get_images("float32", "channels_last") 23 | y = layer(x) 24 | 25 | ref_y = TF.rgb_to_grayscale( 26 | convert_to_tensor(np.transpose(x, [0, 3, 1, 2])), 27 | num_output_channels=3, 28 | ) 29 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 30 | self.assertAllClose(y, ref_y) 31 | 32 | # Test p=0.0 33 | layer = RandomChoice( 34 | transforms=[RandomGrayscale(p=1.0), Identity()], p=[0.0, 1.0] 35 | ) 36 | y = layer(x) 37 | 38 | self.assertAllClose(y, x) 39 | 40 | def test_shape(self): 41 | layer = RandomChoice(transforms=[RandomGrayscale(p=1.0), Identity()]) 42 | 43 | # Test dynamic shape 44 | x = keras.KerasTensor((None, None, None, 3)) 45 | y = layer(x) 46 | self.assertEqual(y.shape, (None, None, None, 3)) 47 | 48 | # Test static shape 49 | x = keras.KerasTensor((None, 32, 32, 3)) 50 | y = layer(x) 51 | self.assertEqual(y.shape, (None, 32, 32, 3)) 52 | 53 | # Test deterministic shape 54 | layer = RandomChoice(transforms=[Resize((16, 16)), Resize((16, 16))]) 55 | x = keras.KerasTensor((None, 16, 16, 3)) 56 | y = layer(x) 57 | self.assertEqual(y.shape, (None, 16, 16, 3)) 58 | 59 | def test_model(self): 60 | layer = RandomChoice(transforms=[RandomGrayscale(p=1.0), Identity()]) 61 | inputs = keras.layers.Input(shape=[None, None, 3]) 62 | outputs = layer(inputs) 63 | model = keras.models.Model(inputs, outputs) 64 | self.assertEqual(model.output_shape, (None, None, None, 3)) 65 | 66 | def test_config(self): 67 | x = get_images("float32", "channels_last") 68 | layer = RandomChoice( 69 | transforms=[RandomGrayscale(p=1.0), Identity()], p=[1.0, 0.0] 70 | ) 71 | y = layer(x) 72 | 73 | layer = RandomChoice.from_config(layer.get_config()) 74 | y2 = layer(x) 75 | self.assertAllClose(y, y2) 76 | 77 | def test_tf_data_compatibility(self): 78 | import tensorflow as tf 79 | 80 | layer = RandomChoice( 81 | transforms=[RandomGrayscale(p=1.0), RandomGrayscale(p=1.0)] 82 | ) 83 | x = get_images("float32", "channels_last") 84 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 85 | for output in ds.take(1): 86 | self.assertIsInstance(output, tf.Tensor) 87 | self.assertEqual(output.shape, (2, 32, 32, 3)) 88 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/trivial_augment_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.trivial_augment import TrivialAugmentWide 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class TrivialAugmentWideTest(TestCase): 13 | @parameterized.named_parameters( 14 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 15 | ) 16 | def test_correctness(self, dtype): 17 | # TODO: Add assertAllClose test 18 | 19 | np.random.seed(42) 20 | 21 | # Test channels_last 22 | x = get_images(dtype, "channels_last") 23 | layer = TrivialAugmentWide(dtype=dtype) 24 | y = layer(x) 25 | 26 | self.assertDType(y, dtype) 27 | 28 | # Test channels_first 29 | if backend.backend() == "tensorflow": 30 | # Some ops not supported by tensorflow CPU 31 | return 32 | backend.set_image_data_format("channels_first") 33 | x = get_images(dtype, "channels_first") 34 | layer = TrivialAugmentWide(dtype=dtype) 35 | y = layer(x) 36 | 37 | self.assertDType(y, dtype) 38 | 39 | def test_shape(self): 40 | # Test dynamic shape 41 | x = keras.KerasTensor((None, None, None, 3)) 42 | y = TrivialAugmentWide()(x) 43 | self.assertEqual(y.shape, (None, None, None, 3)) 44 | 45 | # Test static shape 46 | x = keras.KerasTensor((None, 32, 32, 3)) 47 | y = TrivialAugmentWide()(x) 48 | self.assertEqual(y.shape, (None, 32, 32, 3)) 49 | 50 | def test_model(self): 51 | # Test dynamic shape 52 | layer = TrivialAugmentWide() 53 | inputs = keras.layers.Input(shape=[None, None, 3]) 54 | outputs = layer(inputs) 55 | model = keras.models.Model(inputs, outputs) 56 | self.assertEqual(model.output_shape, (None, None, None, 3)) 57 | 58 | # Test static shape 59 | layer = TrivialAugmentWide() 60 | inputs = keras.layers.Input(shape=[32, 32, 3]) 61 | outputs = layer(inputs) 62 | model = keras.models.Model(inputs, outputs) 63 | self.assertEqual(model.output_shape, (None, 32, 32, 3)) 64 | 65 | def test_config(self): 66 | x = get_images("float32", "channels_last") 67 | layer = TrivialAugmentWide() 68 | y = layer(x) 69 | 70 | layer = TrivialAugmentWide.from_config(layer.get_config()) 71 | y2 = layer(x) 72 | self.assertEqual(y.shape, y2.shape) 73 | 74 | # Test `p=0.0` 75 | layer = TrivialAugmentWide(p=0.0) 76 | y = layer(x) 77 | 78 | layer = TrivialAugmentWide.from_config(layer.get_config()) 79 | y2 = layer(x) 80 | self.assertAllClose(y, x) 81 | self.assertAllClose(y2, x) 82 | self.assertEqual(y.shape, y2.shape) 83 | 84 | def test_tf_data_compatibility(self): 85 | import tensorflow as tf 86 | 87 | layer = TrivialAugmentWide() 88 | x = get_images("float32", "channels_last") 89 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 90 | for output in ds.take(1): 91 | self.assertIsInstance(output, tf.Tensor) 92 | self.assertEqual(output.shape, (2, 32, 32, 3)) 93 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/composition/random_order_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | 4 | from keras_aug._src.layers.composition.random_order import RandomOrder 5 | from keras_aug._src.layers.vision.random_grayscale import RandomGrayscale 6 | from keras_aug._src.layers.vision.random_invert import RandomInvert 7 | from keras_aug._src.layers.vision.resize import Resize 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class FixedRandomOrder(RandomOrder): 13 | def get_params(self): 14 | ops = self.backend 15 | fn_idx = ops.convert_to_tensor([1, 0], dtype="int32") 16 | return fn_idx 17 | 18 | 19 | class RandomOrderTest(TestCase): 20 | def test_correctness(self): 21 | import torch 22 | import torchvision.transforms.v2.functional as TF 23 | from keras.src.backend.torch import convert_to_tensor 24 | 25 | layer = FixedRandomOrder( 26 | transforms=[RandomGrayscale(p=1.0), RandomInvert(p=1.0)] 27 | ) 28 | 29 | x = get_images("float32", "channels_last") 30 | y = layer(x) 31 | 32 | ref_y = TF.rgb_to_grayscale( 33 | TF.invert(convert_to_tensor(np.transpose(x, [0, 3, 1, 2]))), 34 | num_output_channels=3, 35 | ) 36 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 37 | self.assertAllClose(y, ref_y) 38 | 39 | def test_shape(self): 40 | layer = RandomOrder( 41 | transforms=[RandomGrayscale(p=1.0), RandomInvert(p=1.0)] 42 | ) 43 | 44 | # Test dynamic shape 45 | x = keras.KerasTensor((None, None, None, 3)) 46 | y = layer(x) 47 | self.assertEqual(y.shape, (None, None, None, 3)) 48 | 49 | # Test static shape 50 | x = keras.KerasTensor((None, 32, 32, 3)) 51 | y = layer(x) 52 | self.assertEqual(y.shape, (None, 32, 32, 3)) 53 | 54 | # Test deterministic shape 55 | layer = RandomOrder( 56 | transforms=[Resize((16, 16)), RandomGrayscale(p=1.0)] 57 | ) 58 | x = keras.KerasTensor((None, 16, 16, 3)) 59 | y = layer(x) 60 | self.assertEqual(y.shape, (None, 16, 16, 3)) 61 | 62 | def test_model(self): 63 | layer = RandomOrder( 64 | transforms=[RandomGrayscale(p=1.0), RandomInvert(p=1.0)] 65 | ) 66 | inputs = keras.layers.Input(shape=[None, None, 3]) 67 | outputs = layer(inputs) 68 | model = keras.models.Model(inputs, outputs) 69 | self.assertEqual(model.output_shape, (None, None, None, 3)) 70 | 71 | def test_config(self): 72 | x = get_images("float32", "channels_last") 73 | layer = FixedRandomOrder( 74 | transforms=[RandomGrayscale(p=1.0), RandomInvert(p=1.0)] 75 | ) 76 | y = layer(x) 77 | 78 | layer = FixedRandomOrder.from_config(layer.get_config()) 79 | y2 = layer(x) 80 | self.assertAllClose(y, y2) 81 | 82 | def test_tf_data_compatibility(self): 83 | import tensorflow as tf 84 | 85 | layer = RandomOrder( 86 | transforms=[RandomGrayscale(p=1.0), RandomInvert(p=1.0)] 87 | ) 88 | x = get_images("float32", "channels_last") 89 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 90 | for output in ds.take(1): 91 | self.assertIsInstance(output, tf.Tensor) 92 | self.assertEqual(output.shape, (2, 32, 32, 3)) 93 | -------------------------------------------------------------------------------- /docs/generate_object_detection_gif.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_datasets as tfds 3 | from PIL import Image 4 | 5 | from keras_aug import layers as ka_layers 6 | from keras_aug import ops as ka_ops 7 | from keras_aug import visualization 8 | 9 | size = (320, 320) 10 | mosaic_size = (640, 640) 11 | 12 | 13 | def load_voc(name, split, shuffle, batch_size, position): 14 | def unpack_voc_inputs(x): 15 | image = x["image"] 16 | image_shape = tf.shape(image) 17 | height, width = image_shape[-3], image_shape[-2] 18 | boxes = ka_ops.bounding_box.convert_format( 19 | x["objects"]["bbox"], 20 | source="rel_yxyx", 21 | target="xyxy", 22 | height=height, 23 | width=width, 24 | ) 25 | bounding_boxes = {"classes": x["objects"]["label"], "boxes": boxes} 26 | return {"images": image, "bounding_boxes": bounding_boxes} 27 | 28 | ds = tfds.load(name, split=split, with_info=False, shuffle_files=shuffle) 29 | ds: tf.data.Dataset = ds.map(lambda x: unpack_voc_inputs(x)) 30 | ds = ds.map(ka_layers.vision.MaxBoundingBox(40)) # Max: 37 in train 31 | ds = ds.shuffle(128, reshuffle_each_iteration=True) 32 | ds = ds.map( 33 | ka_layers.vision.Resize( 34 | size[0], 35 | along_long_edge=True, 36 | bounding_box_format="xyxy", 37 | dtype="uint8", 38 | ) 39 | ) 40 | ds = ds.map( 41 | ka_layers.vision.Pad( 42 | size, 43 | padding_position=position, 44 | padding_value=114, 45 | bounding_box_format="xyxy", 46 | dtype="uint8", 47 | ) 48 | ) 49 | ds = ds.batch(batch_size) 50 | return ds 51 | 52 | 53 | # Load dataset 54 | args = dict(name="voc/2007", split="train", shuffle=True, batch_size=16) 55 | ds_tl = load_voc(**args, position="top_left") 56 | ds_tr = load_voc(**args, position="top_right") 57 | ds_bl = load_voc(**args, position="bottom_left") 58 | ds_br = load_voc(**args, position="bottom_right") 59 | ds = tf.data.Dataset.zip(ds_tl, ds_tr, ds_bl, ds_br) 60 | 61 | # Augment 62 | ds = ds.map( 63 | ka_layers.vision.Mosaic( 64 | mosaic_size, 65 | offset=(0.25, 0.75), 66 | padding_value=114, 67 | bounding_box_format="xyxy", 68 | dtype="uint8", 69 | ) 70 | ) 71 | ds = ds.map( 72 | ka_layers.vision.RandomAffine( 73 | translate=0.05, 74 | scale=0.25, 75 | padding_value=114, 76 | bounding_box_format="xyxy", 77 | dtype="uint8", 78 | ) 79 | ) 80 | ds = ds.map( 81 | ka_layers.vision.CenterCrop(size, bounding_box_format="xyxy", dtype="uint8") 82 | ) 83 | ds = ds.map(ka_layers.vision.RandomGrayscale(p=0.01)) 84 | ds = ds.map(ka_layers.vision.RandomHSV(hue=0.015, saturation=0.7, value=0.4)) 85 | ds = ds.map( 86 | ka_layers.vision.RandomFlip(mode="horizontal", bounding_box_format="xyxy") 87 | ) 88 | 89 | # Make gif 90 | images = [] 91 | for x in ds.take(1): 92 | drawed_images = visualization.draw_bounding_boxes( 93 | x["images"], x["bounding_boxes"], bounding_box_format="xyxy" 94 | ) 95 | for i in range(drawed_images.shape[0]): 96 | images.append(Image.fromarray(drawed_images[i])) 97 | images[0].save( 98 | "output.gif", 99 | save_all=True, 100 | append_images=images[1:10], 101 | optimize=False, 102 | duration=1000, 103 | loop=0, 104 | ) 105 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/to_dtype_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.to_dtype import ToDType 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class ToDTypeTest(TestCase): 13 | @parameterized.named_parameters( 14 | named_product( 15 | from_dtype=["uint8", "int16", "int32", "bfloat16", "float32"], 16 | to_dtype=["uint8", "int16", "bfloat16", "float32"], 17 | scale=[True, False], 18 | ) 19 | ) 20 | def test_correctness(self, from_dtype, to_dtype, scale): 21 | import torch 22 | import torchvision.transforms.v2.functional as TF 23 | from keras.src.backend.torch import convert_to_tensor 24 | from keras.src.backend.torch import to_torch_dtype 25 | 26 | # Test channels_last 27 | x = get_images(from_dtype, "channels_last") 28 | layer = ToDType(to_dtype, scale) 29 | y = layer(x) 30 | 31 | ref_y = TF.to_dtype( 32 | convert_to_tensor(np.transpose(x, [0, 3, 1, 2])), 33 | dtype=to_torch_dtype(to_dtype), 34 | scale=scale, 35 | ) 36 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 37 | self.assertDType(y, to_dtype) 38 | if from_dtype == "bfloat16" and to_dtype in ("uint8", "int16"): 39 | return 40 | self.assertAllClose(y, ref_y) 41 | 42 | def test_shape(self): 43 | # Test dynamic shape 44 | x = keras.KerasTensor((None, None, None, 3)) 45 | y = ToDType("float32", scale=True)(x) 46 | self.assertEqual(y.shape, (None, None, None, 3)) 47 | backend.set_image_data_format("channels_first") 48 | x = keras.KerasTensor((None, 3, None, None)) 49 | y = ToDType("float32", scale=True)(x) 50 | self.assertEqual(y.shape, (None, 3, None, None)) 51 | 52 | # Test static shape 53 | backend.set_image_data_format("channels_last") 54 | x = keras.KerasTensor((None, 32, 32, 3)) 55 | y = ToDType("float32", scale=True)(x) 56 | self.assertEqual(y.shape, (None, 32, 32, 3)) 57 | backend.set_image_data_format("channels_first") 58 | x = keras.KerasTensor((None, 3, 32, 32)) 59 | y = ToDType("float32", scale=True)(x) 60 | self.assertEqual(y.shape, (None, 3, 32, 32)) 61 | 62 | def test_model(self): 63 | layer = ToDType("float32", scale=True) 64 | inputs = keras.layers.Input(shape=[None, None, 5]) 65 | outputs = layer(inputs) 66 | model = keras.models.Model(inputs, outputs) 67 | self.assertEqual(model.output_shape, (None, None, None, 5)) 68 | 69 | def test_config(self): 70 | x = get_images("float32", "channels_last") 71 | layer = ToDType("float32", scale=True) 72 | y = layer(x) 73 | 74 | layer = ToDType.from_config(layer.get_config()) 75 | y2 = layer(x) 76 | self.assertAllClose(y, y2) 77 | 78 | def test_tf_data_compatibility(self): 79 | import tensorflow as tf 80 | 81 | layer = ToDType("float32", scale=True) 82 | x = get_images("float32", "channels_last") 83 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 84 | for output in ds.take(1): 85 | self.assertIsInstance(output, tf.Tensor) 86 | self.assertEqual(output.shape, (2, 32, 32, 3)) 87 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/normalize.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import keras 4 | from keras import backend 5 | 6 | from keras_aug._src.keras_aug_export import keras_aug_export 7 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 8 | from keras_aug._src.utils.argument_validation import standardize_data_format 9 | 10 | 11 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 12 | @keras.saving.register_keras_serializable(package="keras_aug") 13 | class Normalize(VisionRandomLayer): 14 | """Normalize the images with mean and standard deviation. 15 | 16 | This layer will normalize each channel of the images: 17 | `y[c] = (x[c] - mean[c]) / std[c]`. 18 | 19 | Args: 20 | mean: Sequence of means for each channel. Defaults to 21 | `(0.485, 0.456, 0.406)` which is the mean values from ImageNet. 22 | std: Sequence of standard deviations for each channel. Defaults to 23 | `(0.229, 0.224, 0.225)` which is the std values from ImageNet. 24 | data_format: A string specifying the data format of the input images. 25 | It can be either `"channels_last"` or `"channels_first"`. 26 | If not specified, the value will be interpreted by 27 | `keras.config.image_data_format`. Defaults to `None`. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | mean: typing.Sequence[float] = (0.485, 0.456, 0.406), 33 | std: typing.Sequence[float] = (0.229, 0.224, 0.225), 34 | data_format: typing.Optional[str] = None, 35 | **kwargs, 36 | ): 37 | super().__init__(has_generator=False, **kwargs) 38 | self.mean = tuple(mean) 39 | self.std = tuple(std) 40 | self.data_format = standardize_data_format(data_format) 41 | 42 | if not backend.is_float_dtype(self.compute_dtype): 43 | dtype = self.dtype_policy 44 | raise ValueError( 45 | "The `dtype` of Normalize must be float. " 46 | f"Received: dtype={dtype}" 47 | ) 48 | 49 | def compute_output_shape(self, input_shape): 50 | return input_shape 51 | 52 | def augment_images(self, images, transformations, **kwargs): 53 | ops = self.backend 54 | original_dtype = backend.standardize_dtype(images.dtype) 55 | compute_dtype = backend.result_type(original_dtype, float) 56 | mean = ops.cast(self.mean, compute_dtype) 57 | std = ops.cast(self.std, compute_dtype) 58 | if self.data_format == "channels_last": 59 | mean = ops.numpy.expand_dims(mean, axis=[0, 1, 2]) 60 | std = ops.numpy.expand_dims(std, axis=[0, 1, 2]) 61 | else: 62 | mean = ops.numpy.expand_dims(mean, axis=[0, 2, 3]) 63 | std = ops.numpy.expand_dims(std, axis=[0, 2, 3]) 64 | images = ops.numpy.subtract(images, mean) 65 | images = ops.numpy.divide(images, std) 66 | return ops.cast(images, original_dtype) 67 | 68 | def augment_labels(self, labels, transformations, **kwargs): 69 | return labels 70 | 71 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 72 | return bounding_boxes 73 | 74 | def augment_segmentation_masks( 75 | self, segmentation_masks, transformations, **kwargs 76 | ): 77 | return segmentation_masks 78 | 79 | def augment_keypoints(self, keypoints, transformations, **kwargs): 80 | return keypoints 81 | 82 | def get_config(self): 83 | config = super().get_config() 84 | config.update({"mean": self.mean, "std": self.std}) 85 | return config 86 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/mix_up_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.mix_up import MixUp 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class FixedMixUp(MixUp): 13 | def get_params(self, batch_size, images=None, **kwargs): 14 | ops = self.backend 15 | lam = ops.numpy.ones([batch_size]) * 0.5 16 | return lam 17 | 18 | 19 | class MixUpTest(TestCase): 20 | @parameterized.named_parameters( 21 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 22 | ) 23 | def test_correctness(self, dtype): 24 | atol = 1e-2 if "float" in dtype else 0.5 25 | 26 | # Test channels_last 27 | images = get_images(dtype, "channels_last") 28 | labels = np.array([0, 1], "float32") 29 | inputs = {"images": images, "labels": labels} 30 | layer = FixedMixUp(num_classes=2, dtype=dtype) 31 | outputs = layer(inputs) 32 | 33 | self.assertDType(outputs["images"], dtype) 34 | self.assertAllClose( 35 | outputs["images"][0], images[0] / 2.0 + images[1] / 2.0, atol=atol 36 | ) 37 | self.assertAllClose( 38 | outputs["images"][1], images[0] / 2.0 + images[1] / 2.0, atol=atol 39 | ) 40 | self.assertAllClose(outputs["labels"], [[0.5, 0.5], [0.5, 0.5]]) 41 | 42 | # Test channels_first 43 | backend.set_image_data_format("channels_first") 44 | images = get_images(dtype, "channels_first") 45 | labels = np.array([0, 1], "float32") 46 | inputs = {"images": images, "labels": labels} 47 | layer = FixedMixUp(num_classes=2, dtype=dtype) 48 | outputs = layer(inputs) 49 | 50 | self.assertDType(outputs["images"], dtype) 51 | self.assertAllClose( 52 | outputs["images"][0], images[0] / 2.0 + images[1] / 2.0, atol=atol 53 | ) 54 | self.assertAllClose( 55 | outputs["images"][1], images[0] / 2.0 + images[1] / 2.0, atol=atol 56 | ) 57 | self.assertAllClose(outputs["labels"], [[0.5, 0.5], [0.5, 0.5]]) 58 | 59 | def test_shape(self): 60 | # Test dynamic shape 61 | x = keras.KerasTensor((None, None, None, 3)) 62 | y = MixUp()(x) 63 | self.assertEqual(y.shape, (None, None, None, 3)) 64 | 65 | # Test static shape 66 | x = keras.KerasTensor((None, 32, 32, 3)) 67 | y = MixUp()(x) 68 | self.assertEqual(y.shape, (None, 32, 32, 3)) 69 | 70 | def test_model(self): 71 | layer = MixUp() 72 | inputs = keras.layers.Input(shape=(None, None, 3)) 73 | outputs = layer(inputs) 74 | model = keras.models.Model(inputs, outputs) 75 | self.assertEqual(model.output_shape, (None, None, None, 3)) 76 | 77 | def test_config(self): 78 | x = get_images("float32", "channels_last") 79 | layer = FixedMixUp() 80 | y = layer(x) 81 | 82 | layer = FixedMixUp.from_config(layer.get_config()) 83 | y2 = layer(x) 84 | self.assertAllClose(y, y2) 85 | 86 | def test_tf_data_compatibility(self): 87 | import tensorflow as tf 88 | 89 | layer = MixUp() 90 | x = get_images("float32", "channels_last") 91 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 92 | for output in ds.take(1): 93 | self.assertIsInstance(output, tf.Tensor) 94 | self.assertEqual(output.shape, (2, 32, 32, 3)) 95 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/normalize_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.normalize import Normalize 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class NormalizeTest(TestCase): 13 | mean = (0.485, 0.456, 0.406) 14 | std = (0.229, 0.224, 0.225) 15 | 16 | @parameterized.named_parameters( 17 | named_product(dtype=["float32", "mixed_bfloat16"]) 18 | ) 19 | def test_correctness(self, dtype): 20 | import torch 21 | import torchvision.transforms.v2.functional as TF 22 | from keras.src.backend.torch import convert_to_tensor 23 | 24 | # Test channels_last 25 | x = get_images(dtype, "channels_last") 26 | layer = Normalize(self.mean, self.std, dtype=dtype) 27 | y = layer(x) 28 | 29 | ref_y = TF.normalize( 30 | convert_to_tensor(np.transpose(x, [0, 3, 1, 2])), 31 | self.mean, 32 | self.std, 33 | ) 34 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 35 | self.assertDType(y, dtype) 36 | self.assertAllClose(y, ref_y) 37 | 38 | # Test channels_first 39 | backend.set_image_data_format("channels_first") 40 | x = get_images(dtype, "channels_first") 41 | layer = Normalize(self.mean, self.std, dtype=dtype) 42 | y = layer(x) 43 | 44 | ref_y = TF.normalize(convert_to_tensor(x), self.mean, self.std) 45 | self.assertDType(y, dtype) 46 | self.assertAllClose(y, ref_y) 47 | 48 | def test_shape(self): 49 | # Test dynamic shape 50 | x = keras.KerasTensor((None, None, None, 3)) 51 | y = Normalize(self.mean, self.std)(x) 52 | self.assertEqual(y.shape, (None, None, None, 3)) 53 | backend.set_image_data_format("channels_first") 54 | x = keras.KerasTensor((None, 3, None, None)) 55 | y = Normalize(self.mean, self.std)(x) 56 | self.assertEqual(y.shape, (None, 3, None, None)) 57 | 58 | # Test static shape 59 | backend.set_image_data_format("channels_last") 60 | x = keras.KerasTensor((None, 32, 32, 3)) 61 | y = Normalize(self.mean, self.std)(x) 62 | self.assertEqual(y.shape, (None, 32, 32, 3)) 63 | backend.set_image_data_format("channels_first") 64 | x = keras.KerasTensor((None, 3, 32, 32)) 65 | y = Normalize(self.mean, self.std)(x) 66 | self.assertEqual(y.shape, (None, 3, 32, 32)) 67 | 68 | def test_model(self): 69 | layer = Normalize((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)) 70 | inputs = keras.layers.Input(shape=[None, None, 5]) 71 | outputs = layer(inputs) 72 | model = keras.models.Model(inputs, outputs) 73 | self.assertEqual(model.output_shape, (None, None, None, 5)) 74 | 75 | def test_config(self): 76 | x = get_images("float32", "channels_last") 77 | layer = Normalize(self.mean, self.std) 78 | y = layer(x) 79 | 80 | layer = Normalize.from_config(layer.get_config()) 81 | y2 = layer(x) 82 | self.assertAllClose(y, y2) 83 | 84 | def test_tf_data_compatibility(self): 85 | import tensorflow as tf 86 | 87 | layer = Normalize(self.mean, self.std) 88 | x = get_images("float32", "channels_last") 89 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 90 | for output in ds.take(1): 91 | self.assertIsInstance(output, tf.Tensor) 92 | self.assertEqual(output.shape, (2, 32, 32, 3)) 93 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/max_bounding_box.py: -------------------------------------------------------------------------------- 1 | import keras 2 | 3 | from keras_aug._src.keras_aug_export import keras_aug_export 4 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 5 | 6 | 7 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 8 | @keras.saving.register_keras_serializable(package="keras_aug") 9 | class MaxBoundingBox(VisionRandomLayer): 10 | """Ensure the maximum number of bounding boxes. 11 | 12 | Args: 13 | max_number: Desired output number of bounding boxs. 14 | padding_value: The padding value of the `boxes` and `classes` in 15 | `bounding_boxes`. Defaults to `-1`. 16 | """ 17 | 18 | def __init__(self, max_number, fill_value=-1, **kwargs): 19 | super().__init__(has_generator=False, **kwargs) 20 | self.max_number = int(max_number) 21 | self.fill_value = int(fill_value) 22 | 23 | def compute_output_shape(self, input_shape): 24 | if isinstance(input_shape, dict) and "bounding_boxes" in input_shape: 25 | input_keys = set(input_shape["bounding_boxes"].keys()) 26 | extra_keys = input_keys - set(("boxes", "classes")) 27 | if extra_keys: 28 | raise KeyError( 29 | "There are unsupported keys in `bounding_boxes`: " 30 | f"{list(extra_keys)}. " 31 | "Only `boxes` and `classes` are supported." 32 | ) 33 | boxes_shape = list(input_shape["bounding_boxes"]["boxes"]) 34 | boxes_shape[1] = self.max_number 35 | classes_shape = list(input_shape["bounding_boxes"]["classes"]) 36 | classes_shape[1] = self.max_number 37 | input_shape["bounding_boxes"]["boxes"] = boxes_shape 38 | input_shape["bounding_boxes"]["classes"] = classes_shape 39 | return input_shape 40 | 41 | def augment_images(self, images, transformations, **kwargs): 42 | return images 43 | 44 | def augment_labels(self, labels, transformations, **kwargs): 45 | return labels 46 | 47 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 48 | ops = self.backend 49 | boxes = bounding_boxes["boxes"] 50 | classes = bounding_boxes["classes"] 51 | boxes_shape = ops.shape(boxes) 52 | batch_size = boxes_shape[0] 53 | num_boxes = boxes_shape[1] 54 | 55 | # Get pad size 56 | pad_size = ops.numpy.maximum( 57 | ops.numpy.subtract(self.max_number, num_boxes), 0 58 | ) 59 | boxes = boxes[:, : self.max_number, ...] 60 | boxes = ops.numpy.pad( 61 | boxes, 62 | [[0, 0], [0, pad_size], [0, 0]], 63 | constant_values=self.fill_value, 64 | ) 65 | classes = classes[:, : self.max_number] 66 | classes = ops.numpy.pad( 67 | classes, [[0, 0], [0, pad_size]], constant_values=self.fill_value 68 | ) 69 | 70 | # Ensure shape 71 | boxes = ops.numpy.reshape(boxes, [batch_size, self.max_number, 4]) 72 | classes = ops.numpy.reshape(classes, [batch_size, self.max_number]) 73 | 74 | bounding_boxes = bounding_boxes.copy() 75 | bounding_boxes["boxes"] = boxes 76 | bounding_boxes["classes"] = classes 77 | return bounding_boxes 78 | 79 | def augment_segmentation_masks( 80 | self, segmentation_masks, transformations, **kwargs 81 | ): 82 | return segmentation_masks 83 | 84 | def augment_keypoints(self, keypoints, transformations, **kwargs): 85 | return keypoints 86 | 87 | def get_config(self): 88 | config = super().get_config() 89 | config.update({"max_number": self.max_number}) 90 | return config 91 | -------------------------------------------------------------------------------- /.github/workflows/actions.yml: -------------------------------------------------------------------------------- 1 | # Ref: https://github.com/keras-team/keras/blob/master/.github/workflows/actions.yml 2 | name: Tests 3 | 4 | on: 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | release: 9 | types: [created] 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | format: 16 | name: Check the code format 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Set up Python 3.9 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.9' 24 | - name: Lint 25 | uses: pre-commit/action@v3.0.1 26 | - name: Get pip cache dir 27 | id: pip-cache 28 | run: | 29 | python -m pip install --upgrade pip setuptools 30 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 31 | - name: Cache pip 32 | uses: actions/cache@v4 33 | with: 34 | path: ${{ steps.pip-cache.outputs.dir }} 35 | key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements_ci.txt') }} 36 | - name: Install dependencies 37 | run: | 38 | pip install -r requirements_ci.txt --progress-bar off --upgrade 39 | pip install -e ".[tests]" --progress-bar off --upgrade 40 | - name: Check for API changes 41 | run: | 42 | bash shell/api_gen.sh 43 | git status 44 | clean=$(git status | grep "nothing to commit") 45 | if [ -z "$clean" ]; then 46 | echo "Please run shell/api_gen.sh to generate API." 47 | exit 1 48 | fi 49 | 50 | build: 51 | strategy: 52 | fail-fast: false 53 | matrix: 54 | backend: [tensorflow, jax, torch] 55 | version: [keras-stable] 56 | include: 57 | - backend: jax 58 | version: keras-3.4.1 59 | - backend: jax 60 | version: keras-nightly 61 | name: Run tests 62 | runs-on: ubuntu-latest 63 | env: 64 | KERAS_BACKEND: ${{ matrix.backend }} 65 | steps: 66 | - uses: actions/checkout@v4 67 | - name: Set up Python 3.9 68 | uses: actions/setup-python@v5 69 | with: 70 | python-version: '3.9' 71 | - name: Get pip cache dir 72 | id: pip-cache 73 | run: | 74 | python -m pip install --upgrade pip setuptools 75 | echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT 76 | - name: Cache pip 77 | uses: actions/cache@v4 78 | with: 79 | path: ${{ steps.pip-cache.outputs.dir }} 80 | key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements_ci.txt') }} 81 | - name: Install dependencies 82 | run: | 83 | pip install -r requirements_ci.txt --progress-bar off --upgrade 84 | pip install -e ".[tests]" --progress-bar off --upgrade 85 | - name: Pin Keras 3.4.1 86 | if: ${{ matrix.version == 'keras-3.4.1'}} 87 | run: | 88 | pip uninstall -y keras 89 | pip install keras==3.4.1 --progress-bar off 90 | - name: Pin Keras Nightly 91 | if: ${{ matrix.version == 'keras-nightly'}} 92 | run: | 93 | pip uninstall -y keras 94 | pip install keras-nightly --progress-bar off 95 | - name: Test with pytest 96 | run: | 97 | pytest 98 | coverage xml -o coverage.xml 99 | - name: Upload coverage reports to Codecov 100 | uses: codecov/codecov-action@v4 101 | with: 102 | token: ${{ secrets.CODECOV_TOKEN }} 103 | files: coverage.xml 104 | flags: keras-aug,keras-aug-${{ matrix.backend }} 105 | fail_ci_if_error: false 106 | -------------------------------------------------------------------------------- /guides/voc_yolov8_aug.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import keras 3 | import tensorflow as tf 4 | import tensorflow_datasets as tfds 5 | 6 | from keras_aug import layers as ka_layers 7 | from keras_aug import ops as ka_ops 8 | from keras_aug import visualization 9 | 10 | 11 | def load_voc(name, split, shuffle, batch_size, position): 12 | def unpack_voc_inputs(x): 13 | image = x["image"] 14 | image_shape = tf.shape(image) 15 | height, width = image_shape[-3], image_shape[-2] 16 | boxes = ka_ops.bounding_box.convert_format( 17 | x["objects"]["bbox"], 18 | source="rel_yxyx", 19 | target="xyxy", 20 | height=height, 21 | width=width, 22 | ) 23 | bounding_boxes = {"classes": x["objects"]["label"], "boxes": boxes} 24 | return {"images": image, "bounding_boxes": bounding_boxes} 25 | 26 | ds = tfds.load(name, split=split, with_info=False, shuffle_files=shuffle) 27 | ds: tf.data.Dataset = ds.map(lambda x: unpack_voc_inputs(x)) 28 | 29 | # You can utilize KerasAug's layers in `tf.data` pipeline. 30 | # The layer will automatically switch to the TensorFlow backend to be 31 | # compatible with `tf.data`. 32 | ds = ds.map(ka_layers.vision.MaxBoundingBox(40)) # Max: 37 in train 33 | ds = ds.shuffle(128, reshuffle_each_iteration=True) 34 | ds = ds.map( 35 | ka_layers.vision.Resize( 36 | 640, along_long_edge=True, bounding_box_format="xyxy", dtype="uint8" 37 | ) 38 | ) 39 | ds = ds.map( 40 | ka_layers.vision.Pad( 41 | (640, 640), 42 | padding_position=position, 43 | padding_value=114, 44 | bounding_box_format="xyxy", 45 | dtype="uint8", 46 | ) 47 | ) 48 | ds = ds.batch(batch_size) 49 | return ds 50 | 51 | 52 | args = dict(name="voc/2007", split="train", shuffle=True, batch_size=16) 53 | ds_tl = load_voc(**args, position="top_left") 54 | ds_tr = load_voc(**args, position="top_right") 55 | ds_bl = load_voc(**args, position="bottom_left") 56 | ds_br = load_voc(**args, position="bottom_right") 57 | ds = tf.data.Dataset.zip(ds_tl, ds_tr, ds_bl, ds_br) 58 | ds = ds.map( 59 | ka_layers.vision.Mosaic( 60 | (1280, 1280), 61 | offset=(0.25, 0.75), 62 | padding_value=114, 63 | bounding_box_format="xyxy", 64 | dtype="uint8", 65 | ) 66 | ) 67 | 68 | # You can also utilize KerasAug's layers in a typical Keras manner. 69 | # `augmenter`` will be called just like a regular Keras model, benefiting from 70 | # accelerator (such as GPU & TPU) and compilation. 71 | augmenter = keras.Sequential( 72 | [ 73 | ka_layers.vision.RandomAffine( 74 | translate=0.05, 75 | scale=0.25, 76 | padding_value=114, 77 | bounding_box_format="xyxy", 78 | dtype="uint8", 79 | ), 80 | ka_layers.vision.CenterCrop( 81 | (640, 640), bounding_box_format="xyxy", dtype="uint8" 82 | ), 83 | ka_layers.vision.RandomGrayscale(p=0.01), 84 | ka_layers.vision.RandomHSV(hue=0.015, saturation=0.7, value=0.4), 85 | ka_layers.vision.RandomFlip( 86 | mode="horizontal", bounding_box_format="xyxy" 87 | ), 88 | ] 89 | ) 90 | 91 | for x in ds.take(1): 92 | x = augmenter(x) 93 | drawed_images = visualization.draw_bounding_boxes( 94 | x["images"], x["bounding_boxes"], bounding_box_format="xyxy" 95 | ) 96 | cv2.imwrite("output.jpg", drawed_images[0]) 97 | for i_d in range(drawed_images.shape[0]): 98 | output_path = f"output_{i_d}.jpg" 99 | output_image = cv2.cvtColor(drawed_images[i_d], cv2.COLOR_RGB2BGR) 100 | cv2.imwrite(output_path, output_image) 101 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/cut_mix_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.cut_mix import CutMix 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class FixedCutMix(CutMix): 13 | def get_params(self, batch_size, images=None, **kwargs): 14 | ops = self.backend 15 | top = ops.numpy.zeros([batch_size]) 16 | left = ops.numpy.zeros([batch_size]) 17 | bottom = ops.numpy.ones([batch_size]) * 10 18 | right = ops.numpy.ones([batch_size]) * 10 19 | lam = ops.numpy.ones([batch_size]) * 0.5 20 | return dict(top=top, bottom=bottom, left=left, right=right, lam=lam) 21 | 22 | 23 | class CutMixTest(TestCase): 24 | @parameterized.named_parameters( 25 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 26 | ) 27 | def test_correctness(self, dtype): 28 | # Test channels_last 29 | images = get_images(dtype, "channels_last") 30 | labels = np.array([0, 1], "float32") 31 | inputs = {"images": images, "labels": labels} 32 | layer = FixedCutMix(num_classes=2, dtype=dtype) 33 | outputs = layer(inputs) 34 | 35 | self.assertDType(outputs["images"], dtype) 36 | self.assertAllClose( 37 | outputs["images"][0, 0:10, 0:10, :], images[1, 0:10, 0:10, :] 38 | ) 39 | self.assertAllClose( 40 | outputs["images"][0, 10:, 10:, :], images[0, 10:, 10:, :] 41 | ) 42 | self.assertAllClose(outputs["labels"], [[0.5, 0.5], [0.5, 0.5]]) 43 | 44 | # Test channels_first 45 | backend.set_image_data_format("channels_first") 46 | images = get_images(dtype, "channels_first") 47 | labels = np.array([0, 1], "float32") 48 | inputs = {"images": images, "labels": labels} 49 | layer = FixedCutMix(num_classes=2, dtype=dtype) 50 | outputs = layer(inputs) 51 | 52 | self.assertDType(outputs["images"], dtype) 53 | self.assertAllClose( 54 | outputs["images"][0, :, 0:10, 0:10], images[1, :, 0:10, 0:10] 55 | ) 56 | self.assertAllClose( 57 | outputs["images"][0, :, 10:, 10:], images[0, :, 10:, 10:] 58 | ) 59 | self.assertAllClose(outputs["labels"], [[0.5, 0.5], [0.5, 0.5]]) 60 | 61 | def test_shape(self): 62 | # Test dynamic shape 63 | x = keras.KerasTensor((None, None, None, 3)) 64 | y = CutMix()(x) 65 | self.assertEqual(y.shape, (None, None, None, 3)) 66 | 67 | # Test static shape 68 | x = keras.KerasTensor((None, 32, 32, 3)) 69 | y = CutMix()(x) 70 | self.assertEqual(y.shape, (None, 32, 32, 3)) 71 | 72 | def test_model(self): 73 | layer = CutMix() 74 | inputs = keras.layers.Input(shape=(None, None, 3)) 75 | outputs = layer(inputs) 76 | model = keras.models.Model(inputs, outputs) 77 | self.assertEqual(model.output_shape, (None, None, None, 3)) 78 | 79 | def test_config(self): 80 | x = get_images("float32", "channels_last") 81 | layer = FixedCutMix() 82 | y = layer(x) 83 | 84 | layer = FixedCutMix.from_config(layer.get_config()) 85 | y2 = layer(x) 86 | self.assertAllClose(y, y2) 87 | 88 | def test_tf_data_compatibility(self): 89 | import tensorflow as tf 90 | 91 | layer = CutMix() 92 | x = get_images("float32", "channels_last") 93 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 94 | for output in ds.take(1): 95 | self.assertIsInstance(output, tf.Tensor) 96 | self.assertEqual(output.shape, (2, 32, 32, 3)) 97 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_equalize_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.random_equalize import RandomEqualize 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class RandomEqualizeTest(TestCase): 13 | @parameterized.named_parameters( 14 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 15 | ) 16 | def test_correctness(self, dtype): 17 | import torch 18 | import torchvision.transforms.v2.functional as TF 19 | from keras.src.backend.torch import convert_to_tensor 20 | 21 | # TODO: Reduce atol 22 | if dtype == "float32": 23 | atol = 0.3 24 | elif "bfloat16" in dtype: 25 | atol = 1.0 26 | elif dtype == "uint8": 27 | atol = 64 28 | np.random.seed(42) 29 | 30 | # Test channels_last 31 | x = get_images(dtype, "channels_last") 32 | layer = RandomEqualize(p=1.0, dtype=dtype) 33 | y = layer(x) 34 | 35 | ref_y = TF.equalize(convert_to_tensor(np.transpose(x, [0, 3, 1, 2]))) 36 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 37 | self.assertDType(y, dtype) 38 | self.assertAllClose(y, ref_y, atol=atol) 39 | 40 | # Test channels_first 41 | backend.set_image_data_format("channels_first") 42 | x = np.transpose(x, [0, 3, 1, 2]) 43 | layer = RandomEqualize(p=1.0, dtype=dtype) 44 | y = layer(x) 45 | 46 | ref_y = TF.equalize(convert_to_tensor(x)) 47 | self.assertDType(y, dtype) 48 | self.assertAllClose(y, ref_y, atol=atol) 49 | 50 | # Test p=0.0 51 | backend.set_image_data_format("channels_last") 52 | x = np.transpose(x, [0, 2, 3, 1]) 53 | layer = RandomEqualize(p=0.0, dtype=dtype) 54 | y = layer(x) 55 | 56 | self.assertAllClose(y, x, atol=atol) 57 | 58 | def test_shape(self): 59 | # Test channels_last 60 | x = keras.KerasTensor((None, None, None, 3)) 61 | y = RandomEqualize()(x) 62 | self.assertEqual(y.shape, (None, None, None, 3)) 63 | 64 | # Test channels_first 65 | backend.set_image_data_format("channels_first") 66 | x = keras.KerasTensor((None, 3, None, None)) 67 | y = RandomEqualize()(x) 68 | self.assertEqual(y.shape, (None, 3, None, None)) 69 | 70 | # Test static shape 71 | backend.set_image_data_format("channels_last") 72 | x = keras.KerasTensor((None, 32, 32, 3)) 73 | y = RandomEqualize()(x) 74 | self.assertEqual(y.shape, (None, 32, 32, 3)) 75 | 76 | def test_model(self): 77 | layer = RandomEqualize() 78 | inputs = keras.layers.Input(shape=[None, None, 5]) 79 | outputs = layer(inputs) 80 | model = keras.models.Model(inputs, outputs) 81 | self.assertEqual(model.output_shape, (None, None, None, 5)) 82 | 83 | def test_config(self): 84 | x = get_images("float32", "channels_last") 85 | layer = RandomEqualize(p=1.0) 86 | y = layer(x) 87 | 88 | layer = RandomEqualize.from_config(layer.get_config()) 89 | y2 = layer(x) 90 | self.assertAllClose(y, y2) 91 | 92 | def test_tf_data_compatibility(self): 93 | import tensorflow as tf 94 | 95 | layer = RandomEqualize() 96 | x = get_images("float32", "channels_last") 97 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 98 | for output in ds.take(1): 99 | self.assertIsInstance(output, tf.Tensor) 100 | self.assertEqual(output.shape, (2, 32, 32, 3)) 101 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/composition/random_order.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import keras 4 | from keras import backend 5 | from keras import saving 6 | from keras.src.utils.backend_utils import in_tf_graph 7 | 8 | from keras_aug._src.backend.dynamic_backend import DynamicBackend 9 | from keras_aug._src.backend.dynamic_backend import DynamicRandomGenerator 10 | from keras_aug._src.keras_aug_export import keras_aug_export 11 | 12 | 13 | @keras_aug_export(parent_path=["keras_aug.layers.composition"]) 14 | @keras.saving.register_keras_serializable(package="keras_aug") 15 | class RandomOrder(keras.Layer): 16 | """Apply a list of transformations in a random order. 17 | 18 | Note that due to implementation limitations, the randomness occurs in a 19 | batch manner. 20 | 21 | Args: 22 | transforms: A list of transformations or a `keras.Layer`. 23 | """ 24 | 25 | def __init__(self, transforms, seed=None, **kwargs): 26 | super().__init__(**kwargs) 27 | self._backend = DynamicBackend(backend.backend()) 28 | self._random_generator = DynamicRandomGenerator( 29 | backend.backend(), seed=seed 30 | ) 31 | self.seed = seed 32 | 33 | # Check 34 | if not isinstance(transforms, (typing.Sequence, keras.Layer)): 35 | raise ValueError( 36 | "`transforms` must be a sequence (e.g. tuple and list) or a " 37 | "`keras.Layer`. " 38 | f"Received: transforms={transforms} of type {type(transforms)}" 39 | ) 40 | if isinstance(transforms, keras.Layer): 41 | transforms = [transforms] 42 | self.transforms = list(transforms) 43 | self.total = len(self.transforms) 44 | 45 | self._convert_input_args = False 46 | self._allow_non_tensor_positional_args = True 47 | self.autocast = False 48 | 49 | @property 50 | def backend(self): 51 | return self._backend.backend 52 | 53 | @property 54 | def random_generator(self): 55 | return self._random_generator.random_generator 56 | 57 | def compute_output_shape(self, input_shape): 58 | output_shape = input_shape 59 | for transfrom in self.transforms: 60 | output_shape = transfrom.compute_output_shape(output_shape) 61 | return output_shape 62 | 63 | def get_params(self): 64 | ops = self.backend 65 | random_generator = self.random_generator 66 | 67 | fn_idx = ops.random.shuffle( 68 | ops.numpy.arange(self.total, dtype="int32"), seed=random_generator 69 | ) 70 | return fn_idx 71 | 72 | def __call__(self, inputs, **kwargs): 73 | if in_tf_graph(): 74 | self._set_backend("tensorflow") 75 | try: 76 | outputs = super().__call__(inputs, **kwargs) 77 | finally: 78 | self._reset_backend() 79 | return outputs 80 | else: 81 | return super().__call__(inputs, **kwargs) 82 | 83 | def call(self, inputs): 84 | ops = self.backend 85 | fn_idx = self.get_params() 86 | 87 | outputs = inputs 88 | for i in range(self.total): 89 | idx = fn_idx[i] 90 | outputs = ops.core.switch(idx, self.transforms, outputs) 91 | return outputs 92 | 93 | def get_config(self): 94 | config = super().get_config() 95 | config.update( 96 | { 97 | "transforms": saving.serialize_keras_object(self.transforms), 98 | "seed": self.seed, 99 | } 100 | ) 101 | return config 102 | 103 | @classmethod 104 | def from_config(cls, config, custom_objects=None): 105 | config = config.copy() 106 | config["transforms"] = saving.deserialize_keras_object( 107 | config["transforms"], custom_objects=custom_objects 108 | ) 109 | return cls(**config) 110 | 111 | def _set_backend(self, name): 112 | self._backend.set_backend(name) 113 | self._random_generator.set_generator(name) 114 | 115 | def _reset_backend(self): 116 | self._backend.reset() 117 | self._random_generator.reset() 118 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_posterize_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.random_posterize import RandomPosterize 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class RandomPosterizeTest(TestCase): 13 | @parameterized.named_parameters( 14 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 15 | ) 16 | def test_correctness(self, dtype): 17 | import torch 18 | import torchvision.transforms.v2.functional as TF 19 | from keras.src.backend.torch import convert_to_tensor 20 | 21 | np.random.seed(42) 22 | 23 | # Test channels_last 24 | x = get_images(dtype, "channels_last") 25 | layer = RandomPosterize(4, p=1.0, dtype=dtype) 26 | y = layer(x) 27 | 28 | ref_y = TF.posterize( 29 | convert_to_tensor(np.transpose(x, [0, 3, 1, 2])), bits=4 30 | ) 31 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 32 | self.assertDType(y, dtype) 33 | self.assertAllClose(y, ref_y) 34 | 35 | # Test channels_first 36 | backend.set_image_data_format("channels_first") 37 | x = get_images(dtype, "channels_first") 38 | layer = RandomPosterize(4, p=1.0, dtype=dtype) 39 | y = layer(x) 40 | 41 | ref_y = TF.posterize(convert_to_tensor(x), bits=4) 42 | self.assertDType(y, dtype) 43 | self.assertAllClose(y, ref_y) 44 | 45 | # Test p=0.0 46 | backend.set_image_data_format("channels_last") 47 | x = get_images(dtype, "channels_last") 48 | layer = RandomPosterize(4, p=0.0, dtype=dtype) 49 | y = layer(x) 50 | 51 | self.assertDType(y, dtype) 52 | self.assertAllClose(y, x) 53 | 54 | def test_shape(self): 55 | # Test channels_last 56 | x = keras.KerasTensor((None, None, None, 3)) 57 | y = RandomPosterize(4)(x) 58 | self.assertEqual(y.shape, (None, None, None, 3)) 59 | 60 | # Test channels_first 61 | backend.set_image_data_format("channels_first") 62 | x = keras.KerasTensor((None, 3, None, None)) 63 | y = RandomPosterize(4)(x) 64 | self.assertEqual(y.shape, (None, 3, None, None)) 65 | 66 | # Test static shape 67 | backend.set_image_data_format("channels_last") 68 | x = keras.KerasTensor((None, 32, 32, 3)) 69 | y = RandomPosterize(4)(x) 70 | self.assertEqual(y.shape, (None, 32, 32, 3)) 71 | 72 | def test_model(self): 73 | layer = RandomPosterize(4) 74 | inputs = keras.layers.Input(shape=(None, None, 3)) 75 | outputs = layer(inputs) 76 | model = keras.models.Model(inputs, outputs) 77 | self.assertEqual(model.output_shape, (None, None, None, 3)) 78 | 79 | def test_data_format(self): 80 | # Test channels_last 81 | x = get_images("float32", "channels_last") 82 | layer = RandomPosterize(4) 83 | y = layer(x) 84 | self.assertEqual(tuple(y.shape), (2, 32, 32, 3)) 85 | 86 | # Test channels_first 87 | backend.set_image_data_format("channels_first") 88 | x = get_images("float32", "channels_first") 89 | layer = RandomPosterize(4) 90 | y = layer(x) 91 | self.assertEqual(tuple(y.shape), (2, 3, 32, 32)) 92 | 93 | def test_config(self): 94 | x = get_images("float32", "channels_last") 95 | layer = RandomPosterize(4, p=1.0) 96 | y = layer(x) 97 | 98 | layer = RandomPosterize.from_config(layer.get_config()) 99 | y2 = layer(x) 100 | self.assertAllClose(y, y2) 101 | 102 | def test_tf_data_compatibility(self): 103 | import tensorflow as tf 104 | 105 | layer = RandomPosterize(4) 106 | x = get_images("float32", "channels_last") 107 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 108 | for output in ds.take(1): 109 | self.assertIsInstance(output, tf.Tensor) 110 | self.assertEqual(output.shape, (2, 32, 32, 3)) 111 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_hsv_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras import ops 6 | from keras.src.testing.test_utils import named_product 7 | 8 | from keras_aug._src.layers.vision.random_hsv import RandomHSV 9 | from keras_aug._src.testing.test_case import TestCase 10 | from keras_aug._src.utils.test_utils import get_images 11 | 12 | 13 | class FixedRandomHSV(RandomHSV): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | # Set to non-None 17 | self.hue = (0.9, 1.1) 18 | self.saturation = (0.9, 1.1) 19 | self.value = (0.9, 1.1) 20 | 21 | def get_params(self, batch_size, images=None, **kwargs): 22 | return dict( 23 | hue_gain=ops.ones([batch_size]) * 0.9, 24 | saturation_gain=ops.ones([batch_size]) * 1.1, 25 | value_gain=ops.ones([batch_size]) * 0.9, 26 | ) 27 | 28 | 29 | class RandomHSVTest(TestCase): 30 | regular_args = dict(hue=0.015, saturation=0.7, value=0.4) 31 | 32 | @parameterized.named_parameters( 33 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 34 | ) 35 | def test_correctness(self, dtype): 36 | np.random.seed(42) 37 | 38 | # Test channels_last 39 | x = get_images(dtype, "channels_last") 40 | layer = FixedRandomHSV(dtype=dtype) 41 | y = layer(x) 42 | 43 | # TODO: Test correctness 44 | self.assertEqual(y.shape, x.shape) 45 | self.assertDType(y, dtype) 46 | 47 | # Test channels_first 48 | backend.set_image_data_format("channels_first") 49 | x = get_images(dtype, "channels_first") 50 | layer = FixedRandomHSV(dtype=dtype) 51 | y = layer(x) 52 | 53 | # TODO: Test correctness 54 | self.assertEqual(y.shape, x.shape) 55 | self.assertDType(y, dtype) 56 | 57 | def test_shape(self): 58 | # Test channels_last 59 | x = keras.KerasTensor((None, None, None, 3)) 60 | y = RandomHSV(**self.regular_args)(x) 61 | self.assertEqual(y.shape, (None, None, None, 3)) 62 | 63 | # Test channels_first 64 | backend.set_image_data_format("channels_first") 65 | x = keras.KerasTensor((None, 3, None, None)) 66 | y = RandomHSV(**self.regular_args)(x) 67 | self.assertEqual(y.shape, (None, 3, None, None)) 68 | 69 | # Test static shape 70 | backend.set_image_data_format("channels_last") 71 | x = keras.KerasTensor((None, 32, 32, 3)) 72 | y = RandomHSV(**self.regular_args)(x) 73 | self.assertEqual(y.shape, (None, 32, 32, 3)) 74 | 75 | def test_model(self): 76 | layer = RandomHSV(**self.regular_args) 77 | inputs = keras.layers.Input(shape=(None, None, 3)) 78 | outputs = layer(inputs) 79 | model = keras.models.Model(inputs, outputs) 80 | self.assertEqual(model.output_shape, (None, None, None, 3)) 81 | 82 | def test_data_format(self): 83 | # Test channels_last 84 | x = get_images("float32", "channels_last") 85 | layer = RandomHSV(**self.regular_args) 86 | y = layer(x) 87 | self.assertEqual(tuple(y.shape), (2, 32, 32, 3)) 88 | 89 | # Test channels_first 90 | backend.set_image_data_format("channels_first") 91 | x = get_images("float32", "channels_first") 92 | layer = RandomHSV(**self.regular_args) 93 | y = layer(x) 94 | self.assertEqual(tuple(y.shape), (2, 3, 32, 32)) 95 | 96 | def test_config(self): 97 | x = get_images("float32", "channels_last") 98 | layer = FixedRandomHSV() 99 | y = layer(x) 100 | 101 | layer = FixedRandomHSV.from_config(layer.get_config()) 102 | y2 = layer(x) 103 | self.assertAllClose(y, y2) 104 | 105 | def test_tf_data_compatibility(self): 106 | import tensorflow as tf 107 | 108 | layer = RandomHSV(**self.regular_args) 109 | x = get_images("float32", "channels_last") 110 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 111 | for output in ds.take(1): 112 | self.assertIsInstance(output, tf.Tensor) 113 | self.assertEqual(output.shape, (2, 32, 32, 3)) 114 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_grayscale_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.random_grayscale import RandomGrayscale 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class RandomGrayscaleTest(TestCase): 13 | @parameterized.named_parameters( 14 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 15 | ) 16 | def test_correctness(self, dtype): 17 | import torch 18 | import torchvision.transforms.v2.functional as TF 19 | from keras.src.backend.torch import convert_to_tensor 20 | 21 | np.random.seed(42) 22 | atol = 1e-2 if "bfloat16" in dtype else 1e-6 23 | 24 | # Test channels_last 25 | x = get_images(dtype, "channels_last") 26 | layer = RandomGrayscale(p=1.0, dtype=dtype) 27 | y = layer(x) 28 | 29 | ref_y = TF.rgb_to_grayscale( 30 | convert_to_tensor(np.transpose(x, [0, 3, 1, 2])), 31 | num_output_channels=3, 32 | ) 33 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 34 | self.assertDType(y, dtype) 35 | self.assertAllClose(y, ref_y, atol=atol) 36 | 37 | # Test channels_first 38 | backend.set_image_data_format("channels_first") 39 | x = get_images(dtype, "channels_first") 40 | layer = RandomGrayscale(p=1.0, dtype=dtype) 41 | y = layer(x) 42 | 43 | ref_y = TF.rgb_to_grayscale(convert_to_tensor(x), num_output_channels=3) 44 | self.assertDType(y, dtype) 45 | self.assertAllClose(y, ref_y, atol=atol) 46 | 47 | # Test p=0.0 48 | backend.set_image_data_format("channels_last") 49 | x = get_images(dtype, "channels_last") 50 | layer = RandomGrayscale(p=0.0, dtype=dtype) 51 | y = layer(x) 52 | 53 | self.assertDType(y, dtype) 54 | self.assertAllClose(y, x) 55 | 56 | def test_shape(self): 57 | # Test channels_last 58 | x = keras.KerasTensor((None, None, None, 3)) 59 | y = RandomGrayscale()(x) 60 | self.assertEqual(y.shape, (None, None, None, 3)) 61 | 62 | # Test channels_first 63 | backend.set_image_data_format("channels_first") 64 | x = keras.KerasTensor((None, 3, None, None)) 65 | y = RandomGrayscale()(x) 66 | self.assertEqual(y.shape, (None, 3, None, None)) 67 | 68 | # Test static shape 69 | backend.set_image_data_format("channels_last") 70 | x = keras.KerasTensor((None, 32, 32, 3)) 71 | y = RandomGrayscale()(x) 72 | self.assertEqual(y.shape, (None, 32, 32, 3)) 73 | 74 | def test_model(self): 75 | layer = RandomGrayscale() 76 | inputs = keras.layers.Input(shape=(None, None, 3)) 77 | outputs = layer(inputs) 78 | model = keras.models.Model(inputs, outputs) 79 | self.assertEqual(model.output_shape, (None, None, None, 3)) 80 | 81 | def test_data_format(self): 82 | # Test channels_last 83 | x = get_images("float32", "channels_last") 84 | layer = RandomGrayscale() 85 | y = layer(x) 86 | self.assertEqual(tuple(y.shape), (2, 32, 32, 3)) 87 | 88 | # Test channels_first 89 | backend.set_image_data_format("channels_first") 90 | x = get_images("float32", "channels_first") 91 | layer = RandomGrayscale() 92 | y = layer(x) 93 | self.assertEqual(tuple(y.shape), (2, 3, 32, 32)) 94 | 95 | def test_config(self): 96 | x = get_images("float32", "channels_last") 97 | layer = RandomGrayscale(p=1.0) 98 | y = layer(x) 99 | 100 | layer = RandomGrayscale.from_config(layer.get_config()) 101 | y2 = layer(x) 102 | self.assertAllClose(y, y2) 103 | 104 | def test_tf_data_compatibility(self): 105 | import tensorflow as tf 106 | 107 | layer = RandomGrayscale() 108 | x = get_images("float32", "channels_last") 109 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 110 | for output in ds.take(1): 111 | self.assertIsInstance(output, tf.Tensor) 112 | self.assertEqual(output.shape, (2, 32, 32, 3)) 113 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_solarize_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.random_solarize import RandomSolarize 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class RandomSolarizeTest(TestCase): 13 | @parameterized.named_parameters( 14 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 15 | ) 16 | def test_correctness(self, dtype): 17 | import torch 18 | import torchvision.transforms.v2.functional as TF 19 | from keras.src.backend.torch import convert_to_tensor 20 | 21 | if "float" in dtype: 22 | threshold = 0.5 23 | elif dtype == "uint8": 24 | threshold = 127 25 | np.random.seed(42) 26 | 27 | # Test channels_last 28 | x = get_images(dtype, "channels_last") 29 | layer = RandomSolarize(threshold, p=1.0, dtype=dtype) 30 | y = layer(x) 31 | 32 | ref_y = TF.solarize( 33 | convert_to_tensor(np.transpose(x, [0, 3, 1, 2])), 34 | threshold=threshold, 35 | ) 36 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 37 | self.assertDType(y, dtype) 38 | self.assertAllClose(y, ref_y) 39 | 40 | # Test channels_first 41 | backend.set_image_data_format("channels_first") 42 | x = get_images(dtype, "channels_first") 43 | layer = RandomSolarize(threshold, p=1.0, dtype=dtype) 44 | y = layer(x) 45 | 46 | ref_y = TF.solarize(convert_to_tensor(x), threshold=threshold) 47 | self.assertDType(y, dtype) 48 | self.assertAllClose(y, ref_y) 49 | 50 | # Test p=0.0 51 | backend.set_image_data_format("channels_last") 52 | x = get_images(dtype, "channels_last") 53 | layer = RandomSolarize(threshold, p=0.0, dtype=dtype) 54 | y = layer(x) 55 | 56 | self.assertDType(y, dtype) 57 | self.assertAllClose(y, x) 58 | 59 | def test_shape(self): 60 | # Test channels_last 61 | x = keras.KerasTensor((None, None, None, 3)) 62 | y = RandomSolarize(0.5)(x) 63 | self.assertEqual(y.shape, (None, None, None, 3)) 64 | 65 | # Test channels_first 66 | backend.set_image_data_format("channels_first") 67 | x = keras.KerasTensor((None, 3, None, None)) 68 | y = RandomSolarize(0.5)(x) 69 | self.assertEqual(y.shape, (None, 3, None, None)) 70 | 71 | # Test static shape 72 | backend.set_image_data_format("channels_last") 73 | x = keras.KerasTensor((None, 32, 32, 3)) 74 | y = RandomSolarize(0.5)(x) 75 | self.assertEqual(y.shape, (None, 32, 32, 3)) 76 | 77 | def test_model(self): 78 | layer = RandomSolarize(0.5) 79 | inputs = keras.layers.Input(shape=(None, None, 3)) 80 | outputs = layer(inputs) 81 | model = keras.models.Model(inputs, outputs) 82 | self.assertEqual(model.output_shape, (None, None, None, 3)) 83 | 84 | def test_data_format(self): 85 | # Test channels_last 86 | x = get_images("float32", "channels_last") 87 | layer = RandomSolarize(0.5) 88 | y = layer(x) 89 | self.assertEqual(tuple(y.shape), (2, 32, 32, 3)) 90 | 91 | # Test channels_first 92 | backend.set_image_data_format("channels_first") 93 | x = get_images("float32", "channels_first") 94 | layer = RandomSolarize(0.5) 95 | y = layer(x) 96 | self.assertEqual(tuple(y.shape), (2, 3, 32, 32)) 97 | 98 | def test_config(self): 99 | x = get_images("float32", "channels_last") 100 | layer = RandomSolarize(0.5, p=1.0) 101 | y = layer(x) 102 | 103 | layer = RandomSolarize.from_config(layer.get_config()) 104 | y2 = layer(x) 105 | self.assertAllClose(y, y2) 106 | 107 | def test_tf_data_compatibility(self): 108 | import tensorflow as tf 109 | 110 | layer = RandomSolarize(0.5) 111 | x = get_images("float32", "channels_last") 112 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 113 | for output in ds.take(1): 114 | self.assertIsInstance(output, tf.Tensor) 115 | self.assertEqual(output.shape, (2, 32, 32, 3)) 116 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_channel_permutation_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras import ops 6 | from keras.src.testing.test_utils import named_product 7 | 8 | from keras_aug._src.layers.vision.random_channel_permutation import ( 9 | RandomChannelPermutation, 10 | ) 11 | from keras_aug._src.testing.test_case import TestCase 12 | from keras_aug._src.utils.test_utils import get_images 13 | 14 | 15 | class FixedRandomChannelPermutation(RandomChannelPermutation): 16 | def get_params(self, batch_size, images=None, **kwargs): 17 | zeros = ops.zeros([batch_size], "int32") 18 | ones = ops.ones([batch_size], "int32") 19 | twos = ops.ones([batch_size], "int32") * 2 20 | return ops.stack([ones, twos, zeros], axis=-1) 21 | 22 | 23 | class RandomChannelPermutationTest(TestCase): 24 | @parameterized.named_parameters( 25 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 26 | ) 27 | def test_correctness(self, dtype): 28 | import torch 29 | import torchvision.transforms.v2.functional as TF 30 | from keras.src.backend.torch import convert_to_tensor 31 | 32 | np.random.seed(42) 33 | 34 | # Test channels_last 35 | x = get_images(dtype, "channels_last") 36 | layer = FixedRandomChannelPermutation(3, dtype=dtype) 37 | y = layer(x) 38 | 39 | ref_y = TF.permute_channels( 40 | convert_to_tensor(np.transpose(x, [0, 3, 1, 2])), [1, 2, 0] 41 | ) 42 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 43 | self.assertDType(y, dtype) 44 | self.assertAllClose(y, ref_y) 45 | 46 | # Test channels_first 47 | backend.set_image_data_format("channels_first") 48 | x = get_images(dtype, "channels_first") 49 | layer = FixedRandomChannelPermutation(3, dtype=dtype) 50 | y = layer(x) 51 | 52 | ref_y = TF.permute_channels(convert_to_tensor(x), [1, 2, 0]) 53 | self.assertDType(y, dtype) 54 | self.assertAllClose(y, ref_y) 55 | 56 | def test_shape(self): 57 | # Test channels_last 58 | x = keras.KerasTensor((None, None, None, 3)) 59 | y = RandomChannelPermutation(3)(x) 60 | self.assertEqual(y.shape, (None, None, None, 3)) 61 | 62 | # Test channels_first 63 | backend.set_image_data_format("channels_first") 64 | x = keras.KerasTensor((None, 3, None, None)) 65 | y = RandomChannelPermutation(3)(x) 66 | self.assertEqual(y.shape, (None, 3, None, None)) 67 | 68 | # Test static shape 69 | backend.set_image_data_format("channels_last") 70 | x = keras.KerasTensor((None, 32, 32, 3)) 71 | y = RandomChannelPermutation(3)(x) 72 | self.assertEqual(y.shape, (None, 32, 32, 3)) 73 | 74 | def test_model(self): 75 | layer = RandomChannelPermutation(3) 76 | inputs = keras.layers.Input(shape=(None, None, 3)) 77 | outputs = layer(inputs) 78 | model = keras.models.Model(inputs, outputs) 79 | self.assertEqual(model.output_shape, (None, None, None, 3)) 80 | 81 | def test_data_format(self): 82 | # Test channels_last 83 | x = get_images("float32", "channels_last") 84 | layer = RandomChannelPermutation(3) 85 | y = layer(x) 86 | self.assertEqual(tuple(y.shape), (2, 32, 32, 3)) 87 | 88 | # Test channels_first 89 | backend.set_image_data_format("channels_first") 90 | x = get_images("float32", "channels_first") 91 | layer = RandomChannelPermutation(3) 92 | y = layer(x) 93 | self.assertEqual(tuple(y.shape), (2, 3, 32, 32)) 94 | 95 | def test_config(self): 96 | x = get_images("float32", "channels_last") 97 | layer = FixedRandomChannelPermutation(3) 98 | y = layer(x) 99 | 100 | layer = FixedRandomChannelPermutation.from_config(layer.get_config()) 101 | y2 = layer(x) 102 | self.assertAllClose(y, y2) 103 | 104 | def test_tf_data_compatibility(self): 105 | import tensorflow as tf 106 | 107 | layer = RandomChannelPermutation(3) 108 | x = get_images("float32", "channels_last") 109 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 110 | for output in ds.take(1): 111 | self.assertIsInstance(output, tf.Tensor) 112 | self.assertEqual(output.shape, (2, 32, 32, 3)) 113 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import keras 4 | from keras import backend 5 | 6 | from keras_aug._src.keras_aug_export import keras_aug_export 7 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 8 | from keras_aug._src.utils.argument_validation import standardize_parameter 9 | 10 | 11 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 12 | @keras.saving.register_keras_serializable(package="keras_aug") 13 | class GaussianBlur(VisionRandomLayer): 14 | """Blurs the input images with randomly chosen Gaussian blur kernel. 15 | 16 | The convolution will be using 'reflect' padding corresponding to the 17 | kernel size, to maintain the input shape. 18 | 19 | Note that due to implementation limitations, a single sampled `sigma` will 20 | be applied to the entire batch of images. 21 | 22 | Args: 23 | kernel_size: An int or a sequence of ints specifying the size of the 24 | Gaussian kernel in x and y directions. The values should be odd and 25 | positive numbers. 26 | sigma: A float or a sequence of floats specifying standard deviation to 27 | be used for the Gaussian kernel. If float, sigma is fixed. If a 28 | sequence of floats, sigma is sampled uniformly in the given range. 29 | data_format: A string specifying the data format of the input images. 30 | It can be either `"channels_last"` or `"channels_first"`. 31 | If not specified, the value will be interpreted by 32 | `keras.config.image_data_format`. Defaults to `None`. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | kernel_size: typing.Union[int, typing.Sequence[int]], 38 | sigma: typing.Union[float, typing.Sequence[float]] = (0.1, 2.0), 39 | data_format=None, 40 | **kwargs, 41 | ): 42 | super().__init__(**kwargs) 43 | if isinstance(kernel_size, int): 44 | kernel_size = (kernel_size, kernel_size) 45 | if isinstance(sigma, (int, float)): 46 | sigma = (float(sigma), float(sigma)) 47 | self.kernel_size = tuple(kernel_size) 48 | self.sigma = tuple(sigma) 49 | self.data_format = data_format or keras.config.image_data_format() 50 | 51 | if len(kernel_size) != 2: 52 | raise ValueError( 53 | "The length of `kernel_size` should be 2. " 54 | f"Received: kernel_size={kernel_size}" 55 | ) 56 | for ks in kernel_size: 57 | if ks <= 0 or ks % 2 == 0: 58 | raise ValueError( 59 | "The values of `kernel_size` should be odd and positive." 60 | ) 61 | standardize_parameter( 62 | self.sigma, 63 | "sigma", 64 | bound=(0.0, float("inf")), 65 | allow_none=False, 66 | allow_single_number=False, 67 | ) 68 | 69 | def get_params(self, batch_size, images=None, **kwargs): 70 | ops = self.backend 71 | random_generator = self.random_generator 72 | 73 | dtype = backend.result_type(images.dtype, float) 74 | sigma = ops.random.uniform( 75 | [1], 76 | self.sigma[0], 77 | self.sigma[1], 78 | dtype=dtype, 79 | seed=random_generator, 80 | ) 81 | return sigma[0] 82 | 83 | def compute_output_shape(self, input_shape): 84 | return input_shape 85 | 86 | def augment_images(self, images, transformations=None, **kwargs): 87 | sigma = transformations 88 | images = self.image_backend.guassian_blur( 89 | images, 90 | self.kernel_size, 91 | [sigma, sigma], 92 | data_format=self.data_format, 93 | ) 94 | return images 95 | 96 | def augment_labels(self, labels, transformations, **kwargs): 97 | return labels 98 | 99 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 100 | return bounding_boxes 101 | 102 | def augment_segmentation_masks( 103 | self, segmentation_masks, transformations, **kwargs 104 | ): 105 | return segmentation_masks 106 | 107 | def augment_keypoints(self, keypoints, transformations, **kwargs): 108 | return keypoints 109 | 110 | def get_config(self): 111 | config = super().get_config() 112 | config.update({"kernel_size": self.kernel_size, "sigma": self.sigma}) 113 | return config 114 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/composition/random_apply.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import keras 4 | from keras import backend 5 | from keras import saving 6 | from keras.src.utils.backend_utils import in_tf_graph 7 | 8 | from keras_aug._src.backend.dynamic_backend import DynamicBackend 9 | from keras_aug._src.backend.dynamic_backend import DynamicRandomGenerator 10 | from keras_aug._src.keras_aug_export import keras_aug_export 11 | 12 | 13 | @keras_aug_export(parent_path=["keras_aug.layers.composition"]) 14 | @keras.saving.register_keras_serializable(package="keras_aug") 15 | class RandomApply(keras.Layer): 16 | """Apply randomly a list of transformations with a given probability. 17 | 18 | Note that due to implementation limitations, the randomness occurs in a 19 | batch manner. 20 | 21 | Args: 22 | transforms: A list of transformations or a `keras.Layer`. 23 | p: A float specifying the probability. Defaults to `0.5`. 24 | """ 25 | 26 | def __init__(self, transforms, p: float = 0.5, seed=None, **kwargs): 27 | super().__init__(**kwargs) 28 | self._backend = DynamicBackend(backend.backend()) 29 | self._random_generator = DynamicRandomGenerator( 30 | backend.backend(), seed=seed 31 | ) 32 | self.seed = seed 33 | 34 | # Check 35 | if not isinstance(transforms, (typing.Sequence, keras.Layer)): 36 | raise ValueError( 37 | "`transforms` must be a sequence (e.g. tuple and list) or a " 38 | "`keras.Layer`. " 39 | f"Received: transforms={transforms} of type {type(transforms)}" 40 | ) 41 | if isinstance(transforms, keras.Layer): 42 | transforms = [transforms] 43 | self.transforms = list(transforms) 44 | self.p = float(p) 45 | 46 | self._convert_input_args = False 47 | self._allow_non_tensor_positional_args = True 48 | self.autocast = False 49 | 50 | @property 51 | def backend(self): 52 | return self._backend.backend 53 | 54 | @property 55 | def random_generator(self): 56 | return self._random_generator.random_generator 57 | 58 | def compute_output_shape(self, input_shape): 59 | output_shape = input_shape 60 | for transfrom in self.transforms: 61 | output_shape = transfrom.compute_output_shape(output_shape) 62 | if output_shape != input_shape: 63 | raise ValueError( 64 | "The output shape must be the same as input shape. " 65 | f"Received: input_shape={input_shape}, " 66 | f"output_shape={output_shape}" 67 | ) 68 | return output_shape 69 | 70 | def get_params(self): 71 | ops = self.backend 72 | random_generator = self.random_generator 73 | 74 | p = ops.random.uniform((1,), seed=random_generator) 75 | return p[0] 76 | 77 | def _apply_transforms(self, inputs): 78 | for layer in self.transforms: 79 | inputs = layer(inputs) 80 | return inputs 81 | 82 | def __call__(self, inputs, **kwargs): 83 | if in_tf_graph(): 84 | self._set_backend("tensorflow") 85 | try: 86 | outputs = super().__call__(inputs, **kwargs) 87 | finally: 88 | self._reset_backend() 89 | return outputs 90 | else: 91 | return super().__call__(inputs, **kwargs) 92 | 93 | def call(self, inputs): 94 | ops = self.backend 95 | p = self.get_params() 96 | 97 | ori_inputs = inputs 98 | if isinstance(inputs, dict): 99 | inputs = inputs.copy() 100 | 101 | outputs = ops.core.cond( 102 | p < self.p, 103 | lambda: self._apply_transforms(inputs), 104 | lambda: ori_inputs, 105 | ) 106 | return outputs 107 | 108 | def get_config(self): 109 | config = super().get_config() 110 | config.update( 111 | { 112 | "transforms": saving.serialize_keras_object(self.transforms), 113 | "p": self.p, 114 | "seed": self.seed, 115 | } 116 | ) 117 | return config 118 | 119 | @classmethod 120 | def from_config(cls, config, custom_objects=None): 121 | config = config.copy() 122 | config["transforms"] = saving.deserialize_keras_object( 123 | config["transforms"], custom_objects=custom_objects 124 | ) 125 | return cls(**config) 126 | 127 | def _set_backend(self, name): 128 | self._backend.set_backend(name) 129 | self._random_generator.set_generator(name) 130 | 131 | def _reset_backend(self): 132 | self._backend.reset() 133 | self._random_generator.reset() 134 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_sharpen_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.random_sharpen import RandomSharpen 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | from keras_aug._src.utils.test_utils import uses_gpu 11 | 12 | 13 | class RandomSharpenTest(TestCase): 14 | @parameterized.named_parameters( 15 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 16 | ) 17 | def test_correctness(self, dtype): 18 | import torch 19 | import torchvision.transforms.v2.functional as TF 20 | from keras.src.backend.torch import convert_to_tensor 21 | 22 | if dtype == "float32": 23 | atol = 1e-6 24 | elif "bfloat16" in dtype: 25 | atol = 1e-2 26 | elif dtype == "uint8": 27 | atol = 1e-6 28 | np.random.seed(42) 29 | 30 | # Test channels_last 31 | x = get_images(dtype, "channels_last") 32 | layer = RandomSharpen(2.0, p=1.0, dtype=dtype) 33 | y = layer(x) 34 | 35 | ref_y = TF.adjust_sharpness( 36 | convert_to_tensor(np.transpose(x, [0, 3, 1, 2])), 37 | sharpness_factor=2.0, 38 | ) 39 | ref_y = torch.permute(ref_y, (0, 2, 3, 1)) 40 | self.assertDType(y, dtype) 41 | self.assertAllClose(y, ref_y, atol=atol) 42 | 43 | # Test channels_first 44 | if backend.backend() == "tensorflow" and not uses_gpu(): 45 | self.skipTest("Tensorflow CPU doesn't support `RandomSharpen`") 46 | backend.set_image_data_format("channels_first") 47 | x = get_images(dtype, "channels_first") 48 | layer = RandomSharpen(2.0, p=1.0, dtype=dtype) 49 | y = layer(x) 50 | 51 | ref_y = TF.adjust_sharpness(convert_to_tensor(x), sharpness_factor=2.0) 52 | self.assertDType(y, dtype) 53 | self.assertAllClose(y, ref_y, atol=atol) 54 | 55 | # Test p=0.0 56 | backend.set_image_data_format("channels_last") 57 | x = get_images(dtype, "channels_last") 58 | layer = RandomSharpen(2.0, p=0.0, dtype=dtype) 59 | y = layer(x) 60 | 61 | self.assertDType(y, dtype) 62 | self.assertAllClose(y, x) 63 | 64 | def test_shape(self): 65 | # Test channels_last 66 | x = keras.KerasTensor((None, None, None, 3)) 67 | y = RandomSharpen(2.0)(x) 68 | self.assertEqual(y.shape, (None, None, None, 3)) 69 | 70 | # Test channels_first 71 | backend.set_image_data_format("channels_first") 72 | x = keras.KerasTensor((None, 3, None, None)) 73 | y = RandomSharpen(2.0)(x) 74 | self.assertEqual(y.shape, (None, 3, None, None)) 75 | 76 | # Test static shape 77 | backend.set_image_data_format("channels_last") 78 | x = keras.KerasTensor((None, 32, 32, 3)) 79 | y = RandomSharpen(2.0)(x) 80 | self.assertEqual(y.shape, (None, 32, 32, 3)) 81 | 82 | def test_model(self): 83 | layer = RandomSharpen(2.0) 84 | inputs = keras.layers.Input(shape=(None, None, 3)) 85 | outputs = layer(inputs) 86 | model = keras.models.Model(inputs, outputs) 87 | self.assertEqual(model.output_shape, (None, None, None, 3)) 88 | 89 | def test_data_format(self): 90 | # Test channels_last 91 | x = get_images("float32", "channels_last") 92 | layer = RandomSharpen(2.0) 93 | y = layer(x) 94 | self.assertEqual(tuple(y.shape), (2, 32, 32, 3)) 95 | 96 | # Test channels_first 97 | if backend.backend() == "tensorflow" and not uses_gpu(): 98 | self.skipTest("Tensorflow CPU doesn't support `RandomSharpen`") 99 | backend.set_image_data_format("channels_first") 100 | x = get_images("float32", "channels_first") 101 | layer = RandomSharpen(2.0) 102 | y = layer(x) 103 | self.assertEqual(tuple(y.shape), (2, 3, 32, 32)) 104 | 105 | def test_config(self): 106 | x = get_images("float32", "channels_last") 107 | layer = RandomSharpen(2.0, p=1.0) 108 | y = layer(x) 109 | 110 | layer = RandomSharpen.from_config(layer.get_config()) 111 | y2 = layer(x) 112 | self.assertAllClose(y, y2) 113 | 114 | def test_tf_data_compatibility(self): 115 | import tensorflow as tf 116 | 117 | layer = RandomSharpen(2.0) 118 | x = get_images("float32", "channels_last") 119 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 120 | for output in ds.take(1): 121 | self.assertIsInstance(output, tf.Tensor) 122 | self.assertEqual(output.shape, (2, 32, 32, 3)) 123 | -------------------------------------------------------------------------------- /keras_aug/_src/visualization/draw_bounding_boxes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from keras import backend 3 | from keras import ops 4 | 5 | from keras_aug._src import ops as ka_ops 6 | from keras_aug._src.keras_aug_export import keras_aug_export 7 | 8 | 9 | @keras_aug_export(parent_path=["keras_aug.visualization"]) 10 | def draw_bounding_boxes( 11 | images, 12 | bounding_boxes, 13 | bounding_box_format, 14 | class_mapping=None, 15 | color_mapping=None, 16 | thickness=1, 17 | font_scale=1.0, 18 | data_format=None, 19 | ): 20 | try: 21 | import cv2 22 | except ImportError: 23 | raise ImportError( 24 | "Cannot import OpenCV. You can install it by " 25 | "`pip install opencv-python`." 26 | ) 27 | class_mapping = class_mapping or {} 28 | if len(class_mapping) > 0: 29 | num_classes = len(class_mapping) 30 | else: 31 | num_classes = 80 # Defaults to 80 (COCO) 32 | if color_mapping is None: 33 | color_mapping = {} 34 | for i, color in enumerate(_generate_color_palette(num_classes)): 35 | color_mapping[i] = color 36 | thickness = int(thickness) 37 | data_format = data_format or backend.image_data_format() 38 | images_shape = ops.shape(images) 39 | if len(images_shape) != 4: 40 | raise ValueError( 41 | "`images` must be batched 4D tensor. " 42 | f"Received: images.shape={images_shape}" 43 | ) 44 | if not isinstance(bounding_boxes, dict): 45 | raise TypeError( 46 | "`bounding_boxes` should be a dict. " 47 | f"Received: bounding_boxes={bounding_boxes} of type " 48 | f"{type(bounding_boxes)}" 49 | ) 50 | if "boxes" not in bounding_boxes or "classes" not in bounding_boxes: 51 | raise ValueError( 52 | "`bounding_boxes` should be a dict containing 'boxes' and " 53 | f"'classes' keys. Received: bounding_boxes={bounding_boxes}" 54 | ) 55 | if data_format == "channels_last": 56 | h_axis = -3 57 | w_axis = -2 58 | else: 59 | h_axis = -2 60 | w_axis = -1 61 | height = images_shape[h_axis] 62 | width = images_shape[w_axis] 63 | bounding_boxes = bounding_boxes.copy() 64 | bounding_boxes = ka_ops.bounding_box.convert_format( 65 | bounding_boxes, bounding_box_format, "xyxy", height, width 66 | ) 67 | 68 | # To numpy array 69 | images = ka_ops.image.transform_dtype(images, images.dtype, "uint8") 70 | images = ops.convert_to_numpy(images) 71 | boxes = ops.convert_to_numpy(bounding_boxes["boxes"]) 72 | classes = ops.convert_to_numpy(bounding_boxes["classes"]) 73 | if "confidences" in bounding_boxes: 74 | confidences = ops.convert_to_numpy(bounding_boxes["confidences"]) 75 | else: 76 | confidences = None 77 | 78 | result = [] 79 | batch_size = images.shape[0] 80 | for i in range(batch_size): 81 | _image = images[i] 82 | _box = boxes[i] 83 | _class = classes[i] 84 | for box_i in range(_box.shape[0]): 85 | x1, y1, x2, y2 = _box[box_i].astype("int32") 86 | c = _class[box_i].astype("int32") 87 | if c == -1: 88 | continue 89 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 90 | c = int(c) 91 | color = color_mapping[c % num_classes] 92 | 93 | # Draw bounding box 94 | cv2.rectangle(_image, (x1, y1), (x2, y2), color, thickness) 95 | 96 | if c in class_mapping: 97 | label = class_mapping[c] 98 | if confidences is not None: 99 | conf = confidences[i][box_i] 100 | label = f"{label} | {conf:.2f}" 101 | 102 | font_x1, font_y1 = _find_text_location( 103 | x1, y1, font_scale, thickness 104 | ) 105 | cv2.putText( 106 | _image, 107 | label, 108 | (font_x1, font_y1), 109 | cv2.FONT_HERSHEY_SIMPLEX, 110 | font_scale, 111 | color, 112 | thickness, 113 | ) 114 | result.append(_image) 115 | return np.stack(result, axis=0) 116 | 117 | 118 | def _find_text_location(x, y, font_scale, thickness): 119 | font_height = int(font_scale * 12) 120 | target_y = y - 8 121 | if target_y - (2 * font_height) > 0: 122 | return x, y - 8 123 | 124 | line_offset = thickness 125 | static_offset = 3 126 | 127 | return ( 128 | x + static_offset, 129 | y + (2 * font_height) + line_offset + static_offset, 130 | ) 131 | 132 | 133 | def _generate_color_palette(num_classes: int): 134 | palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1]) 135 | return [((i * palette) % 255).tolist() for i in range(num_classes)] 136 | -------------------------------------------------------------------------------- /keras_aug/_src/utils/argument_validation.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from collections.abc import Sequence 3 | 4 | from keras import backend 5 | 6 | 7 | def standardize_parameter( 8 | parameter, 9 | name="parameter", 10 | center=0.0, 11 | bound=None, 12 | allow_none=True, 13 | allow_single_number=True, 14 | ): 15 | if parameter is None and not allow_none: 16 | raise ValueError(f"`{name}` cannot be `None`") 17 | if parameter is None and allow_none: 18 | return parameter 19 | 20 | if not isinstance(parameter, Sequence) and not allow_single_number: 21 | raise ValueError( 22 | f"`{name}` cannot be a single number." 23 | f"Received: {name}={parameter}" 24 | ) 25 | if not isinstance(parameter, Sequence): 26 | parameter = abs(parameter) 27 | parameter = (center - parameter, center + parameter) 28 | elif len(parameter) > 2: 29 | raise ValueError( 30 | f"`{name}` must be a sequence of 2 values. " 31 | f"Received: {name}={parameter}" 32 | ) 33 | if parameter[0] > parameter[1]: 34 | raise ValueError( 35 | f"`{name}` must be in the order that first element is bigger " 36 | f"that second element. Received: {name}={parameter}" 37 | ) 38 | if bound is not None: 39 | if parameter[0] < bound[0] or parameter[1] > bound[1]: 40 | raise ValueError( 41 | f"{name} is out of bounds `[{bound[0]}, {bound[1]}]`. " 42 | f"Received: {name}={parameter}" 43 | ) 44 | return tuple(parameter) 45 | 46 | 47 | def standardize_value_range(value_range): 48 | if not isinstance(value_range, Sequence) or len(value_range) != 2: 49 | raise ValueError( 50 | "`value_range` must be a sequence of numbers. " 51 | f"Received: value_range={value_range}" 52 | ) 53 | if value_range[0] > value_range[1]: 54 | raise ValueError( 55 | "`value_range` must be in the order that first element is bigger " 56 | f"that second element. Received: value_range={value_range}" 57 | ) 58 | return tuple(value_range) 59 | 60 | 61 | def standardize_size(size): 62 | if isinstance(size, numbers.Number): 63 | return int(size), int(size) 64 | if isinstance(size, Sequence) and len(size) == 1: 65 | return int(size[0]), int(size[0]) 66 | if len(size) != 2: 67 | raise ValueError( 68 | "`size` must be a single integer or the sequence of 2 " 69 | f"numbers. Received: size={size}" 70 | ) 71 | return int(size[0]), int(size[1]) 72 | 73 | 74 | def standardize_interpolation(interpolation): 75 | if isinstance(interpolation, str): 76 | interpolation = interpolation.lower() 77 | if interpolation not in ("nearest", "bilinear", "bicubic"): 78 | raise ValueError( 79 | "Invalid `interpolation`. Available values are 'nearest', " 80 | "'bilinear' and 'bicubic'. " 81 | f"Received: interpolation={interpolation}" 82 | ) 83 | return interpolation 84 | else: 85 | raise ValueError( 86 | "`interpolation` must be `str`. " 87 | f"Received: interpolation={interpolation} of type " 88 | f"{type(interpolation)}" 89 | ) 90 | 91 | 92 | def standardize_padding_mode(padding_mode): 93 | available_padding_mode = ("constant", "reflect", "symmetric") 94 | if padding_mode not in available_padding_mode: 95 | raise ValueError( 96 | "Invalid `padding_mode`. Available values are: " 97 | f"{list(available_padding_mode)}. " 98 | f"Received: padding_mode={padding_mode}" 99 | ) 100 | return padding_mode 101 | 102 | 103 | def standardize_bbox_format(bounding_box_format): 104 | if bounding_box_format is None: 105 | return bounding_box_format 106 | available_bounding_box_format = ( 107 | "xyxy", 108 | "xywh", 109 | "center_xywh", 110 | "rel_xyxy", 111 | "rel_xywh", 112 | "rel_center_xywh", 113 | ) 114 | if bounding_box_format not in available_bounding_box_format: 115 | raise ValueError( 116 | "Invalid `bounding_box_format`. Available values are: " 117 | f"{list(available_bounding_box_format)}. " 118 | f"Received: bounding_box_format={bounding_box_format}" 119 | ) 120 | return bounding_box_format 121 | 122 | 123 | def standardize_data_format(data_format): 124 | if data_format is None: 125 | data_format = backend.image_data_format() 126 | if data_format not in ("channels_last", "channels_first"): 127 | raise ValueError( 128 | "Invalid `data_format`. Available values are: " 129 | f"['channels_last', 'channels_first']. " 130 | f"Received: data_format={data_format}" 131 | ) 132 | return data_format 133 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/mix_up.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import keras 4 | from keras import backend 5 | 6 | from keras_aug._src.keras_aug_export import keras_aug_export 7 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 8 | from keras_aug._src.utils.argument_validation import standardize_data_format 9 | 10 | 11 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 12 | @keras.saving.register_keras_serializable(package="keras_aug") 13 | class MixUp(VisionRandomLayer): 14 | """Apply MixUp to the provided batch of images and labels. 15 | 16 | Note that `MixUp` is meant to be used on batches of inputs, not individual 17 | input. The sample pairing is deterministic and done by matching consecutive 18 | samples in the batch, so the batch needs to be shuffled. 19 | 20 | Typically, `MixUp` expects the `labels` to be one-hot-encoded format. If 21 | they are not, with provided `num_classes`, this layer will transform the 22 | `labels` into one-hot-encoded format. (e.g. `(batch_size, num_classes)`) 23 | 24 | References: 25 | - [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412) 26 | 27 | Args: 28 | alpha: The hyperparameter of the beta distribution used for cutmix. 29 | Defaults to `1.0`. 30 | num_classes: The number of classes in the inputs. Used for one-hot 31 | encoding. Can be `None` if the labels are already one-hot-encoded. 32 | Defaults to `None`. 33 | data_format: A string specifying the data format of the input images. 34 | It can be either `"channels_last"` or `"channels_first"`. 35 | If not specified, the value will be interpreted by 36 | `keras.config.image_data_format`. Defaults to `None`. 37 | """ # noqa: E501 38 | 39 | def __init__( 40 | self, 41 | alpha: float = 1.0, 42 | num_classes: typing.Optional[int] = None, 43 | data_format: typing.Optional[str] = None, 44 | **kwargs, 45 | ): 46 | super().__init__(**kwargs) 47 | self.alpha = float(alpha) 48 | self.num_classes = int(num_classes) if num_classes is not None else None 49 | self.data_format = standardize_data_format(data_format) 50 | 51 | if self.data_format == "channels_last": 52 | self.h_axis, self.w_axis = -3, -2 53 | else: 54 | self.h_axis, self.w_axis = -2, -1 55 | 56 | def compute_output_shape(self, input_shape): 57 | return input_shape 58 | 59 | def get_params(self, batch_size, images=None, **kwargs): 60 | ops = self.backend 61 | random_generator = self.random_generator 62 | 63 | dtype = backend.result_type(self.compute_dtype, float) 64 | lam = ops.random.beta( 65 | [batch_size], self.alpha, self.alpha, seed=random_generator 66 | ) 67 | lam = ops.cast(lam, dtype) 68 | return lam 69 | 70 | def augment_images(self, images, transformations, **kwargs): 71 | ops = self.backend 72 | 73 | lam = transformations 74 | original_dtype = backend.standardize_dtype(images.dtype) 75 | images = self.image_backend.transform_dtype( 76 | images, images.dtype, backend.result_type(images.dtype, float) 77 | ) 78 | rolled_images = ops.numpy.roll(images, shift=1, axis=0) 79 | lam = ops.numpy.expand_dims(lam, axis=[1, 2, 3]) 80 | images = ops.numpy.add( 81 | ops.numpy.multiply(rolled_images, 1.0 - lam), 82 | ops.numpy.multiply(images, lam), 83 | ) 84 | images = self.image_backend.transform_dtype( 85 | images, images.dtype, original_dtype 86 | ) 87 | return images 88 | 89 | def augment_labels(self, labels, transformations, **kwargs): 90 | ops = self.backend 91 | 92 | lam = transformations 93 | compute_dtype = backend.result_type(labels.dtype, float) 94 | labels_ndim = len(ops.shape(labels)) 95 | if labels_ndim == 1: 96 | if self.num_classes is None: 97 | raise ValueError( 98 | "If `labels` is not one-hot-encoded, you must provide " 99 | "`num_classes` in the constructor. " 100 | f"Received: num_classes={self.num_classes}" 101 | ) 102 | labels = ops.nn.one_hot( 103 | labels, self.num_classes, axis=-1, dtype=compute_dtype 104 | ) 105 | labels = ops.cast(labels, compute_dtype) 106 | rolled_labels = ops.numpy.roll(labels, shift=1, axis=0) 107 | lam = ops.numpy.expand_dims(lam, axis=-1) 108 | labels = ops.numpy.add( 109 | ops.numpy.multiply(rolled_labels, 1.0 - lam), 110 | ops.numpy.multiply(labels, lam), 111 | ) 112 | return labels 113 | 114 | def get_config(self): 115 | config = super().get_config() 116 | config.update({"alpha": self.alpha, "num_classes": self.num_classes}) 117 | return config 118 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/max_bounding_box_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import ml_dtypes 3 | import numpy as np 4 | import pytest 5 | from absl.testing import parameterized 6 | from keras.src.testing.test_utils import named_product 7 | 8 | from keras_aug._src.layers.vision.max_bounding_box import MaxBoundingBox 9 | from keras_aug._src.testing.test_case import TestCase 10 | from keras_aug._src.utils.test_utils import get_images 11 | 12 | 13 | class MaxBoundingBoxTest(TestCase): 14 | @parameterized.named_parameters( 15 | named_product(dtype=["float32", "mixed_bfloat16", "uint8"]) 16 | ) 17 | def test_correctness(self, dtype): 18 | bbox_dtype = ml_dtypes.bfloat16 if dtype == "mixed_bfloat16" else dtype 19 | inputs = { 20 | "images": get_images(dtype, "channels_last"), 21 | "bounding_boxes": { 22 | "boxes": np.ones((2, 4, 4)).astype(bbox_dtype), 23 | "classes": np.ones((2, 4)).astype(bbox_dtype), 24 | }, 25 | } 26 | layer = MaxBoundingBox(max_number=8, dtype=dtype) 27 | outputs = layer(inputs) 28 | self.assertDType(outputs["images"], dtype) 29 | self.assertAllClose( 30 | outputs["bounding_boxes"]["boxes"][:, :4, :], 31 | inputs["bounding_boxes"]["boxes"], 32 | ) 33 | self.assertAllClose( 34 | outputs["bounding_boxes"]["classes"][:, :4], 35 | inputs["bounding_boxes"]["classes"], 36 | ) 37 | self.assertAllClose( 38 | outputs["bounding_boxes"]["boxes"][:, 4:, :], 39 | np.ones((2, 4, 4)) * -1, 40 | ) 41 | self.assertAllClose( 42 | outputs["bounding_boxes"]["classes"][:, 4:], 43 | np.ones((2, 4)) * -1, 44 | ) 45 | 46 | def test_shape(self): 47 | # Test static shape 48 | x = { 49 | "images": keras.KerasTensor((None, 32, 32, 3)), 50 | "bounding_boxes": { 51 | "boxes": keras.KerasTensor((None, 4, 4)), 52 | "classes": keras.KerasTensor((None, 4)), 53 | }, 54 | } 55 | y = MaxBoundingBox(max_number=8)(x) 56 | self.assertEqual(y["bounding_boxes"]["boxes"].shape, (None, 8, 4)) 57 | self.assertEqual(y["bounding_boxes"]["classes"].shape, (None, 8)) 58 | 59 | @pytest.mark.skip("keras.models.Model doesn't support nested inputs") 60 | def test_model(self): 61 | layer = MaxBoundingBox(max_number=8) 62 | inputs = { 63 | "images": keras.KerasTensor((None, 32, 32, 3)), 64 | "bounding_boxes": { 65 | "boxes": keras.KerasTensor((None, 4, 4)), 66 | "classes": keras.KerasTensor((None, 4)), 67 | }, 68 | } 69 | outputs = layer(inputs) 70 | model = keras.models.Model(inputs, outputs) 71 | self.assertEqual( 72 | model.output_shape["bounding_boxes"]["boxes"].shape, (None, 8, 4) 73 | ) 74 | self.assertEqual( 75 | model.output_shape["bounding_boxes"]["classes"].shape, (None, 8) 76 | ) 77 | 78 | def test_config(self): 79 | x = get_images("float32", "channels_last") 80 | inputs = { 81 | "images": x, 82 | "bounding_boxes": { 83 | "boxes": np.ones((2, 4, 4)), 84 | "classes": np.ones((2, 4)), 85 | }, 86 | } 87 | layer = MaxBoundingBox(max_number=8) 88 | outputs = layer(inputs) 89 | boxes = keras.ops.convert_to_numpy(outputs["bounding_boxes"]["boxes"]) 90 | classes = keras.ops.convert_to_numpy( 91 | outputs["bounding_boxes"]["classes"] 92 | ) 93 | 94 | layer = MaxBoundingBox.from_config(layer.get_config()) 95 | outputs2 = layer(inputs) 96 | boxes2 = keras.ops.convert_to_numpy(outputs2["bounding_boxes"]["boxes"]) 97 | classes2 = keras.ops.convert_to_numpy( 98 | outputs2["bounding_boxes"]["classes"] 99 | ) 100 | 101 | self.assertAllClose(boxes, boxes2) 102 | self.assertAllClose(classes, classes2) 103 | 104 | def test_tf_data_compatibility(self): 105 | import tensorflow as tf 106 | 107 | layer = MaxBoundingBox(max_number=8) 108 | x = get_images("float32", "channels_last") 109 | ds = tf.data.Dataset.from_tensor_slices(x) 110 | ds = ds.map( 111 | lambda x: { 112 | "images": x, 113 | "bounding_boxes": { 114 | "boxes": np.ones((4, 4)), 115 | "classes": np.ones((4)), 116 | }, 117 | } 118 | ) 119 | ds = ds.batch(2).map(layer) 120 | for output in ds.take(1): 121 | self.assertIsInstance(output["images"], tf.Tensor) 122 | self.assertEqual(output["images"].shape, (2, 32, 32, 3)) 123 | self.assertEqual(output["bounding_boxes"]["boxes"].shape, (2, 8, 4)) 124 | self.assertEqual(output["bounding_boxes"]["classes"].shape, (2, 8)) 125 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_equalize.py: -------------------------------------------------------------------------------- 1 | import keras 2 | from keras import backend 3 | 4 | from keras_aug._src.keras_aug_export import keras_aug_export 5 | from keras_aug._src.layers.base.vision_random_layer import VisionRandomLayer 6 | from keras_aug._src.utils.argument_validation import standardize_data_format 7 | 8 | 9 | @keras_aug_export(parent_path=["keras_aug.layers.vision"]) 10 | @keras.saving.register_keras_serializable(package="keras_aug") 11 | class RandomEqualize(VisionRandomLayer): 12 | """Equalize the histogram of the images randomly with a given probability. 13 | 14 | This class equalizes the histogram of the images by applying a non-linear 15 | mapping in order to create a uniform distribution of grayscale values in 16 | the outputs. 17 | 18 | Args: 19 | bins: The number of bins to use in histogram equalization. The value 20 | must be in the range of `[0, 256]`. Defaults to `256`. 21 | p: A float specifying the probability. Defaults to `0.5`. 22 | data_format: A string specifying the data format of the input images. 23 | It can be either `"channels_last"` or `"channels_first"`. 24 | If not specified, the value will be interpreted by 25 | `keras.config.image_data_format`. Defaults to `None`. 26 | """ 27 | 28 | def __init__(self, bins=256, p: float = 0.5, data_format=None, **kwargs): 29 | super().__init__(**kwargs) 30 | self.bins = bins 31 | self.p = float(p) 32 | self.data_format = standardize_data_format(data_format) 33 | 34 | def compute_output_shape(self, input_shape): 35 | return input_shape 36 | 37 | def get_params(self, batch_size, images=None, **kwargs): 38 | ops = self.backend 39 | random_generator = self.random_generator 40 | p = ops.random.uniform([batch_size], seed=random_generator) 41 | return p 42 | 43 | def augment_images(self, images, transformations=None, **kwargs): 44 | ops = self.backend 45 | p = transformations 46 | 47 | def equalize(images): 48 | original_dtype = backend.standardize_dtype(images.dtype) 49 | images = self.image_backend.transform_dtype( 50 | images, images.dtype, "uint8" 51 | ) 52 | images = self.image_backend.equalize( 53 | images, self.bins, self.data_format 54 | ) 55 | images = self.image_backend.transform_dtype( 56 | images, images.dtype, original_dtype 57 | ) 58 | return images 59 | 60 | prob = ops.numpy.expand_dims(p < self.p, axis=[1, 2, 3]) 61 | images = ops.numpy.where(prob, equalize(images), images) 62 | return images 63 | 64 | def augment_labels(self, labels, transformations, **kwargs): 65 | return labels 66 | 67 | def augment_bounding_boxes(self, bounding_boxes, transformations, **kwargs): 68 | return bounding_boxes 69 | 70 | def augment_segmentation_masks( 71 | self, segmentation_masks, transformations, **kwargs 72 | ): 73 | return segmentation_masks 74 | 75 | def augment_keypoints(self, keypoints, transformations, **kwargs): 76 | return keypoints 77 | 78 | def _equalize_single_image(self, image): 79 | ops = self.backend 80 | if self.data_format == "channels_last": 81 | return ops.numpy.stack( 82 | [ 83 | self._scale_channel(image[..., c]) 84 | for c in range(image.shape[-1]) 85 | ], 86 | axis=-1, 87 | ) 88 | else: 89 | return ops.numpy.stack( 90 | [self._scale_channel(image[c]) for c in range(image.shape[-3])], 91 | axis=-3, 92 | ) 93 | 94 | def _scale_channel(self, image_channel): 95 | ops = self.backend 96 | hist = ops.numpy.bincount( 97 | ops.numpy.reshape(image_channel, [-1]), minlength=self.bins 98 | ) 99 | nonzero = ops.numpy.where(ops.numpy.not_equal(hist, 0), None, None) 100 | nonzero_hist = ops.numpy.reshape(ops.numpy.take(hist, nonzero), [-1]) 101 | step = ops.numpy.floor_divide( 102 | ops.numpy.sum(hist) - nonzero_hist[-1], 255 103 | ) 104 | 105 | def step_is_0(): 106 | return ops.cast(image_channel, "uint8") 107 | 108 | def step_not_0(): 109 | lut = ops.numpy.floor_divide( 110 | ops.numpy.add( 111 | ops.numpy.cumsum(hist), ops.numpy.floor_divide(step, 2) 112 | ), 113 | step, 114 | ) 115 | lut = ops.numpy.pad(lut[:-1], [[1, 0]]) 116 | lut = ops.numpy.clip(lut, 0, 255) 117 | result = ops.numpy.take(lut, ops.cast(image_channel, "int64")) 118 | return ops.cast(result, "uint8") 119 | 120 | return ops.cond(step == 0, step_is_0, step_not_0) 121 | 122 | def get_config(self): 123 | config = super().get_config() 124 | config.update({"p": self.p, "bins": self.bins}) 125 | return config 126 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/vision/random_erasing_test.py: -------------------------------------------------------------------------------- 1 | import keras 2 | import numpy as np 3 | from absl.testing import parameterized 4 | from keras import backend 5 | from keras.src.testing.test_utils import named_product 6 | 7 | from keras_aug._src.layers.vision.random_erasing import RandomErasing 8 | from keras_aug._src.testing.test_case import TestCase 9 | from keras_aug._src.utils.test_utils import get_images 10 | 11 | 12 | class FixedRandomErasing(RandomErasing): 13 | def get_params(self, batch_size, images=None, **kwargs): 14 | ops = self.backend 15 | images_shape = ops.shape(images) 16 | height, width = images_shape[self.h_axis], images_shape[self.w_axis] 17 | top = ops.numpy.zeros([batch_size]) 18 | left = ops.numpy.zeros([batch_size]) 19 | h = ops.numpy.ones([batch_size]) * 10 20 | w = ops.numpy.ones([batch_size]) * 10 21 | if isinstance(self.value, str): 22 | dtype = backend.result_type(images.dtype, float) 23 | v = ops.random.normal(ops.shape(images), dtype=dtype) 24 | elif isinstance(self.value, float): 25 | dtype = backend.standardize_dtype(images.dtype) 26 | v = ops.numpy.full(ops.shape(images), self.value) 27 | v = ops.cast(v, dtype) 28 | elif isinstance(self.value, tuple): 29 | dtype = backend.standardize_dtype(images.dtype) 30 | v = ops.convert_to_tensor(self.value) # [c] 31 | v = ops.cast(v, dtype) 32 | if self.data_format == "channels_last": 33 | v = ops.numpy.expand_dims(v, axis=[0, 1, 2]) 34 | v = ops.numpy.tile(v, [batch_size, height, width, 1]) 35 | else: 36 | v = ops.numpy.expand_dims(v, axis=[0, -1, -2]) 37 | v = ops.numpy.tile(v, [batch_size, 1, height, width]) 38 | return dict(top=top, left=left, height=h, width=w, value=v) 39 | 40 | 41 | class RandomErasingTest(TestCase): 42 | @parameterized.named_parameters( 43 | named_product( 44 | value=[0.0, (1.0, 1.0, 1.0), "random"], 45 | dtype=["float32", "mixed_bfloat16", "uint8"], 46 | ) 47 | ) 48 | def test_correctness(self, value, dtype): 49 | if dtype == "uint8" and value == "random": 50 | self.skipTest("value='random' doesn't support dtype='uint8'") 51 | 52 | np.random.seed(42) 53 | 54 | # Test channels_last 55 | images = get_images(dtype, "channels_last") 56 | layer = FixedRandomErasing(value=value, dtype=dtype) 57 | outputs = layer(images) 58 | 59 | self.assertDType(outputs, dtype) 60 | if value == 0.0: 61 | self.assertAllClose( 62 | outputs[:, 0:10, 0:10, :], 63 | np.zeros_like(self.convert_to_numpy(outputs)[:, 0:10, 0:10, :]), 64 | ) 65 | elif value == (1.0, 1.0, 1.0): 66 | self.assertAllClose( 67 | outputs[:, 0:10, 0:10, :], 68 | np.ones_like(self.convert_to_numpy(outputs)[:, 0:10, 0:10, :]), 69 | ) 70 | else: 71 | pass 72 | self.assertAllClose(outputs[:, 10:, 10:, :], images[:, 10:, 10:, :]) 73 | 74 | # Test channels_first 75 | backend.set_image_data_format("channels_first") 76 | images = get_images(dtype, "channels_first") 77 | layer = FixedRandomErasing(value=value, dtype=dtype) 78 | outputs = layer(images) 79 | 80 | self.assertDType(outputs, dtype) 81 | if value == 0.0: 82 | self.assertAllClose( 83 | outputs[:, :, 0:10, 0:10], 84 | np.zeros_like(self.convert_to_numpy(outputs)[:, :, 0:10, 0:10]), 85 | ) 86 | elif value == (1.0, 1.0, 1.0): 87 | self.assertAllClose( 88 | outputs[:, :, 0:10, 0:10], 89 | np.ones_like(self.convert_to_numpy(outputs)[:, :, 0:10, 0:10]), 90 | ) 91 | else: 92 | pass 93 | self.assertAllClose(outputs[:, :, 10:, 10:], images[:, :, 10:, 10:]) 94 | 95 | def test_shape(self): 96 | # Test dynamic shape 97 | x = keras.KerasTensor((None, None, None, 3)) 98 | y = RandomErasing()(x) 99 | self.assertEqual(y.shape, (None, None, None, 3)) 100 | 101 | # Test static shape 102 | x = keras.KerasTensor((None, 32, 32, 3)) 103 | y = RandomErasing()(x) 104 | self.assertEqual(y.shape, (None, 32, 32, 3)) 105 | 106 | def test_model(self): 107 | layer = RandomErasing() 108 | inputs = keras.layers.Input(shape=(None, None, 3)) 109 | outputs = layer(inputs) 110 | model = keras.models.Model(inputs, outputs) 111 | self.assertEqual(model.output_shape, (None, None, None, 3)) 112 | 113 | def test_config(self): 114 | x = get_images("float32", "channels_last") 115 | layer = FixedRandomErasing() 116 | y = layer(x) 117 | 118 | layer = FixedRandomErasing.from_config(layer.get_config()) 119 | y2 = layer(x) 120 | self.assertAllClose(y, y2) 121 | 122 | def test_tf_data_compatibility(self): 123 | import tensorflow as tf 124 | 125 | layer = RandomErasing() 126 | x = get_images("float32", "channels_last") 127 | ds = tf.data.Dataset.from_tensor_slices(x).batch(2).map(layer) 128 | for output in ds.take(1): 129 | self.assertIsInstance(output, tf.Tensor) 130 | self.assertEqual(output.shape, (2, 32, 32, 3)) 131 | -------------------------------------------------------------------------------- /keras_aug/_src/layers/composition/random_choice.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import keras 4 | from keras import backend 5 | from keras import saving 6 | from keras.src.utils.backend_utils import in_tf_graph 7 | 8 | from keras_aug._src.backend.dynamic_backend import DynamicBackend 9 | from keras_aug._src.backend.dynamic_backend import DynamicRandomGenerator 10 | from keras_aug._src.keras_aug_export import keras_aug_export 11 | 12 | 13 | @keras_aug_export(parent_path=["keras_aug.layers.composition"]) 14 | @keras.saving.register_keras_serializable(package="keras_aug") 15 | class RandomChoice(keras.Layer): 16 | """Apply single transformation randomly picked from a list. 17 | 18 | Note that due to implementation limitations, the randomness occurs in a 19 | batch manner. 20 | 21 | Args: 22 | transforms: A list of transformations or a `keras.Layer`. 23 | p: A list of probability of each transform being picked. If p doesn't 24 | sum to `1.0`, it is automatically normalized. If `None`, all 25 | transforms have the same probability. Defaults to `None`. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | transforms, 31 | p: typing.Optional[typing.Sequence[keras.Layer]] = None, 32 | seed=None, 33 | **kwargs, 34 | ): 35 | super().__init__(**kwargs) 36 | self._backend = DynamicBackend(backend.backend()) 37 | self._random_generator = DynamicRandomGenerator( 38 | backend.backend(), seed=seed 39 | ) 40 | self.seed = seed 41 | 42 | # Check 43 | if not isinstance(transforms, (typing.Sequence, keras.Layer)): 44 | raise ValueError( 45 | "`transforms` must be a sequence (e.g. tuple and list) or a " 46 | "`keras.Layer`. " 47 | f"Received: transforms={transforms} of type {type(transforms)}" 48 | ) 49 | if isinstance(transforms, keras.Layer): 50 | transforms = [transforms] 51 | if p is not None: 52 | if not isinstance(p, typing.Sequence): 53 | raise TypeError( 54 | "If `p` is provided, it must be a sequence. " 55 | f"Received: p={p} of type {type(p)}" 56 | ) 57 | if len(p) != len(transforms): 58 | raise ValueError( 59 | "If `p` is provided, the length of it should be the same " 60 | "`transforms`. " 61 | f"Received: transforms={transforms}, p={p}" 62 | ) 63 | else: 64 | p = [1.0] * len(transforms) 65 | 66 | self.transforms = list(transforms) 67 | total = sum(p) 68 | self.p = [prob / total for prob in p] 69 | 70 | self._convert_input_args = False 71 | self._allow_non_tensor_positional_args = True 72 | self.autocast = False 73 | 74 | @property 75 | def backend(self): 76 | return self._backend.backend 77 | 78 | @property 79 | def random_generator(self): 80 | return self._random_generator.random_generator 81 | 82 | def compute_output_shape(self, input_shape): 83 | transform_shape = [ 84 | transfrom.compute_output_shape(input_shape) 85 | for transfrom in self.transforms 86 | ] 87 | transform_shape = set(transform_shape) 88 | if len(transform_shape) > 1: 89 | raise ValueError( 90 | "The output shape of all `transforms` must be the same. " 91 | f"Received: input_shape={input_shape}, " 92 | f"possible transform_shape={list(transform_shape)}" 93 | ) 94 | output_shape = list(transform_shape)[0] 95 | return output_shape 96 | 97 | def get_params(self): 98 | ops = self.backend 99 | random_generator = self.random_generator 100 | 101 | p = ops.convert_to_tensor([self.p]) 102 | p = ops.random.categorical(ops.numpy.log(p), 1, seed=random_generator) 103 | p = p[0][0] 104 | return p 105 | 106 | def __call__(self, inputs, **kwargs): 107 | if in_tf_graph(): 108 | self._set_backend("tensorflow") 109 | try: 110 | outputs = super().__call__(inputs, **kwargs) 111 | finally: 112 | self._reset_backend() 113 | return outputs 114 | else: 115 | return super().__call__(inputs, **kwargs) 116 | 117 | def call(self, inputs): 118 | ops = self.backend 119 | p = self.get_params() 120 | 121 | outputs = ops.core.switch(p, self.transforms, inputs) 122 | return outputs 123 | 124 | def get_config(self): 125 | config = super().get_config() 126 | config.update( 127 | { 128 | "transforms": saving.serialize_keras_object(self.transforms), 129 | "p": self.p, 130 | "seed": self.seed, 131 | } 132 | ) 133 | return config 134 | 135 | @classmethod 136 | def from_config(cls, config, custom_objects=None): 137 | config = config.copy() 138 | config["transforms"] = saving.deserialize_keras_object( 139 | config["transforms"], custom_objects=custom_objects 140 | ) 141 | return cls(**config) 142 | 143 | def _set_backend(self, name): 144 | self._backend.set_backend(name) 145 | self._random_generator.set_generator(name) 146 | 147 | def _reset_backend(self): 148 | self._backend.reset() 149 | self._random_generator.reset() 150 | -------------------------------------------------------------------------------- /keras_aug/_src/ops/image.py: -------------------------------------------------------------------------------- 1 | from keras.src.utils.backend_utils import in_tf_graph 2 | 3 | from keras_aug._src.backend.image import ImageBackend 4 | from keras_aug._src.keras_aug_export import keras_aug_export 5 | 6 | 7 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 8 | def transform_dtype(images, from_dtype, to_dtype, scale=True): 9 | backend = "tensorflow" if in_tf_graph() else None 10 | return ImageBackend(backend).transform_dtype( 11 | images, from_dtype, to_dtype, scale=scale 12 | ) 13 | 14 | 15 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 16 | def crop(images, top, left, height, width, data_format=None): 17 | backend = "tensorflow" if in_tf_graph() else None 18 | return ImageBackend(backend).crop( 19 | images, top, left, height, width, data_format=data_format 20 | ) 21 | 22 | 23 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 24 | def pad( 25 | images, mode, top, bottom, left, right, constant_value=0, data_format=None 26 | ): 27 | backend = "tensorflow" if in_tf_graph() else None 28 | return ImageBackend(backend).pad( 29 | images, 30 | mode, 31 | top, 32 | bottom, 33 | left, 34 | right, 35 | constant_value=constant_value, 36 | data_format=data_format, 37 | ) 38 | 39 | 40 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 41 | def adjust_brightness(images, factor): 42 | backend = "tensorflow" if in_tf_graph() else None 43 | return ImageBackend(backend).adjust_brightness(images, factor) 44 | 45 | 46 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 47 | def adjust_contrast(images, factor, data_format=None): 48 | backend = "tensorflow" if in_tf_graph() else None 49 | return ImageBackend(backend).adjust_contrast( 50 | images, factor, data_format=data_format 51 | ) 52 | 53 | 54 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 55 | def adjust_saturation(images, factor, data_format=None): 56 | backend = "tensorflow" if in_tf_graph() else None 57 | return ImageBackend(backend).adjust_saturation( 58 | images, factor, data_format=data_format 59 | ) 60 | 61 | 62 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 63 | def adjust_hue(images, factor, data_format=None): 64 | backend = "tensorflow" if in_tf_graph() else None 65 | return ImageBackend(backend).adjust_hue( 66 | images, factor, data_format=data_format 67 | ) 68 | 69 | 70 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 71 | def affine( 72 | images, 73 | angle, 74 | translate_x, 75 | translate_y, 76 | scale, 77 | shear_x, 78 | shear_y, 79 | center_x=None, 80 | center_y=None, 81 | interpolation="bilinear", 82 | padding_mode="constant", 83 | padding_value=0, 84 | data_format=None, 85 | ): 86 | backend = "tensorflow" if in_tf_graph() else None 87 | return ImageBackend(backend).affine( 88 | images, 89 | angle, 90 | translate_x, 91 | translate_y, 92 | scale, 93 | shear_x, 94 | shear_y, 95 | center_x=center_x, 96 | center_y=center_y, 97 | interpolation=interpolation, 98 | padding_mode=padding_mode, 99 | padding_value=padding_value, 100 | data_format=data_format, 101 | ) 102 | 103 | 104 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 105 | def auto_contrast(images, data_format=None): 106 | backend = "tensorflow" if in_tf_graph() else None 107 | return ImageBackend(backend).auto_contrast(images, data_format=data_format) 108 | 109 | 110 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 111 | def blend(images1, images2, factor): 112 | backend = "tensorflow" if in_tf_graph() else None 113 | return ImageBackend(backend).blend(images1, images2, factor) 114 | 115 | 116 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 117 | def equalize(images, bins=256, data_format=None): 118 | backend = "tensorflow" if in_tf_graph() else None 119 | return ImageBackend(backend).equalize( 120 | images, bins=bins, data_format=data_format 121 | ) 122 | 123 | 124 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 125 | def guassian_blur(images, kernel_size, sigma, data_format=None): 126 | backend = "tensorflow" if in_tf_graph() else None 127 | return ImageBackend(backend).guassian_blur( 128 | images, kernel_size, sigma, data_format=data_format 129 | ) 130 | 131 | 132 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 133 | def rgb_to_grayscale(images, num_channels=3, data_format=None): 134 | backend = "tensorflow" if in_tf_graph() else None 135 | return ImageBackend(backend).rgb_to_grayscale( 136 | images, num_channels=num_channels, data_format=data_format 137 | ) 138 | 139 | 140 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 141 | def invert(images): 142 | backend = "tensorflow" if in_tf_graph() else None 143 | return ImageBackend(backend).invert(images) 144 | 145 | 146 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 147 | def posterize(images, bits): 148 | backend = "tensorflow" if in_tf_graph() else None 149 | return ImageBackend(backend).posterize(images, bits) 150 | 151 | 152 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 153 | def sharpen(images, factor, data_format=None): 154 | backend = "tensorflow" if in_tf_graph() else None 155 | return ImageBackend(backend).sharpen( 156 | images, factor, data_format=data_format 157 | ) 158 | 159 | 160 | @keras_aug_export(parent_path=["keras_aug.ops.image"]) 161 | def solarize(images, threshold): 162 | backend = "tensorflow" if in_tf_graph() else None 163 | return ImageBackend(backend).solarize(images, threshold) 164 | --------------------------------------------------------------------------------