├── .github └── workflows │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── example ├── example.py └── images │ ├── results │ └── .gitignore │ └── ximilar-similar.jpg ├── logo.png ├── setup.py ├── test └── core │ └── test_bboxes.py └── tf_image ├── __init__.py ├── application ├── __init__.py ├── augmentation_config.py └── tools.py └── core ├── __init__.py ├── bboxes ├── __init__.py ├── clip.py ├── erase.py ├── flip.py ├── resize.py └── rotate.py ├── clip.py ├── colors.py ├── convert_type_decorator.py ├── erase.py ├── quality.py ├── random.py └── resize.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Package To PyPI 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | build-n-publish: 10 | name: Build and publish Python Package 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: '3.x' 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install setuptools wheel twine 22 | - name: Build and publish 23 | env: 24 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 25 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 26 | run: | 27 | python setup.py sdist bdist_wheel 28 | twine upload dist/* 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ximilar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tf-image 2 | 3 | __tf-image__ implements methods for image augmentation for Tensorflow 2.+ / tf.data.Dataset. 4 | 5 | __Why?__ 6 | 7 | Official TensorFlow 2+ [tf.image](https://www.tensorflow.org/api_docs/python/tf/image) package contains just 8 | a few and simple operations for image augmentation. This is not enough if you want to augment images and using 9 | all the power of tf.data.Dataset. There is also [tf-addons](https://www.tensorflow.org/addons) projects which 10 | contains more of the operations (for example rotate), but it still not enough. 11 | And on top of that, none of those two supports operation on bounding boxes and therefore is not fully usable 12 | for augmentation object detection datasets. 13 | 14 | If you do not require the operations in graph then simply use cv2, [imgaug](https://github.com/aleju/imgaug) 15 | or [albumentations](https://github.com/albumentations-team/albumentations) together with `tf.py_function`. 16 | They have (at the moment) much more operations and options for image augmentation. 17 | 18 | ## Installation 19 | 20 | Use pip: 21 | 22 | pip install tf-image 23 | 24 | For installation from source code, clone the repository and install it from code (`pip install -e .`). 25 | There are no dependencies specified. You have to install TensorFlow 2+ and appropriate TensorFlow Addons. 26 | Specific version is on you, we wanted to keep this library as general as possible. 27 | 28 | ## Image and bounding boxes formats 29 | We use channel last format for images. Images can be represented either in 0.0 - 1.0 or 0 - 255 range. 30 | Similar is true for bounding boxes. They can be provided either in relative coordinates with range 0.0 - 1.0 using 31 | float dtype or in absolute image coordinates using integer dtype. 32 | Internally, This is done using [convert_type](tf_image/core/convert_type_decorator.py) 33 | decorator on functions which needs it. This decorator converts the images into the type we use 34 | (tf.float and 0.0-1.1 in both cases) and after the function is done, original format is restored. 35 | If performing multiple operations, you can use this decorator on own function. 36 | (Conversions after each operation will not be needed.) 37 | 38 | ## Quickstart 39 | For your convenience, we included a simple and configurable application, which combines all the provided augmentations. 40 | They are performed in a random order to make the augmentation even more powerful. 41 | 42 | There is also one script which uses this augmentation function and which outputs three augmented 43 | image without bounding boxes and three with bonding boxes. 44 | See [example/example.py](example/example.py) for more information. 45 | 46 | If you want to use the functions alone, here is how: 47 | ```python 48 | import tensorflow as tf 49 | import tensorflow_addons as tfa 50 | 51 | from tf_image.core.random import random_function 52 | from tf_image.core.colors import rgb_shift, channel_drop 53 | from tf_image.core.convert_type_decorator import convert_type 54 | 55 | 56 | @convert_type 57 | def augment_image(image): 58 | # use TensorFlow library 59 | image = tf.image.random_flip_left_right(image) 60 | image = tf.image.random_flip_up_down(image) 61 | 62 | # use tf-image library 63 | image = random_function( 64 | image, rgb_shift, 0.1, None, **{"r_shift": 0.1, "g_shift": 0.1, "b_shift": 0.1} 65 | ) # do rgb shift with 10 % prob 66 | image = random_function(image, channel_drop, 0.1, None) 67 | # and whatever else you want 68 | 69 | # use TensorFlow Addons library 70 | image = tfa.image.rotate(image, 10) 71 | 72 | return image 73 | 74 | 75 | def map_function(image_file, label): 76 | image = tf.io.read_file(image_file) 77 | image = tf.image.decode_jpeg(image) 78 | image = augment_image(image) 79 | 80 | return image, label 81 | 82 | 83 | def return_dataset(image_files, labels): 84 | dataset = ( 85 | tf.data.Dataset.from_tensor_slices((image_files, labels)) 86 | .cache() 87 | .shuffle(len(image_files)) 88 | .map(map_function) 89 | .batch(20) 90 | .prefetch(tf.data.experimental.AUTOTUNE) 91 | ) 92 | 93 | return dataset 94 | 95 | return_dataset(["images/ximilar-similar.jpg"], [[1,2,3]]) 96 | ``` 97 | 98 | ## Supported operations 99 | 100 | Image augmentations: 101 | * aspect ration deformations *(inc. bounding boxes)* 102 | * channel drop 103 | * channel swap 104 | * erase, see [https://arxiv.org/abs/1708.04552] *(repeated, inc. bounding boxes)* 105 | * flip up-down, left-right *(inc. bounding boxes)* 106 | * grayscale 107 | * gaussian noise 108 | * clip *(inc. bounding boxes)* 109 | * rgb shift 110 | * resize with different methods *(inc. bounding boxes)* 111 | * rotate *(inc. bounding boxes)* 112 | 113 | Random operations: 114 | * random_function: calls function on image with some probability [0.0, 0.1] 115 | * random_function_bboxes: calls function on image and bounding boxes with some probability [0.0, 0.1] 116 | 117 | Feel free to improve and add more functions. We are looking forward to your merge requests! 118 | (Please only plain tensorflow2+, no opencv.) 119 | 120 | [![](logo.png)](https://ximilar.com) 121 | -------------------------------------------------------------------------------- /example/example.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tf_image.application.augmentation_config import AugmentationConfig 4 | from tf_image.application.tools import random_augmentations 5 | 6 | # 7 | # Loads an images from images/ximilar-similar.jpg, create some bounding boxes and augment 8 | # three times without bounding boxes and three times with them. Results are saved to images/results folder. 9 | # 10 | 11 | # Loads the basic setup, feel free to experiment! 12 | config = AugmentationConfig() 13 | 14 | # Loads the image and creates bounding boxes for three completely visible apples. 15 | image_encoded = tf.io.read_file("images/ximilar-similar.jpg") 16 | image = tf.image.decode_jpeg(image_encoded) 17 | 18 | bboxes = tf.constant([[262.0, 135.0, 504.0, 371.0], [285.0, 446.0, 494.0, 644.0], [272.0, 688.0, 493.0, 895.0]]) 19 | bboxes /= tf.cast( 20 | tf.stack([tf.shape(image)[0], tf.shape(image)[1], tf.shape(image)[0], tf.shape(image)[1],]), tf.float32 21 | ) 22 | bboxes_colors = [[0, 0, 255], [0, 0, 255], [0, 0, 255]] 23 | 24 | for i in range(3): 25 | image_augmented = random_augmentations(image, config) 26 | 27 | image_augmented_encoded = tf.image.encode_png(image_augmented) 28 | tf.io.write_file(f"images/results/ximilar-similar_{i + 1}.png", image_augmented_encoded) 29 | 30 | for i in range(3): 31 | image_augmented, bboxes_augmented = random_augmentations(image, config, bboxes=bboxes) 32 | 33 | image_augmented = tf.image.draw_bounding_boxes([image_augmented], [bboxes_augmented], bboxes_colors)[0] 34 | image_augmented = tf.cast(image_augmented, tf.uint8) # for some reason, draw_bounding_boxes converts image to float 35 | 36 | image_augmented_encoded = tf.image.encode_png(image_augmented) 37 | tf.io.write_file(f"images/results/ximilar-similar_bboxes_{i + 1}.png", image_augmented_encoded) 38 | -------------------------------------------------------------------------------- /example/images/results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore -------------------------------------------------------------------------------- /example/images/ximilar-similar.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ximilar-com/tf-image/b6218217aceb21481360f8934bedfea8a9190f61/example/images/ximilar-similar.jpg -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ximilar-com/tf-image/b6218217aceb21481360f8934bedfea8a9190f61/logo.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | with open("README.md", "r") as fh: 5 | long_description = fh.read() 6 | 7 | setup( 8 | name="tf-image", 9 | version="0.2.0", 10 | description="Image augmentation operations for TensorFlow 2+.", 11 | url="https://github.com/Ximilar-com/tf-image", 12 | author="Ximilar.com Team, Michal Lukac, Libor Vanek, ...", 13 | author_email="tech@ximilar.com", 14 | license="MIT", 15 | packages=find_packages(), 16 | keywords="machine learning, multimedia, image", 17 | classifiers=[ 18 | "Development Status :: 3 - Alpha", 19 | "License :: OSI Approved :: MIT License", 20 | "Programming Language :: Python :: 3.6", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | ], 23 | include_package_data=True, 24 | zip_safe=False, 25 | namespace_packages=["tf_image"], 26 | long_description=long_description, 27 | long_description_content_type="text/markdown", 28 | ) 29 | -------------------------------------------------------------------------------- /test/core/test_bboxes.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import unittest 3 | 4 | from tf_image.core.bboxes.clip import clip_bboxes, clip_random_with_bboxes 5 | from tf_image.core.bboxes.erase import multiple_erase 6 | from tf_image.core.bboxes.flip import ( 7 | flip_left_right, 8 | flip_up_down, 9 | ) 10 | from tf_image.core.bboxes.resize import resize, random_pad_to_square, random_aspect_ratio_deformation 11 | from tf_image.core.bboxes.rotate import _find_bbox, _unpack_bbox 12 | 13 | 14 | class BboxTest(unittest.TestCase): 15 | def check_types(self, function): 16 | """ 17 | We want our functions to work with different types to provided user the freedom to choose. 18 | We test the two standard ones - tf.int8, tf.float32. If this works, other meaningful types are OK as well. 19 | 20 | We provide the image and bounding boxes. Function does not need to use both or return meaningful value for both. 21 | But it has to return at least None. 22 | 23 | :param function: (image, bounding boxes, relative coordinates) -> (image, bounding boxes) 24 | """ 25 | dtypes = [tf.int8, tf.float32] 26 | for dtype in dtypes: 27 | image = tf.ones([24, 8, 3], dtype=dtype) 28 | bboxes = tf.constant([[1, 1, 15, 6]], dtype=dtype) 29 | 30 | if dtype.is_floating: 31 | bboxes /= tf.cast([image.shape[0], image.shape[1], image.shape[0], image.shape[1]], dtype) 32 | 33 | image, bboxes = function(image, bboxes) 34 | 35 | if image is not None: 36 | self.assertEqual(image.dtype, dtype, msg="Image type does not fit.") 37 | 38 | if bboxes is not None: 39 | self.assertEqual(bboxes.dtype, dtype, msg="Bounding box type does not fit.") 40 | 41 | 42 | class TestClip(BboxTest): 43 | def test_clip_types(self): 44 | self.check_types(lambda image, bboxes: clip_random_with_bboxes(image, bboxes)) 45 | 46 | def test_clip_relative(self): 47 | # calculation done for image = tf.ones([24, 8, 3]) 48 | bboxes = tf.constant([[0.1, 0.25, 0.5, 0.75]]) 49 | 50 | clipped = clip_bboxes(bboxes, tf.constant(0.1), tf.constant(0.125), tf.constant(0.5), tf.constant(0.75)) 51 | tf.debugging.assert_equal(tf.constant([[0.0, 1.0 / 6.0, 0.8, 5.0 / 6.0]]), clipped) 52 | 53 | 54 | class TestErase(BboxTest): 55 | def test_flip_left_right_types(self): 56 | self.check_types(lambda image, bboxes: multiple_erase(image, bboxes)) 57 | 58 | 59 | class TestFlip(BboxTest): 60 | def test_flip_left_right_types(self): 61 | self.check_types(lambda image, bboxes: flip_left_right(image, bboxes)) 62 | 63 | def test_flip_left_right_absolute(self): 64 | image = tf.ones([24, 8, 3]) 65 | bboxes = tf.constant([[1, 1, 15, 6]],) 66 | 67 | _, flipped = flip_left_right(image, bboxes) 68 | tf.debugging.assert_equal(tf.constant([[1, 2, 15, 7]]), flipped) 69 | 70 | def test_flip_left_right_relative(self): 71 | image = tf.ones([24, 8, 3]) 72 | bboxes = tf.constant([[0.1, 0.2, 0.4, 0.7]]) 73 | 74 | _, flipped = flip_left_right(image, bboxes) 75 | tf.debugging.assert_equal(tf.constant([[0.1, 0.3, 0.4, 0.8]]), flipped) 76 | 77 | def test_flip_up_down_types(self): 78 | self.check_types(lambda image, bboxes: flip_up_down(image, bboxes)) 79 | 80 | def test_flip_up_down_absolute(self): 81 | image = tf.ones([24, 8, 3]) 82 | bboxes = tf.constant([[1, 1, 15, 6]]) 83 | 84 | _, flipped = flip_up_down(image, bboxes) 85 | tf.debugging.assert_equal(tf.constant([[9, 1, 23, 6]]), flipped) 86 | 87 | def test_flip_up_down_relative(self): 88 | image = tf.ones([24, 8, 3]) 89 | bboxes = tf.constant([[0.1, 0.2, 0.4, 0.7]]) 90 | 91 | _, flipped = flip_up_down(image, bboxes) 92 | tf.debugging.assert_equal(tf.constant([[0.6, 0.2, 0.9, 0.7]]), flipped) 93 | 94 | 95 | class TestResize(BboxTest): 96 | def setUp(self): 97 | self.image = tf.ones([24, 8, 3]) 98 | self.bboxes = tf.constant([[0.25, 0.50, 0.75, 0.8]]) 99 | 100 | def test_resize_types(self): 101 | self.check_types(lambda image, bboxes: resize(image, bboxes, 20, 20)) 102 | self.check_types(lambda image, bboxes: random_pad_to_square(image, bboxes)) 103 | self.check_types(lambda image, bboxes: random_aspect_ratio_deformation(image, bboxes)) 104 | 105 | def test_resize_keep_aspect_ratio(self): 106 | image_resized, bboxes_resized = resize(self.image, self.bboxes, 6, 6, keep_aspect_ratio=True) 107 | tf.debugging.assert_equal( 108 | tf.constant([6, 6, 3]), image_resized.shape, message="Dimensions of the resized image are wrong." 109 | ) 110 | 111 | tf.debugging.assert_equal( 112 | tf.round(tf.ones([6, 2, 3], dtype=tf.float32) * 1000), 113 | tf.round(image_resized[:, 2:4, :] * 1000), 114 | message="Ones from original image should be there.", 115 | ) 116 | tf.debugging.assert_equal(tf.zeros([6, 2, 3]), image_resized[:, 4:, :], message="Padding should be here.") 117 | tf.debugging.assert_equal(tf.zeros([6, 2, 3]), image_resized[:, :2, :], message="Padding should be here.") 118 | 119 | tf.debugging.assert_equal( 120 | tf.constant(tf.round(tf.constant([[0.25, 0.5, 0.75, 0.6]]) * 1000.0)), tf.round(bboxes_resized * 1000.0) 121 | ) 122 | 123 | def test_resize_not_keep_aspect_ratio(self): 124 | image_resized, bboxes_resized = resize(self.image, self.bboxes, 6, 6, keep_aspect_ratio=False) 125 | tf.debugging.assert_equal( 126 | tf.constant([6, 6, 3]), image_resized.shape, message="Dimensions of the resized image are wrong." 127 | ) 128 | 129 | mult = 10 130 | tf.debugging.assert_equal( 131 | tf.ones([6, 6, 3]) * mult, 132 | tf.math.round(image_resized * mult), 133 | message="New image should contain only ones.", 134 | ) 135 | 136 | h, w = self.image.shape[0], self.image.shape[1] 137 | tf.debugging.assert_equal(self.bboxes, bboxes_resized) 138 | 139 | 140 | class TestRotate(unittest.TestCase): 141 | def test_find_bbox(self): 142 | coordinates = tf.constant( 143 | [[10.0, 2.0], [15.0, 6.0], [22.0, 10.0], [1.0, 30.0], [5.0, 2.0], [8.0, 9.0], [10.0, 15.0], [22.0, 3.0]] 144 | ) 145 | bbox = _find_bbox(coordinates) 146 | tf.debugging.assert_equal(tf.constant([1.0, 2.0, 22.0, 30.0]), bbox) 147 | 148 | def test_pack_unpack(self): 149 | bbox = tf.constant([1, 2, 3, 4]) 150 | bbox2 = _find_bbox(_unpack_bbox(bbox)) 151 | tf.debugging.assert_equal(bbox, bbox2) 152 | 153 | 154 | if __name__ == "__main__": 155 | unittest.main() 156 | -------------------------------------------------------------------------------- /tf_image/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of several augmentation operations in TensorFlow 2+. 3 | """ 4 | 5 | __version__ = "0.1.0" 6 | 7 | from . import core 8 | 9 | from pkg_resources import declare_namespace 10 | 11 | declare_namespace("tf_image") 12 | -------------------------------------------------------------------------------- /tf_image/application/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ximilar-com/tf-image/b6218217aceb21481360f8934bedfea8a9190f61/tf_image/application/__init__.py -------------------------------------------------------------------------------- /tf_image/application/augmentation_config.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | # use int enum because of TensorFlow 5 | class ColorAugmentation(IntEnum): 6 | """ 7 | Enum which splits color related operations into two groups: 8 | - LIGHT: brightness, contrast 9 | - MEDIUM: LIGHT + hue, saturation 10 | - AGGRESSIVE: MEDIUM + channel swap, channel drop, gray scale. 11 | 12 | In addition, there is an option for no augmentations: ColorAugmentation.NONE. 13 | """ 14 | 15 | NONE = 0 16 | LIGHT = 1 17 | MEDIUM = 2 18 | AGGRESSIVE = 3 19 | 20 | 21 | class AspectRatioAugmentation(IntEnum): 22 | """ 23 | There are two posibilities how we can distort aspect ration: 24 | - NORMAL: same maximal distortions in both horizontal and vertical direction or 25 | - TOWARDS_SQUARE: more squeezing for longer side and more stretching for a shorter side. 26 | 27 | In addition, there is an option for no augmentations: AspectRatioAugmentation.NONE. 28 | """ 29 | 30 | NONE = 0 31 | NORMAL = 1 32 | TOWARDS_SQUARE = 2 33 | 34 | 35 | class AugmentationConfig(object): 36 | """ 37 | Specifies which augmentations should be applied. 38 | """ 39 | 40 | def __init__(self): 41 | self.color = ColorAugmentation.AGGRESSIVE 42 | self.crop = True 43 | self.distort_aspect_ratio = AspectRatioAugmentation.NORMAL 44 | self.quality = True # jpeg quality, noise 45 | self.erasing = True 46 | self.rotate90 = False 47 | self.rotate45 = False # rotate 45 degrees clockwise (other multiples can be done by turning on rotate90) 48 | self.rotate_max = 13 49 | self.flip_vertical = True 50 | self.flip_horizontal = True 51 | self.padding_square = False 52 | -------------------------------------------------------------------------------- /tf_image/application/tools.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tf_image.application.augmentation_config import ColorAugmentation, AugmentationConfig, AspectRatioAugmentation 4 | from tf_image.core.bboxes.clip import clip_random_with_bboxes 5 | from tf_image.core.bboxes.erase import multiple_erase, calculate_bboxes_max_erase_area 6 | from tf_image.core.bboxes.flip import flip_left_right, flip_up_down 7 | from tf_image.core.bboxes.resize import random_aspect_ratio_deformation, random_pad_to_square 8 | from tf_image.core.bboxes.rotate import random_rotate, rot90, rot45 9 | from tf_image.core.clip import clip_random 10 | from tf_image.core.colors import channel_drop, grayscale, channel_swap, rgb_shift 11 | from tf_image.core.convert_type_decorator import convert_type 12 | from tf_image.core.quality import gaussian_noise 13 | from tf_image.core.random import random_function 14 | from tf_image.core.random import random_function_bboxes 15 | 16 | 17 | def random_augmentations(image, augmentation_config: AugmentationConfig, bboxes=None, prob_demanding_ops: float = 0.5): 18 | """ 19 | Apply augmentations in random order. 20 | 21 | WARNING: this is just a testing class and it is likely to change. 22 | 23 | :param image: 3-D Tensor of shape (height, width, channels). 24 | :param augmentation_config: Config defining which augmentations can be applied. 25 | :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax] 26 | :param prob_demanding_ops: Probability that a time consuming operation (like rotation) will be performed. 27 | :return: augmented image or (augmented image, bboxes) if bboxes parameter is not None 28 | """ 29 | has_bboxes = bboxes is not None 30 | if not has_bboxes: 31 | bboxes = tf.reshape([], (0, 4)) 32 | 33 | # convert_dtype decorator needs this special argument order (converting now saves us converting in each operation) 34 | image, bboxes = _random_augmentations(image, bboxes, augmentation_config, prob_demanding_ops) 35 | 36 | if has_bboxes: 37 | return image, bboxes 38 | 39 | return image 40 | 41 | 42 | @tf.function 43 | @convert_type 44 | def _random_augmentations(image, bboxes, augmentation_config: AugmentationConfig, prob_demanding_ops: float): 45 | @tf.function 46 | def apply(idx, image, bboxes): 47 | # List of tuples (precondition, augmentation), augmentation will be applied only if precondition is True. 48 | functions = [ 49 | ( 50 | tf.math.equal(augmentation_config.color, ColorAugmentation.AGGRESSIVE), 51 | lambda: ( 52 | ( 53 | random_function(image, rgb_shift, 0.2, **{"r_shift": 0.15, "g_shift": 0.15, "b_shift": 0.15}), 54 | bboxes, 55 | ) 56 | ), 57 | ), 58 | ( 59 | tf.math.equal(augmentation_config.color, ColorAugmentation.AGGRESSIVE), 60 | lambda: ( 61 | random_function(image, channel_swap, 0.1), 62 | bboxes, 63 | ), 64 | ), 65 | ( 66 | tf.math.equal(augmentation_config.color, ColorAugmentation.AGGRESSIVE), 67 | lambda: (random_function(image, grayscale, 0.1), bboxes), 68 | ), 69 | ( 70 | tf.math.equal(augmentation_config.color, ColorAugmentation.AGGRESSIVE), 71 | lambda: (random_function(image, channel_drop, 0.1), bboxes), 72 | ), 73 | ( 74 | tf.math.greater_equal(augmentation_config.color, ColorAugmentation.LIGHT), 75 | lambda: (tf.image.random_brightness(image, 0.2), bboxes), 76 | ), 77 | ( 78 | tf.math.greater_equal(augmentation_config.color, ColorAugmentation.LIGHT), 79 | lambda: (tf.image.random_contrast(image, 0.8, 1.2), bboxes), 80 | ), 81 | ( 82 | tf.math.greater_equal(augmentation_config.color, ColorAugmentation.MEDIUM), 83 | lambda: (tf.image.random_saturation(image, 0.8, 1.2), bboxes), 84 | ), 85 | ( 86 | tf.math.greater_equal(augmentation_config.color, ColorAugmentation.MEDIUM), 87 | lambda: (tf.image.random_hue(image, 0.2), bboxes), 88 | ), 89 | ( 90 | tf.math.equal(augmentation_config.crop, True), 91 | lambda: tf.cond( 92 | tf.greater(tf.shape(bboxes)[0], 0), 93 | lambda: clip_random_with_bboxes(image, bboxes), 94 | lambda: ( 95 | clip_random( 96 | image, 97 | min_shape=( 98 | tf.cast(tf.cast(tf.shape(image)[0], dtype=tf.float32) * 0.9, dtype=tf.int32), 99 | tf.cast(tf.cast(tf.shape(image)[1], dtype=tf.float32) * 0.9, dtype=tf.int32), 100 | ), 101 | ), 102 | bboxes, 103 | ), 104 | ), 105 | ), 106 | ( 107 | tf.math.equal(augmentation_config.distort_aspect_ratio, AspectRatioAugmentation.NORMAL), 108 | lambda: random_function_bboxes( 109 | image, 110 | bboxes, 111 | random_aspect_ratio_deformation, 112 | prob=prob_demanding_ops, 113 | unify_dims=False, 114 | max_squeeze=0.6, 115 | max_stretch=1.3, 116 | ), 117 | ), 118 | ( 119 | tf.math.equal(augmentation_config.distort_aspect_ratio, AspectRatioAugmentation.TOWARDS_SQUARE), 120 | lambda: random_function_bboxes( 121 | image, 122 | bboxes, 123 | random_aspect_ratio_deformation, 124 | prob=prob_demanding_ops, 125 | unify_dims=True, 126 | max_squeeze=0.6, 127 | max_stretch=1.3, 128 | ), 129 | ), 130 | ( 131 | tf.math.equal(augmentation_config.quality, True), 132 | lambda: (random_function(image, gaussian_noise, prob=0.15, stddev_max=0.05), bboxes), 133 | ), 134 | ( 135 | tf.math.equal(augmentation_config.erasing, True), 136 | lambda: multiple_erase( 137 | image, 138 | bboxes, 139 | iterations=tf.random.uniform((), 0, 7, tf.int32), 140 | max_area=calculate_bboxes_max_erase_area(bboxes, max_area=0.1), 141 | ), 142 | ), 143 | (tf.math.equal(augmentation_config.rotate90, True), lambda: rot90(image, bboxes)), 144 | ( 145 | tf.math.equal(augmentation_config.rotate45, True), 146 | lambda: (random_function_bboxes(image, bboxes, rot45, prob=0.5)), 147 | ), 148 | ( 149 | tf.math.greater(augmentation_config.rotate_max, 0), 150 | lambda: random_function_bboxes( 151 | image, 152 | bboxes, 153 | random_rotate, 154 | prob=prob_demanding_ops, 155 | min_rotate=-augmentation_config.rotate_max, 156 | max_rotate=augmentation_config.rotate_max, 157 | ), 158 | ), 159 | ( 160 | tf.math.equal(augmentation_config.flip_horizontal, True), 161 | lambda: random_function_bboxes(image, bboxes, flip_left_right, 0.5), 162 | ), 163 | ( 164 | tf.math.equal(augmentation_config.flip_vertical, True), 165 | lambda: random_function_bboxes(image, bboxes, flip_up_down, 0.5), 166 | ), 167 | ] 168 | 169 | # We cannot simply index by i, this loop will find the given augmentation 170 | # and perform it if the precondition is satisfied. 171 | for i in range(len(functions)): 172 | image, bboxes = tf.cond( 173 | tf.math.logical_and(tf.equal(i, idx), functions[i][0]), functions[i][1], lambda: (image, bboxes) 174 | ) 175 | 176 | return image, bboxes 177 | 178 | # TODO we had some problems if random_jpeg_quality was inside the random operations ... find out why 179 | image = tf.cond( 180 | tf.math.equal(augmentation_config.quality, True), 181 | lambda: tf.image.random_jpeg_quality(image, 35, 98), 182 | lambda: image, 183 | ) 184 | 185 | # Randomize the sequence of augmentation indices. 186 | augmentation_count = 17 187 | order = tf.random.shuffle(tf.range(augmentation_count)) 188 | 189 | # Loop over all augmentation and apply them. 190 | i = tf.constant(0, dtype=tf.int32) 191 | condition = lambda i, _image, _bboxes: tf.greater(augmentation_count, i) 192 | body = lambda i, image, bboxes: (i + 1, *apply(order[i], image, bboxes)) 193 | _, image, bboxes = tf.while_loop( 194 | condition, 195 | body, 196 | (i, image, bboxes), 197 | shape_invariants=( 198 | i.get_shape(), 199 | tf.TensorShape([None, None, None]), 200 | tf.TensorShape([None, 4]), 201 | ), 202 | ) 203 | 204 | # this ned to be at the end, otherwise we are not guaranteed to get the square 205 | # (and it could interact with the other augmentation in such way that we would have too much empty space) 206 | image, bboxes = tf.cond( 207 | tf.math.equal(augmentation_config.padding_square, True), 208 | lambda: random_pad_to_square(image, bboxes), 209 | lambda: ( 210 | image, 211 | bboxes, 212 | ), 213 | ) 214 | 215 | return image, bboxes 216 | -------------------------------------------------------------------------------- /tf_image/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ximilar-com/tf-image/b6218217aceb21481360f8934bedfea8a9190f61/tf_image/core/__init__.py -------------------------------------------------------------------------------- /tf_image/core/bboxes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ximilar-com/tf-image/b6218217aceb21481360f8934bedfea8a9190f61/tf_image/core/bboxes/__init__.py -------------------------------------------------------------------------------- /tf_image/core/bboxes/clip.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tf_image.core.clip import clip_random 4 | from tf_image.core.convert_type_decorator import convert_type 5 | 6 | 7 | @tf.function 8 | @convert_type 9 | def clip_random_with_bboxes(image, bboxes, min_shape=(1, 1)): 10 | """ 11 | Randomly clips an image in such way, that all bounding boxes are still fully present in the resulting image. 12 | Update bounding boxes to match the new image. 13 | 14 | :param image: 3-D Tensor of shape (height, width, channels). 15 | :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax] 16 | :return: (clipped image, updated bounding boxes) 17 | """ 18 | with tf.name_scope("clip_random_with_bboxes"): 19 | return tf.cond( 20 | tf.equal(tf.reduce_sum(bboxes), 0), 21 | lambda: (clip_random(image, min_shape), bboxes), 22 | lambda: _clip_random_with_bboxes(image, bboxes), 23 | ) 24 | 25 | 26 | @tf.function 27 | def _clip_random_with_bboxes(image, bboxes): 28 | image_height = tf.cast(tf.shape(image)[0], dtype=tf.float32) 29 | image_width = tf.cast(tf.shape(image)[1], dtype=tf.float32) 30 | 31 | # calculate coordinates 32 | new_miny = tf.random.uniform([], 0.0, tf.math.reduce_min(bboxes[:, 0])) 33 | new_minx = tf.random.uniform([], 0.0, tf.math.reduce_min(bboxes[:, 1])) 34 | new_maxy = tf.random.uniform([], tf.math.reduce_max(bboxes[:, 2]), 1.0) 35 | new_maxx = tf.random.uniform([], tf.math.reduce_max(bboxes[:, 3]), 1.0) 36 | new_height, new_width = new_maxy - new_miny, new_maxx - new_minx 37 | 38 | # prepare parameters 39 | args_clip_bboxes = [new_miny, new_minx, new_height, new_width] 40 | args_clip_image = [ 41 | tf.cast(new_miny * image_height, dtype=tf.int32), 42 | tf.cast(new_minx * image_width, dtype=tf.int32), 43 | tf.cast(new_height * image_height, dtype=tf.int32), 44 | tf.cast(new_width * image_width, dtype=tf.int32), 45 | ] 46 | 47 | # update 48 | image, bboxes = tf.cond( 49 | tf.math.logical_or(tf.math.greater(new_height, 0), tf.math.greater(new_width, 0)), 50 | lambda: (tf.image.crop_to_bounding_box(image, *args_clip_image), clip_bboxes(bboxes, *args_clip_bboxes)), 51 | lambda: (image, bboxes), 52 | ) 53 | 54 | return image, bboxes 55 | 56 | 57 | @tf.function 58 | def clip_bboxes(bboxes_relative, new_miny, new_minx, new_height, new_width): 59 | """ 60 | Calculates new coordinates for given bounding boxes given the cut area of an image. 61 | 62 | :param bboxes_relative: 2-D Tensor (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax]. 63 | :param new_miny: Relative clipping coordinate. 64 | :param new_minx: Relative clipping coordinate. 65 | :param new_height: Relative clipping coordinate. 66 | :param new_width: Relative clipping coordinate. 67 | :return: clipped bounding boxes 68 | """ 69 | # move the coordinates according to new min value 70 | bboxes_move_min = tf.stack([new_miny, new_minx, new_miny, new_minx]) 71 | bboxes = bboxes_relative - bboxes_move_min 72 | 73 | # if we use relative coordinates, we have to scale the coordinates to be between 0 and 1 again 74 | bboxes_scale = [new_height, new_width, new_height, new_width] 75 | bboxes = bboxes / bboxes_scale 76 | 77 | return bboxes 78 | -------------------------------------------------------------------------------- /tf_image/core/bboxes/erase.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tf_image.core.convert_type_decorator import convert_type 4 | from tf_image.core.erase import random_erasing 5 | 6 | 7 | @convert_type 8 | def multiple_erase(image, bboxes, iterations=10, max_area=1.0): 9 | """ 10 | Repeatedly erase rectangular areas from given image. 11 | 12 | :param image: 3-D Tensor of shape [height, width, channels]. 13 | :param bboxes: Bounding box representing the cut which will give us the clipped image. 14 | :param iterations: How many random rectangles we are going to erase. 15 | :param max_area: Maximum part of the image to be erased in one iteration. (Range: 0.0 to 1.0) 16 | :return: (augmented image, unchanged bboxes) 17 | """ 18 | with tf.name_scope("multiple_erase"): 19 | max_area = tf.clip_by_value(tf.cast(max_area, dtype=tf.float32), 0.0, 1.0) 20 | 21 | i = tf.constant(0) 22 | condition = lambda i, _image: i < iterations 23 | body = lambda i, image: (i + 1, random_erasing(image, max_area=max_area)) 24 | _, image = tf.while_loop(condition, body, (i, image)) 25 | 26 | return image, bboxes 27 | 28 | 29 | def calculate_bboxes_max_erase_area(bboxes, max_area=1.0, erase_smallest=0.5): 30 | """ 31 | Calculates the the biggest area (width * height) that can be erased on an image with given bounding boxes. 32 | 33 | Result = smallest value from (smallest bounding boxes size * erase_smallest) or max_area 34 | 35 | :param bboxes: Bounding box representing the cut which will give us the clipped image. 36 | :param max_area: Maximum part of the image to be erased in one iteration. (0.0 to 1.0) 37 | :param erase_smallest: Multiple of the smallest bounding that we could erase. (0.0 none, 1.0 full or more.) 38 | :return: relative max_area (0.0 - 1.0) 39 | """ 40 | max_area = tf.clip_by_value(tf.cast(max_area, dtype=tf.float32), 0.0, 1.0) 41 | 42 | sizes = bboxes[:, 2:] - bboxes[:, :2] 43 | areas = tf.math.reduce_prod(sizes, axis=1) 44 | smallest_bbox_area = tf.math.reduce_min(areas) # inf if there are no bonding boxes 45 | return tf.minimum(smallest_bbox_area * (1 - erase_smallest), max_area) 46 | -------------------------------------------------------------------------------- /tf_image/core/bboxes/flip.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tf_image.core.convert_type_decorator import convert_type 4 | 5 | 6 | @tf.function 7 | @convert_type 8 | def flip_left_right(image, bboxes): 9 | """ 10 | Flip an image and bounding boxes horizontally (left to right). 11 | 12 | :param image: 3-D Tensor of shape [height, width, channels] 13 | :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax] 14 | :return: image, bounding boxes 15 | """ 16 | with tf.name_scope("flip_left_right"): 17 | bboxes = bboxes * tf.constant([1, -1, 1, -1], dtype=tf.float32) + tf.stack([0.0, 1.0, 0.0, 1.0]) 18 | bboxes = tf.stack([bboxes[:, 0], bboxes[:, 3], bboxes[:, 2], bboxes[:, 1]], axis=1) 19 | 20 | image = tf.image.flip_left_right(image) 21 | 22 | return image, bboxes 23 | 24 | 25 | @tf.function 26 | @convert_type 27 | def flip_up_down(image, bboxes): 28 | """ 29 | Flip an image and bounding boxes vertically (upside down). 30 | 31 | :param image: 3-D Tensor of shape [height, width, channels] 32 | :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax] 33 | :return: image, bounding boxes 34 | """ 35 | with tf.name_scope("flip_up_down"): 36 | bboxes = bboxes * tf.constant([-1, 1, -1, 1], dtype=tf.float32) + tf.stack([1.0, 0.0, 1.0, 0.0]) 37 | bboxes = tf.stack([bboxes[:, 2], bboxes[:, 1], bboxes[:, 0], bboxes[:, 3]], axis=1) 38 | 39 | image = tf.image.flip_up_down(image) 40 | 41 | return image, bboxes 42 | -------------------------------------------------------------------------------- /tf_image/core/bboxes/resize.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tf_image.core.convert_type_decorator import convert_type 4 | from tf_image.core.resize import random_resize_pad, random_resize 5 | 6 | 7 | @convert_type 8 | @tf.function( 9 | input_signature=[ 10 | tf.TensorSpec(shape=(None, None, None), dtype=tf.float32), 11 | tf.TensorSpec(shape=(None, 4), dtype=tf.float32), 12 | tf.TensorSpec(shape=(), dtype=tf.int32), 13 | tf.TensorSpec(shape=(), dtype=tf.int32), 14 | tf.TensorSpec(shape=(), dtype=tf.bool), 15 | tf.TensorSpec(shape=(), dtype=tf.bool), 16 | ] 17 | ) 18 | def resize(image, bboxes, height, width, keep_aspect_ratio=True, random_method=False): 19 | """ 20 | Resize given image and bounding boxes. 21 | 22 | :param image: 3-D Tensor of shape [height, width, channels]. 23 | :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax] 24 | :param height: Height of the resized image. 25 | :param width: Width of the resized image. 26 | :param keep_aspect_ratio: True if we should add padding instead to fully spreading out the image to given size 27 | :param random_method: Whether we should use random resize method (tf.image.ResizeMethod) or the default one. 28 | :return: (resized image, resized bounding boxes) 29 | """ 30 | with tf.name_scope("resize"): 31 | 32 | def _keep_aspect_ratio(img, boxes, h, w): 33 | image_shape = tf.cast(tf.shape(img), tf.float32) 34 | image_height, image_width = image_shape[0], image_shape[1] 35 | 36 | img = tf.cond( 37 | random_method, 38 | lambda: random_resize_pad(img, height, w), 39 | lambda: tf.image.resize_with_pad(img, height, w), 40 | ) 41 | 42 | h, w = tf.cast(h, dtype=tf.float32), tf.cast(w, dtype=tf.float32) 43 | resize_coef = tf.math.minimum(h / image_height, w / image_width) 44 | resized_height, resized_width = image_height * resize_coef, image_width * resize_coef 45 | pad_y, pad_x = (h - resized_height) / 2, (w - resized_width) / 2 46 | boxes = boxes * tf.stack([resized_height, resized_width, resized_height, resized_width]) + tf.stack( 47 | [pad_y, pad_x, pad_y, pad_x,] 48 | ) 49 | 50 | boxes /= tf.stack([h, w, h, w]) 51 | 52 | return img, boxes 53 | 54 | def _dont_keep_aspect_ration(img, boxes, h, w): 55 | img = tf.cond(random_method, lambda: random_resize(img, h, w), lambda: tf.image.resize(img, (h, w)),) 56 | 57 | return img, boxes 58 | 59 | image, bboxes = tf.cond( 60 | keep_aspect_ratio, 61 | lambda: _keep_aspect_ratio(image, bboxes, height, width), 62 | lambda: _dont_keep_aspect_ration(image, bboxes, height, width), 63 | ) 64 | return image, bboxes 65 | 66 | 67 | @tf.function 68 | @convert_type 69 | def random_aspect_ratio_deformation(image, bboxes, max_squeeze=0.7, max_stretch=1.3, unify_dims=False): 70 | """ 71 | Randomly pick width or height dimension and squeeze or stretch given image in that dimension. 72 | 73 | Often, we train on a square images. To fill bigger part of this square, we can set parameter unify_dims to True. 74 | This will allow us to stretch the short side / squeeze the long size by a bigger ration that max_squeeze/max_stretch 75 | and get more squared image. 76 | 77 | :param image: 3-D Tensor of shape [height, width, channels]. 78 | :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax] 79 | :param max_squeeze: Maximum relative coefficient for squeezing an image size. (0.0 to 1.0) 80 | :param max_stretch: Maximum relative coefficient for stretching an image size. (0.0 to 1.0) 81 | :param unify_dims: overwrite max_squeeze of long side and max_stretch of short size to be able to fill a square 82 | :return: (augmented image, updated bounding boxes) 83 | """ 84 | with tf.name_scope("random_aspect_ratio_deformation"): 85 | image_shape = tf.cast(tf.shape(image), dtype=tf.float32) 86 | height, width = image_shape[0], image_shape[1] 87 | 88 | # Do we do the squeeze/stretch the y or x side? 89 | side = tf.random.uniform([], 0, 2, dtype=tf.int32) 90 | 91 | # update max squeeze / stretch of unify_dims is set to true 92 | # (if parameters can get bigger in order to fit the square better) 93 | max_squeeze_h = tf.math.maximum(max_squeeze, height / width) if unify_dims else max_stretch 94 | max_squeeze_w = tf.math.maximum(max_squeeze, width / height) if unify_dims else max_stretch 95 | max_stretch_h = tf.math.maximum(max_stretch, width / height) if unify_dims else max_stretch 96 | max_stretch_w = tf.math.maximum(max_stretch, height / width) if unify_dims else max_stretch 97 | 98 | # new size 99 | height = height * tf.cond(side == 0, lambda: tf.random.uniform([], max_squeeze_h, max_stretch_h), lambda: 1.0) 100 | width = width * tf.cond(side != 0, lambda: tf.random.uniform([], max_squeeze_w, max_stretch_w), lambda: 1.0) 101 | height, width = tf.cast(height, dtype=tf.int32), tf.cast(width, dtype=tf.int32) 102 | 103 | image, bboxes = resize(image, bboxes, height, width, keep_aspect_ratio=False) 104 | 105 | return image, bboxes 106 | 107 | 108 | @tf.function 109 | @convert_type 110 | def random_pad_to_square( 111 | image, bboxes, max_extend=0.1, 112 | ): 113 | """ 114 | Creates a square image from a given input. The final size is given by the longer input image size + random padding 115 | limited by max_extend parameter. The position of the original image inside this square is random as well. 116 | 117 | :param image: 3-D Tensor of shape [height, width, channels]. 118 | :param bboxes: 2-D Tensor (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax] 119 | :param max_extend: maximal free space that could be added to the bigger side of the image 120 | :return: (padding image, updated bounding boxes) 121 | """ 122 | with tf.name_scope("random_pad_to_square"): 123 | height, width = tf.shape(image)[0], tf.shape(image)[1] 124 | 125 | # how much empty space we add to the longer side 126 | max_extend = tf.math.maximum(height, width) // tf.cast(100 * max_extend, dtype=tf.int32) 127 | extend = tf.cond( 128 | tf.math.greater_equal(max_extend, 1), 129 | lambda: tf.random.uniform([], 0, max_extend, dtype=tf.int32), 130 | lambda: 0, 131 | ) 132 | 133 | # find out, how much we can extend the longer side and shift the shorted side at most 134 | max_padding_top = tf.math.maximum(extend, width - height) 135 | max_padding_left = tf.math.maximum(extend, height - width) 136 | 137 | # now, take a random padding values 138 | padding_top = tf.cond( 139 | tf.math.greater(max_padding_top, 1), 140 | lambda: tf.random.uniform([], 0, max_padding_top, dtype=tf.int32), 141 | lambda: 0, 142 | ) 143 | padding_left = tf.cond( 144 | tf.math.greater(max_padding_left, 1), 145 | lambda: tf.random.uniform([], 0, max_padding_left, dtype=tf.int32), 146 | lambda: 0, 147 | ) 148 | 149 | # this will be the final width and height of the image 150 | size = tf.math.maximum(height + padding_top, width + padding_left) 151 | 152 | # pad image on all sides to get a square image 153 | image = tf.pad( 154 | image, [[padding_top, size - padding_top - height], [padding_left, size - padding_left - width], [0, 0]] 155 | ) 156 | 157 | # we need to cast all dimensions to float co continue with bounding box calculations 158 | height, width = tf.cast(height, dtype=tf.float32), tf.cast(width, dtype=tf.float32) 159 | padding_top, padding_left = tf.cast(padding_top, dtype=tf.float32), tf.cast(padding_left, dtype=tf.float32) 160 | size = tf.cast(size, dtype=tf.float32) 161 | 162 | # update positions of bounding boxes 163 | padding = tf.stack([padding_top, padding_left, padding_top, padding_left,]) 164 | bboxes = bboxes * tf.stack([height, width, height, width]) + padding 165 | 166 | bboxes /= tf.stack([size, size, size, size]) 167 | 168 | return image, bboxes 169 | -------------------------------------------------------------------------------- /tf_image/core/bboxes/rotate.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_addons as tfa 3 | 4 | from math import pi 5 | 6 | from tf_image.core.bboxes.clip import clip_bboxes 7 | from tf_image.core.convert_type_decorator import convert_type 8 | 9 | 10 | @tf.function 11 | @convert_type 12 | def random_rotate(image, bboxes, min_rotate=-20, max_rotate=20): 13 | """ 14 | Randomly Rotates image and bonding boxes. 15 | The rotation degree is chosen from provided range with an uniform probability. 16 | 17 | We do not cut any part of the image. Zeros padding is added around to fill the empty space after rotation. 18 | 19 | Rotating bounding boxes has one significant drawback. The result (except few special cases) 20 | is not a proper bounding box! The ages are not vertical and horizontal. We need to fix that. 21 | We took the same approach as with the whole image. To be sure no part of the object is left out, 22 | we make a bounding box around the rotated bounding box. Unfortunately this approach means that 23 | we also increase its size. 24 | 25 | ########################### 26 | ## ### ## 27 | ## ## #### ## 28 | ## ## ## ## 29 | ## ## ## ## 30 | ## ## ## ## 31 | ## ## ## ## 32 | ## ## ## ## 33 | ## ## ## ## 34 | ## ## ## ## 35 | ## #### ## ## 36 | ## ### ## 37 | ########################### 38 | 39 | :param image: 3-D Tensor of shape (height, width, channels). 40 | :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax]. 41 | :param min_rotate: Maximal scalar angle in degrees that can be used for counterclockwise rotation. 42 | :param max_rotate: Minimal scalar angle in degrees that can be used for counterclockwise rotation. 43 | :return: (rotated image, rotated bounding boxes) 44 | """ 45 | if min_rotate == 0 and max_rotate == 0: 46 | return image, bboxes 47 | 48 | if min_rotate >= max_rotate: 49 | raise ValueError(f"Minimum has to be greater than maximum! {min_rotate} {max_rotate}") 50 | 51 | with tf.name_scope("random_rotate"): 52 | rotate_mean = (min_rotate + max_rotate) / 2 53 | rotate_stdev = (max_rotate - min_rotate) / 4 54 | rotate = tf.random.truncated_normal([], mean=rotate_mean, stddev=rotate_stdev) 55 | image, bboxes = _rotate(image, bboxes, rotate) 56 | 57 | return image, bboxes 58 | 59 | 60 | @tf.function 61 | def _rotate(image, bboxes, angle): 62 | image_height, image_width = tf.shape(image)[0], tf.shape(image)[1] 63 | rotate = angle * pi / 180 + tf.keras.backend.epsilon() 64 | 65 | # find the new width and height bounds 66 | abs_cos = tf.math.abs(tf.math.cos(rotate)) 67 | abs_sin = tf.math.abs(tf.math.sin(rotate)) 68 | bound_h = tf.cast( 69 | tf.cast(image_height, dtype=tf.float32) * abs_cos + tf.cast(image_width, dtype=tf.float32) * abs_sin, 70 | dtype=tf.int32, 71 | ) 72 | bound_w = tf.cast( 73 | tf.cast(image_height, dtype=tf.float32) * abs_sin + tf.cast(image_width, dtype=tf.float32) * abs_cos, 74 | dtype=tf.int32, 75 | ) 76 | 77 | # if the new bounds are bigger than the old ones on some side, add some padding 78 | pad_bound_h = tf.math.maximum(image_height, bound_h) 79 | pad_bound_w = tf.math.maximum(image_width, bound_w) 80 | pad_y = (pad_bound_h - image_height) // 2 81 | pad_x = (pad_bound_w - image_width) // 2 82 | 83 | image = tf.image.pad_to_bounding_box(image, pad_y, pad_x, pad_bound_h, pad_bound_w) 84 | 85 | bboxes_resize = [image_height, image_width, image_height, image_width] 86 | bboxes_pad = [pad_y, pad_x, pad_y, pad_x] 87 | bboxes = tf.cast(bboxes * bboxes_resize + bboxes_pad, dtype=tf.float32) 88 | bboxes_points = tf.map_fn(_unpack_bbox, bboxes) 89 | bboxes_points = _rotate_points(bboxes_points, -rotate, image) 90 | bboxes = tf.map_fn( 91 | _find_bbox, 92 | bboxes_points, 93 | ) 94 | 95 | image = tfa.image.rotate(image, rotate) 96 | image_height, image_width = tf.shape(image)[0], tf.shape(image)[1] 97 | 98 | bboxes_resize = tf.cast(tf.stack([image_height, image_width, image_height, image_width]), dtype=tf.float32) 99 | bboxes = bboxes / bboxes_resize 100 | 101 | # if the new bounds are smaller than the old ones on some side, remove the empty space 102 | clip_y, clip_x = image_height - bound_h, image_width - bound_w 103 | clip_args = [clip_y // 2, clip_x // 2, image_height - clip_y, image_width - clip_x] 104 | clip_args_rel = [ 105 | tf.cast(clip_args[0] / image_height, dtype=tf.float32), 106 | tf.cast(clip_args[1] / image_width, dtype=tf.float32), 107 | tf.cast(clip_args[2] / image_height, dtype=tf.float32), 108 | tf.cast(clip_args[3] / image_width, dtype=tf.float32), 109 | ] 110 | image, bboxes = tf.cond( 111 | tf.math.logical_or(tf.math.greater(clip_y, 0), tf.math.greater(clip_x, 0)), 112 | lambda: (tf.image.crop_to_bounding_box(image, *clip_args), clip_bboxes(bboxes, *clip_args_rel)), 113 | lambda: (image, bboxes), 114 | ) 115 | 116 | return image, bboxes 117 | 118 | 119 | @tf.function 120 | def _unpack_bbox(bbox): 121 | """ 122 | Translate bounding box into corner coordinates. 123 | 124 | :param bbox: Bounding box of a shape [ymin, xmin, ymax, xmax]. 125 | :return: List of corner coordinates. 126 | """ 127 | ymin, xmin, ymax, xmax = bbox[0], bbox[1], bbox[2], bbox[3] 128 | return tf.stack([[ymin, xmin], [ymin, xmax], [ymax, xmin], [ymax, xmax]]) 129 | 130 | 131 | @tf.function 132 | def _find_bbox(points): 133 | """ 134 | Return smallest bounding box containing all given points. 135 | 136 | :param points: List of 2D points [y, x], 137 | :return: Bounding box of a shape [ymin, xmin, ymax, xmax]. 138 | """ 139 | return tf.stack( 140 | [ 141 | tf.math.reduce_min(points[:, 0]), 142 | tf.math.reduce_min(points[:, 1]), 143 | tf.math.reduce_max(points[:, 0]), 144 | tf.math.reduce_max(points[:, 1]), 145 | ] 146 | ) 147 | 148 | 149 | @tf.function 150 | def _rotate_points(points, angle, image): 151 | """ 152 | Rotate all points in a given list around a center of given image. 153 | 154 | :param points: List of 2D points [y, x]. 155 | :param angle: Angle in radians. 156 | :param image: A reference image. 157 | :return: List of rotated points. 158 | """ 159 | image_height, image_width = tf.shape(image)[0], tf.shape(image)[1] 160 | center = tf.cast(tf.stack([image_height / 2, image_width / 2]), dtype=tf.float32) 161 | rotation_matrix = tf.stack([tf.math.cos(angle), -tf.math.sin(angle), tf.math.sin(angle), tf.math.cos(angle)]) 162 | rotation_matrix = tf.reshape(rotation_matrix, (2, 2)) 163 | points = tf.matmul(points - center, rotation_matrix) + center 164 | return points 165 | 166 | 167 | @tf.function 168 | @convert_type 169 | def rot90(image, bboxes, k=(0, 1, 2, 3)): 170 | """ 171 | Rotate image and bounding boxes counter-clockwise by random multiple of 90 degrees. 172 | 173 | :param image: 3-D Tensor of shape [height, width, channels] 174 | :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax] 175 | :param k: array with multiples of 90 to choose from 176 | :return: (rotated image, rotated bounding boxes) 177 | """ 178 | with tf.name_scope("rot90"): 179 | selected_k = tf.math.floormod(tf.random.shuffle(k)[0], 4) 180 | 181 | image = tf.image.rot90(image, k=selected_k) 182 | 183 | rotate_bboxes = [ 184 | lambda: bboxes, 185 | lambda: tf.stack( 186 | [tf.math.subtract(1.0, bboxes[:, 3]), bboxes[:, 0], tf.math.subtract(1.0, bboxes[:, 1]), bboxes[:, 2]], 187 | axis=1, 188 | ), 189 | lambda: tf.math.subtract( 190 | 1.0, 191 | tf.stack( 192 | [ 193 | bboxes[:, 2], 194 | bboxes[:, 3], 195 | bboxes[:, 0], 196 | bboxes[:, 1], 197 | ], 198 | axis=1, 199 | ), 200 | ), 201 | lambda: tf.stack( 202 | [bboxes[:, 1], tf.math.subtract(1.0, bboxes[:, 2]), bboxes[:, 3], tf.math.subtract(1.0, bboxes[:, 0])], 203 | axis=1, 204 | ), 205 | ] 206 | 207 | bboxes = tf.cond( 208 | tf.greater(tf.shape(bboxes)[0], 0), 209 | lambda: tf.switch_case(selected_k, rotate_bboxes), 210 | lambda: bboxes, 211 | ) 212 | 213 | return image, bboxes 214 | 215 | 216 | @tf.function 217 | @convert_type 218 | def rot45(image, bboxes): 219 | """ 220 | Rotate image and bounding boxes counter-clockwise by 45 degrees. 221 | 222 | :param image: 3-D Tensor of shape [height, width, channels] 223 | :param bboxes: 2-D Tensor of shape (box_number, 4) containing bounding boxes in format [ymin, xmin, ymin, xmax] 224 | :return: (rotated image, rotated bounding boxes) 225 | """ 226 | with tf.name_scope("rot45"): 227 | return _rotate(image, bboxes, 45) 228 | -------------------------------------------------------------------------------- /tf_image/core/clip.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | @tf.function 5 | def clip_random(image, min_shape): 6 | """ 7 | Randomly cuts out part of the image. Useful for images with no bounding boxes. It provides additional parameter 8 | for minimum size which is not needed when we have bounding boxes. 9 | 10 | If the height or width of an image is smaller than min_shape, we keep the given dimension. 11 | 12 | :param image: 3-D Tensor of shape (height, width, channels). 13 | :param min_shape: smallest image cut size, (height, width). 14 | :return: clipped image 15 | """ 16 | img_height, img_width = tf.shape(image)[0], tf.shape(image)[1] 17 | min_height, min_width = min_shape[0], min_shape[1] 18 | 19 | height = tf.cond( 20 | tf.math.greater(img_height, min_height), 21 | lambda: tf.random.uniform([], min_height, img_height, dtype=tf.int32), 22 | lambda: img_height, 23 | ) 24 | width = tf.cond( 25 | tf.math.greater(img_width, min_width), 26 | lambda: tf.random.uniform([], min_width, img_width, dtype=tf.int32), 27 | lambda: img_width, 28 | ) 29 | 30 | image = tf.image.random_crop(image, size=(height, width, tf.shape(image)[-1])) 31 | return image 32 | 33 | 34 | -------------------------------------------------------------------------------- /tf_image/core/colors.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tf_image.core.convert_type_decorator import convert_type 4 | from tf_image.core.random import random_choice 5 | 6 | 7 | @tf.function 8 | def channel_swap(image): 9 | """ 10 | Randomly swaps image channels. 11 | 12 | :param image: An image, last dimension is a channel. 13 | :return: Image with swapped channels. 14 | """ 15 | indices = tf.range(start=0, limit=3, dtype=tf.int32) 16 | shuffled_indices = tf.random.shuffle(indices) 17 | image = tf.gather(image, shuffled_indices, axis=2) 18 | return image 19 | 20 | 21 | @tf.function 22 | def channel_drop(image): 23 | """ 24 | Randomly drops one image channels. 25 | 26 | :param image: An image, last dimension is a channel. 27 | :return: Image with a dropped channel. 28 | """ 29 | orig_dtype = image.dtype 30 | 31 | r, g, b = tf.split(image, 3, axis=2) 32 | zeros = tf.zeros_like(r, dtype=orig_dtype) 33 | 34 | indexes_r = tf.concat([zeros, g, b], axis=2) 35 | indexes_g = tf.concat([r, zeros, b], axis=2) 36 | indexes_b = tf.concat([r, g, zeros], axis=2) 37 | 38 | image = random_choice([indexes_r, indexes_g, indexes_b], 1)[0] 39 | return image 40 | 41 | 42 | @tf.function 43 | @convert_type 44 | def rgb_shift(image, r_shift=0.0, g_shift=0.0, b_shift=0.0): 45 | """ 46 | Randomly shift channels in a given image. 47 | 48 | :param image: An image, last dimension is a channel. 49 | :param r_shift: Maximal red shift delta. Range: from 0.0 to 1.0. 50 | :param g_shift: Maximal green shift delta. Range: from 0.0 to 1.0. 51 | :param b_shift: Maximal blue shift delta. Range: from 0.0 to 1.0. 52 | :return: Augmented image. 53 | """ 54 | r, g, b = tf.split(image, 3, axis=2) 55 | 56 | r = r + tf.random.uniform([], -r_shift, r_shift) 57 | g = g + tf.random.uniform([], -g_shift, g_shift) 58 | b = b + tf.random.uniform([], -b_shift, b_shift) 59 | 60 | image = tf.concat([r, g, b], axis=2) 61 | return image 62 | 63 | 64 | @tf.function 65 | @convert_type 66 | def grayscale(image): 67 | """ 68 | Convert image to grayscale, but keep 3 dimensions. 69 | 70 | :param image: An image. 71 | :return: Grayscale image. 72 | """ 73 | image = tf.image.rgb_to_grayscale(image) # this will create one dimension 74 | image = tf.image.grayscale_to_rgb(image) # this will create three dimension again 75 | return image 76 | -------------------------------------------------------------------------------- /tf_image/core/convert_type_decorator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from functools import wraps 3 | 4 | 5 | def convert_type(function): 6 | """ 7 | Often, we need the input image and bounding boxes to have a specific dtype and format. 8 | This decorator converts image (and bounding boxes) and provides it to decorated function. 9 | After teh function is done, reverse conversion is done to return the same format for a user of the function. 10 | 11 | Image formats are the standard one accepted by TensorFlow. For Bounding boxes, we use: 12 | - integer types for absolute coordinates or 13 | - float types for relative coordinates. 14 | 15 | Be careful, this decorator expects specific format of function parameters: 16 | - without bounding boxes: image, kwargs 17 | - with bouding boxes: image, bounding boxes, other args, kwargs 18 | 19 | :param function: function to be decorated, see the requirements in the description of the decorator! 20 | :return: decorated function 21 | """ 22 | 23 | @wraps(function) 24 | def wrap(image, *args, **kwargs): 25 | image_type = image.dtype 26 | image = tf.image.convert_image_dtype(image, tf.float32) 27 | 28 | if len(args) >= 1: 29 | bboxes = args[0] 30 | bboxes_type = bboxes.dtype 31 | bboxes_absolute = bboxes_type.is_integer 32 | 33 | bboxes = tf.cast(bboxes, tf.float32) 34 | bboxes = _bboxes_to_relative(image, bboxes) if bboxes_absolute else bboxes 35 | bboxes = tf.clip_by_value(bboxes, 0.0, 1.0) 36 | 37 | image, bboxes = function(image, bboxes, *args[1:], **kwargs) 38 | image = tf.clip_by_value(image, 0.0, 1.0) 39 | image = tf.image.convert_image_dtype(image, image_type, saturate=True) 40 | 41 | bboxes = tf.clip_by_value(bboxes, 0.0, 1.0) 42 | bboxes = _bboxes_to_absolute(image, bboxes) if bboxes_absolute else bboxes 43 | bboxes = tf.cast(bboxes, bboxes_type) 44 | 45 | return image, bboxes 46 | 47 | image = function(image, **kwargs) 48 | image = tf.clip_by_value(image, 0.0, 1.0) 49 | image = tf.image.convert_image_dtype(image, image_type, saturate=True) 50 | return image 51 | 52 | return wrap 53 | 54 | 55 | @tf.function 56 | def _bboxes_to_relative(image, bboxes): 57 | image_height, image_width = tf.shape(image)[0], tf.shape(image)[1] 58 | bboxes_update = tf.cast(tf.stack([image_height, image_width, image_height, image_width]), dtype=tf.float32) 59 | return bboxes / bboxes_update 60 | 61 | 62 | @tf.function 63 | def _bboxes_to_absolute(image, bboxes): 64 | image_height, image_width = tf.shape(image)[0], tf.shape(image)[1] 65 | bboxes_update = tf.cast(tf.stack([image_height, image_width, image_height, image_width]), dtype=tf.float32) 66 | return bboxes * bboxes_update 67 | -------------------------------------------------------------------------------- /tf_image/core/erase.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tf_image.core.convert_type_decorator import convert_type 4 | 5 | 6 | @tf.function 7 | @convert_type 8 | def random_erasing(image, max_area=0.1, erased_value=0): 9 | """ 10 | Randomly removes rectangular part of given image. 11 | 12 | :param image: An image. (Float 0-1 or integer 0-255.) 13 | :param max_area: Maximum part of the image to be erased. (Range: 0.0 to 1.0) 14 | :param erased_value: The value which will be used for the empty area. 15 | :return: Augmented image. 16 | """ 17 | image_height, image_width = tf.shape(image)[-3], tf.shape(image)[-2] 18 | max_area = tf.cast(max_area * tf.cast(image_height * image_width, tf.float32), tf.int32) 19 | 20 | return tf.cond(tf.greater_equal(max_area, 1), lambda: _random_erasing(image, max_area, erased_value), lambda: image) 21 | 22 | 23 | @tf.function 24 | def _random_erasing(image, max_area, erased_value): 25 | image_height, image_width = tf.shape(image)[-3], tf.shape(image)[-2] 26 | 27 | # Get center of the rectangle to be removed. 28 | y = tf.random.uniform([], 1, image_height - 2, dtype=tf.int32) 29 | x = tf.random.uniform([], 1, image_width - 2, dtype=tf.int32) 30 | 31 | # Functions fo calculating the size of the erased space. 32 | def random(max_val): 33 | return tf.cond(tf.greater(max_val, 1), lambda: tf.random.uniform([], 1, max_val, dtype=tf.int32), lambda: 1) 34 | 35 | def get_size(center1, center2, max1, max2): 36 | size1 = random(tf.math.reduce_min([center1, max1 - center1, max_area])) 37 | size2 = random(tf.math.reduce_min([center2, max2 - center2, max_area // size1])) 38 | return size1, size2 39 | 40 | def swap(size1, size2): 41 | return size2, size1 42 | 43 | # If we use only one of those, we would get a lot vertical / horizontal rectangles. 44 | # Changing the first generated size, we fix the distribution. 45 | height, width = tf.cond( 46 | tf.math.greater(tf.random.uniform([], 0, 1), 0.5), 47 | lambda: get_size(y, x, image_height, image_width), 48 | lambda: swap(*get_size(x, y, image_width, image_height)), 49 | ) 50 | 51 | # Crate mask for generated rectangle. 52 | mask = tf.ones((height, width), dtype=image.dtype) 53 | top, left = y - height // 2, x - width // 2 54 | mask = tf.pad(mask, [[top, image_height - top - height], [left, image_width - left - width]]) 55 | mask = tf.image.grayscale_to_rgb(tf.expand_dims(mask, -1)) 56 | 57 | # Now, we can erase the rectangle from the image. 58 | image = image * (1 - mask) + mask * erased_value 59 | return image 60 | -------------------------------------------------------------------------------- /tf_image/core/quality.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tf_image.core.convert_type_decorator import convert_type 4 | 5 | 6 | @tf.function 7 | @convert_type 8 | def gaussian_noise(image, stddev_max=0.1): 9 | """ 10 | Add a Gaussian noise into a given image. 11 | 12 | :param image: An image. (Float 0-1 or integer 0-255.) 13 | :param stddev_max: Standard deviation maximum for added Gaussian noise. Range: from 0.0 to 1.0. 14 | :return: Image with a Gaussian noise. 15 | """ 16 | stddev = tf.random.uniform([], 0.0, stddev_max) 17 | noise = tf.random.normal(shape=tf.shape(image), mean=0, stddev=stddev) 18 | image = image + noise 19 | 20 | return image 21 | -------------------------------------------------------------------------------- /tf_image/core/random.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | @tf.function 5 | def random_choice(x, size, axis=0): 6 | dim_x = tf.cast(tf.shape(x)[axis], tf.int64) 7 | indices = tf.range(0, dim_x, dtype=tf.int64) 8 | sample_index = tf.random.shuffle(indices)[:size] 9 | sample = tf.gather(x, sample_index, axis=axis) 10 | return sample 11 | 12 | 13 | def random_function(image, function, prob, seed=None, **kwargs): 14 | with tf.name_scope("random_" + function.__name__): 15 | uniform_random = tf.random.uniform([], 0, 1.0, seed=seed) 16 | mirror_cond = tf.math.less(uniform_random, prob) 17 | result = tf.cond(mirror_cond, lambda: function(image, **kwargs), lambda: image) 18 | return result 19 | 20 | 21 | def random_function_bboxes(image, bboxes, function, prob, seed=None, **kwargs): 22 | with tf.name_scope("random_" + function.__name__): 23 | uniform_random = tf.random.uniform([], 0, 1.0, seed=seed) 24 | mirror_cond = tf.math.less(uniform_random, prob) 25 | result = tf.cond(mirror_cond, lambda: function(image, bboxes, **kwargs), lambda: (image, bboxes)) 26 | return result -------------------------------------------------------------------------------- /tf_image/core/resize.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | @tf.function 5 | def random_resize_pad(images, height, width): 6 | methods = { 7 | 0: lambda: tf.image.resize_with_pad(images, height, width, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR), 8 | 1: lambda: tf.image.resize_with_pad(images, height, width, method=tf.image.ResizeMethod.BICUBIC), 9 | 2: lambda: tf.image.resize_with_pad(images, height, width, method=tf.image.ResizeMethod.AREA), 10 | 3: lambda: tf.image.resize_with_pad(images, height, width, method=tf.image.ResizeMethod.LANCZOS3), 11 | 4: lambda: tf.image.resize_with_pad(images, height, width, method=tf.image.ResizeMethod.LANCZOS5), 12 | 5: lambda: tf.image.resize_with_pad(images, height, width, method=tf.image.ResizeMethod.MITCHELLCUBIC), 13 | 6: lambda: tf.image.resize_with_pad(images, height, width, method=tf.image.ResizeMethod.GAUSSIAN), 14 | } 15 | return tf.switch_case(tf.cast(tf.random.uniform([], 0, 1.0) * len(methods), tf.int32), branch_fns=methods) 16 | 17 | 18 | @tf.function 19 | def random_resize(images, height, width): 20 | methods = { 21 | 0: lambda: tf.image.resize(images, (height, width), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR), 22 | 1: lambda: tf.image.resize(images, (height, width), method=tf.image.ResizeMethod.BICUBIC), 23 | 2: lambda: tf.image.resize(images, (height, width), method=tf.image.ResizeMethod.AREA), 24 | 3: lambda: tf.image.resize(images, (height, width), method=tf.image.ResizeMethod.LANCZOS3), 25 | 4: lambda: tf.image.resize(images, (height, width), method=tf.image.ResizeMethod.LANCZOS5), 26 | 5: lambda: tf.image.resize(images, (height, width), method=tf.image.ResizeMethod.MITCHELLCUBIC), 27 | 6: lambda: tf.image.resize(images, (height, width), method=tf.image.ResizeMethod.GAUSSIAN), 28 | } 29 | return tf.switch_case(tf.cast(tf.random.uniform([], 0, 1.0) * len(methods), tf.int32), branch_fns=methods) 30 | --------------------------------------------------------------------------------