├── .github └── workflows │ └── publish-to-test-pypi.yml ├── .gitignore ├── .vscode └── settings.json ├── LICENSE.md ├── README.md ├── audios └── sample-6s.wav ├── images ├── cat.jpg ├── dog.jpg └── smelu.png ├── requirements.txt ├── setup.py └── tensorflow_extra ├── __init__.py ├── activations.py ├── layers.py ├── utils.py └── version.py /.github/workflows/publish-to-test-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build-n-publish: 7 | name: Build and publish Python 🐍 distributions 📦 to PyPI 8 | runs-on: ubuntu-18.04 9 | steps: 10 | - uses: actions/checkout@master 11 | - name: Set up Python 3.6 12 | uses: actions/setup-python@v1 13 | with: 14 | python-version: 3.6 15 | - name: Install pypa/build 16 | run: >- 17 | python -m 18 | pip install 19 | build 20 | --user 21 | - name: Build a binary wheel and a source tarball 22 | run: >- 23 | python -m 24 | build 25 | --sdist 26 | --wheel 27 | --outdir dist/ 28 | . 29 | - name: Publish distribution 📦 to PyPI 30 | if: startsWith(github.ref, 'refs/tags') 31 | uses: pypa/gh-action-pypi-publish@master 32 | with: 33 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.linting.enabled": false 3 | } -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2021 Awsaf 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Extra 2 | > TensorFlow GPU & TPU compatible operations: MelSpectrogram, TimeFreqMask, CutMix, MixUp, ZScore, and more 3 | 4 | # Installation 5 | For Stable version 6 | ```shell 7 | !pip install tensorflow-extra 8 | ``` 9 | or 10 | For updated version 11 | ```shell 12 | !pip install git+https://github.com/awsaf49/tensorflow_extra 13 | ``` 14 | 15 | # Usage 16 | To check use case of this library, checkout [BirdCLEF23: Pretraining is All you Need](https://www.kaggle.com/code/awsaf49/birdclef23-pretraining-is-all-you-need-train) notebook. It uses this library along with **Multi Stage Transfer Learning** for Bird Call Identification task. 17 | 18 | 19 | # Layers 20 | ## MelSpectrogram 21 | Converts audio data to mel-spectrogram in GPU/TPU. 22 | ```py 23 | import tensorflow_extra as tfe 24 | audio2spec = tfe.layers.MelSpectrogram() 25 | spec = audio2spec(audio) 26 | ``` 27 | 28 | 29 | 30 | 31 | ## Time Frequency Masking 32 | Can also control number of stripes. 33 | ```py 34 | time_freq_mask = tfe.layers.TimeFreqMask() 35 | spec = time_freq_mask(spec) 36 | ``` 37 | 38 | 39 | ## CutMix 40 | Can be used with audio, spec, image. For spec full freq resolution can be used using `full_height=True`. 41 | ```py 42 | cutmix = tfe.layers.CutMix() 43 | audio = cutmix(audio, training=True) # accepts both audio & spectrogram 44 | ``` 45 | 46 | 47 | 48 | ## MixUp 49 | Can be used with audio, spec, image. For spec full freq resolution can be used using `full_height=True`. 50 | ```py 51 | mixup = tfe.layers.MixUp() 52 | audio = mixup(audio, training=True) # accepts both audio & spectrogram 53 | ``` 54 | 55 | 56 | 57 | 58 | ## Normalization 59 | Applies standardization and rescaling. 60 | ```py 61 | norm = tfe.layers.ZScoreMinMax() 62 | spec = norm(spec) 63 | ``` 64 | 65 | 66 | 67 | # Activations 68 | ## SmeLU: Smooth ReLU 69 | ```py 70 | import tensorflow as tf 71 | import tensorflow_extra as tfe 72 | 73 | a = tf.constant([-2.5, -1.0, 0.5, 1.0, 2.5]) 74 | b = tfe.activations.smelu(a) # array([0., 0.04166667, 0.6666667 , 1.0416666 , 2.5]) 75 | ``` 76 | 77 | -------------------------------------------------------------------------------- /audios/sample-6s.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awsaf49/tensorflow_extra/d7646de433cab9b671154c2e07192c9ae28bd4ba/audios/sample-6s.wav -------------------------------------------------------------------------------- /images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awsaf49/tensorflow_extra/d7646de433cab9b671154c2e07192c9ae28bd4ba/images/cat.jpg -------------------------------------------------------------------------------- /images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awsaf49/tensorflow_extra/d7646de433cab9b671154c2e07192c9ae28bd4ba/images/dog.jpg -------------------------------------------------------------------------------- /images/smelu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/awsaf49/tensorflow_extra/d7646de433cab9b671154c2e07192c9ae28bd4ba/images/smelu.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from codecs import open 3 | from os import path 4 | 5 | here = path.abspath(path.dirname(__file__)) 6 | 7 | # Get the long description from the README file 8 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 9 | long_description = f.read() 10 | 11 | with open(path.join(here, "requirements.txt")) as f: 12 | install_requires = [x for x in f.read().splitlines() if len(x)] 13 | 14 | exec(open("tensorflow_extra/version.py").read()) 15 | 16 | setup( 17 | name="tensorflow_extra", 18 | version=__version__, 19 | description="Tensorflow Extra Utilities. https://github.com/awsaf49/tensorflow_extra", 20 | long_description=long_description, 21 | long_description_content_type="text/markdown", 22 | url="https://github.com/awsaf49/tensorflow_extra", 23 | author="Awsaf", 24 | author_email="awsaf49@gmail.com", 25 | classifiers=[ 26 | # How mature is this project? Common values are 27 | # 3 - Alpha 28 | # 4 - Beta 29 | # 5 - Production/Stable 30 | "Development Status :: 3 - Alpha", 31 | "Intended Audience :: Developers", 32 | "Intended Audience :: Science/Research", 33 | "License :: OSI Approved :: Apache Software License", 34 | "Programming Language :: Python :: 3.6", 35 | "Programming Language :: Python :: 3.7", 36 | "Programming Language :: Python :: 3.8", 37 | "Topic :: Scientific/Engineering", 38 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 39 | "Topic :: Software Development", 40 | "Topic :: Software Development :: Libraries", 41 | "Topic :: Software Development :: Libraries :: Python Modules", 42 | ], 43 | # Note that this is a string of words separated by whitespace, not a list. 44 | keywords="tensorflow extra utilities", 45 | packages=find_packages(exclude=["tests"]), 46 | include_package_data=True, 47 | install_requires=install_requires, 48 | python_requires=">=3.6", 49 | license="MIT", 50 | ) 51 | -------------------------------------------------------------------------------- /tensorflow_extra/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from tensorflow_extra import activations 3 | from tensorflow_extra import layers 4 | -------------------------------------------------------------------------------- /tensorflow_extra/activations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra") 4 | def smelu(x, beta=1.5): 5 | """Smooth ReLU (SmeLU): Smooth activations and reproducibility in deep networks, https://arxiv.org/abs/2010.09931 6 | 7 | Args: 8 | x : numpy or tensorflow tensor 9 | beta (float): smooth value. Defaults to 1.5. 10 | 11 | Returns: 12 | tensorflow tensor 13 | """ 14 | x = tf.convert_to_tensor(x) 15 | return tf.where( 16 | tf.math.abs(x) <= beta, ((x + beta) ** 2) / (4 * beta), tf.nn.relu(x) 17 | ) 18 | -------------------------------------------------------------------------------- /tensorflow_extra/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | 4 | 5 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra") 6 | class MelSpectrogram(tf.keras.layers.Layer): 7 | """ 8 | Mel Spectrogram Layer to convert audio to mel spectrogram which works with single or batched inputs. 9 | 10 | Args: 11 | n_fft (int): Size of the FFT window. 12 | hop_length (int): Number of samples between successive STFT columns. 13 | win_length (int): Size of the STFT window. If None, defaults to n_fft. 14 | window_fn (str): Name of the window function to use. 15 | sr (int): Sample rate of the input signal. 16 | n_mels (int): Number of mel bins to generate. 17 | fmin (float): Minimum frequency of the mel bins. 18 | fmax (float): Maximum frequency of the mel bins. If None, defaults to sr / 2. 19 | power (float): Exponent for the magnitude spectrogram. 20 | power_to_db (bool): Whether to convert the power spectrogram to decibels. 21 | top_db (float): Maximum decibel value for the output spectrogram. 22 | power_to_db (bool): Whether to convert spectrogram from energy to power. 23 | out_channels (int): Number of output channels. If None, no channel is created. 24 | 25 | Call Args: 26 | input (tf.Tensor): Audio signal of shape (audio_len,) or (None, audio_len) 27 | 28 | Returns: 29 | tf.Tensor: Mel spectrogram of shape (..., n_mels, time, out_channels) 30 | or (..., n_mels, time) if out_channels is None. 31 | 32 | """ 33 | 34 | def __init__( 35 | self, 36 | n_fft=2048, 37 | hop_length=512, 38 | win_length=None, 39 | window="hann_window", 40 | sr=16000, 41 | n_mels=128, 42 | fmin=20.0, 43 | fmax=None, 44 | power_to_db=True, 45 | top_db=80.0, 46 | power=2.0, 47 | amin=1e-10, 48 | ref=1.0, 49 | out_channels=None, 50 | name="mel_spectrogram", 51 | **kwargs, 52 | ): 53 | super(MelSpectrogram, self).__init__(name=name, **kwargs) 54 | self.n_fft = n_fft 55 | self.hop_length = hop_length 56 | self.win_length = win_length or n_fft 57 | self.window = window 58 | self.sr = sr 59 | self.n_mels = n_mels 60 | self.fmin = fmin 61 | self.fmax = fmax or int(sr / 2) 62 | self.power_to_db = power_to_db 63 | self.top_db = top_db 64 | self.power = power 65 | self.amin = amin 66 | self.ref = ref 67 | self.out_channels = out_channels 68 | 69 | @tf.function 70 | def call(self, input): 71 | spec = self.spectrogram(input) # audio to spectrogram with shape 72 | spec = self.melscale(spec) # spectrogram to mel spectrogram 73 | if self.power_to_db: 74 | spec = self.dbscale(spec) # mel spectrogram to decibel mel spectrogram 75 | spec = tf.linalg.matrix_transpose( 76 | spec 77 | ) # (..., time, n_mels) to (..., n_mels, time) 78 | if self.out_channels is not None: 79 | spec = self.update_channels(spec) 80 | return spec 81 | 82 | def spectrogram(self, input): 83 | spec = tf.signal.stft( 84 | input, 85 | frame_length=self.win_length, 86 | frame_step=self.hop_length, 87 | fft_length=self.n_fft, 88 | window_fn=getattr(tf.signal, self.window), 89 | pad_end=True, 90 | ) 91 | spec = tf.math.pow(tf.math.abs(spec), self.power) 92 | return spec 93 | 94 | def melscale(self, input): 95 | nbin = tf.shape(input)[-1] 96 | matrix = tf.signal.linear_to_mel_weight_matrix( 97 | num_mel_bins=self.n_mels, 98 | num_spectrogram_bins=nbin, 99 | sample_rate=self.sr, 100 | lower_edge_hertz=self.fmin, 101 | upper_edge_hertz=self.fmax, 102 | ) 103 | return tf.tensordot(input, matrix, axes=1) 104 | 105 | def dbscale(self, input): 106 | log_spec = 10.0 * ( 107 | tf.math.log(tf.math.maximum(input, self.amin)) / tf.math.log(10.0) 108 | ) 109 | if callable(self.ref): 110 | ref_value = self.ref(log_spec) 111 | else: 112 | ref_value = tf.math.abs(self.ref) 113 | log_spec -= ( 114 | 10.0 115 | * tf.math.log(tf.math.maximum(ref_value, self.amin)) 116 | / tf.math.log(10.0) 117 | ) 118 | log_spec = tf.math.maximum(log_spec, tf.math.reduce_max(log_spec) - self.top_db) 119 | return log_spec 120 | 121 | def update_channels(self, input): 122 | spec = input[..., tf.newaxis] 123 | if self.out_channels > 1: 124 | multiples = tf.concat( 125 | [ 126 | tf.ones(tf.rank(spec) - 1, dtype=tf.int32), 127 | tf.constant([self.out_channels], dtype=tf.int32), 128 | ], 129 | axis=0, 130 | ) 131 | spec = tf.tile(spec, multiples) 132 | return spec 133 | 134 | def get_config(self): 135 | config = super(MelSpectrogram, self).get_config() 136 | config.update( 137 | { 138 | "n_fft": self.n_fft, 139 | "hop_length": self.hop_length, 140 | "win_length": self.win_length, 141 | "window": self.window, 142 | "sr": self.sr, 143 | "n_mels": self.n_mels, 144 | "fmin": self.fmin, 145 | "fmax": self.fmax, 146 | "power_to_db": self.power_to_db, 147 | "top_db": self.top_db, 148 | "power": self.power, 149 | "amin": self.amin, 150 | "ref": self.ref, 151 | "out_channels": self.out_channels, 152 | } 153 | ) 154 | return config 155 | 156 | 157 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra") 158 | class MixUp(tf.keras.layers.Layer): 159 | """ 160 | MixUp Augmentation Layer to apply MixUp to one batch. 161 | 162 | Args: 163 | alpha (float): Alpha parameter for beta distribution. 164 | prob (float): Probability of applying MixUp. 165 | 166 | Call Args: 167 | images (tf.Tensor): Batch of images. 168 | labels (tf.Tensor): Batch of labels. 169 | 170 | Returns: 171 | tf.Tensor: Batch of image. 172 | tf.Tensor: Batch of labels. 173 | 174 | """ 175 | 176 | def __init__(self, alpha=0.2, prob=0.5, name="mix_up", **kwargs): 177 | super(MixUp, self).__init__(name=name, **kwargs) 178 | self.alpha = alpha 179 | self.prob = prob 180 | 181 | @tf.function 182 | def call(self, images, labels=None, training=False): 183 | 184 | # Skip batch if not training or if prob is not met or if labels are not provided 185 | if tf.random.uniform([]) > self.prob or not training or labels is None: 186 | return (images, labels) if labels is not None else images 187 | 188 | # Get original shape 189 | spec_shape = tf.shape(images) 190 | label_shape = tf.shape(labels) 191 | 192 | # Select lambda from beta distribution 193 | beta = tfp.distributions.Beta(self.alpha, self.alpha) 194 | lam = beta.sample(1)[0] 195 | 196 | # It's faster to roll the batch by one instead of shuffling it to create image pairs 197 | images = lam * images + (1 - lam) * tf.roll(images, shift=1, axis=0) 198 | labels = lam * labels + (1 - lam) * tf.roll(labels, shift=1, axis=0) 199 | 200 | # Ensure original shape 201 | images = tf.reshape(images, spec_shape) 202 | labels = tf.reshape(labels, label_shape) 203 | 204 | return images, labels 205 | 206 | def get_config(self): 207 | config = super(MixUp, self).get_config() 208 | config.update( 209 | { 210 | "alpha": self.alpha, 211 | "prob": self.prob, 212 | } 213 | ) 214 | return config 215 | 216 | 217 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra") 218 | class CutMix(tf.keras.layers.Layer): 219 | """ 220 | Augmentation layer to apply CutMix to one batch. 221 | 222 | Args: 223 | alpha (float): Alpha parameter for beta distribution. 224 | prob (float): Probability of applying CutMix. 225 | full_height (bool): If True, the patch will be cut with full height of the image. 226 | full_width (bool): If True, the patch will be cut with full width of the image. 227 | 228 | Call Args: 229 | images (tf.Tensor): Batch of images. 230 | labels (tf.Tensor): Batch of labels. 231 | 232 | Returns: 233 | tf.Tensor: Batch of image. 234 | tf.Tensor: Batch of labels. 235 | """ 236 | 237 | def __init__( 238 | self, 239 | alpha=0.2, 240 | prob=0.5, 241 | full_height=False, 242 | full_width=False, 243 | name="cut_mix", 244 | **kwargs, 245 | ): 246 | super(CutMix, self).__init__(name=name, **kwargs) 247 | self.alpha = alpha 248 | self.prob = prob 249 | self.full_height = full_height 250 | self.full_width = full_width 251 | 252 | @tf.function 253 | def call(self, images, labels=None, training=False): 254 | # Skip batch if not training or if prob is not met or if labels are not provided 255 | if tf.random.uniform([]) > self.prob or not training or labels is None: 256 | return (images, labels) if labels is not None else images 257 | 258 | # Ensure 4D input 259 | images, was_2d = self._ensure_4d(images) 260 | 261 | # Get original shapes 262 | image_shape = tf.shape(images) 263 | label_shape = tf.shape(labels) 264 | 265 | # Select lambda from beta distribution 266 | beta = tfp.distributions.Beta(self.alpha, self.alpha) 267 | lam = beta.sample(1)[0] 268 | 269 | # It's faster to roll the batch by one instead of shuffling it to create image pairs 270 | images_rolled = tf.roll(images, shift=1, axis=0) 271 | labels_rolled = tf.roll(labels, shift=1, axis=0) 272 | 273 | # Find dimensions of patch 274 | H = tf.cast(image_shape[1], tf.int32) 275 | W = tf.cast(image_shape[2], tf.int32) 276 | r_x = ( 277 | tf.random.uniform([], maxval=W, dtype=tf.int32) 278 | if not self.full_width 279 | else 0 280 | ) 281 | r_y = ( 282 | tf.random.uniform([], maxval=H, dtype=tf.int32) 283 | if not self.full_height 284 | else 0 285 | ) 286 | r = 0.5 * tf.math.sqrt(1.0 - lam) 287 | r_w_p = r if not self.full_width else 1.0 288 | r_h_p = r if not self.full_height else 1.0 289 | r_w_half = tf.cast(r_w_p * tf.cast(W, tf.float32), tf.int32) 290 | r_h_half = tf.cast(r_h_p * tf.cast(H, tf.float32), tf.int32) 291 | 292 | # Find the coordinates of the patch 293 | x1 = tf.cast(tf.clip_by_value(r_x - r_w_half, 0, W), tf.int32) 294 | x2 = tf.cast(tf.clip_by_value(r_x + r_w_half, 1, W), tf.int32) 295 | y1 = tf.cast(tf.clip_by_value(r_y - r_h_half, 0, H), tf.int32) 296 | y2 = tf.cast(tf.clip_by_value(r_y + r_h_half, 1, H), tf.int32) 297 | 298 | # Extract outer-pad patch -> [0, 0, 1, 1, 0, 0] 299 | patch1 = images[:, y1:y2, x1:x2, :] # [batch, height, width, channel] 300 | patch1 = tf.pad( 301 | patch1, [[0, 0], [y1, H - y2], [x1, W - x2], [0, 0]] 302 | ) # outer-pad 303 | 304 | # Extract inner-pad patch -> [2, 2, 0, 0, 2, 2] 305 | patch2 = images_rolled[:, y1:y2, x1:x2, :] 306 | patch2 = tf.pad( 307 | patch2, [[0, 0], [y1, H - y2], [x1, W - x2], [0, 0]] 308 | ) # outer-pad 309 | patch2 = images_rolled - patch2 # inner-pad = img - outer-pad 310 | 311 | # Combine patches [0, 0, 1, 1, 0, 0] + [2, 2, 0, 0, 2, 2] -> [2, 2, 1, 1, 2, 2] 312 | images = patch1 + patch2 313 | 314 | # Combine labels 315 | lam = tf.cast((1.0 - (x2 - x1) * (y2 - y1) / (W * H)), tf.float32) 316 | labels = lam * labels + (1.0 - lam) * labels_rolled 317 | 318 | # Ensure original shape 319 | images = tf.reshape(images, image_shape) 320 | labels = tf.reshape(labels, label_shape) 321 | 322 | # Ensure original shape 323 | images = self._ensure_original_shape(images, was_2d) 324 | 325 | return images, labels 326 | 327 | def _ensure_4d(self, tensor): 328 | if len(tensor.shape) == 2: 329 | tensor = tf.expand_dims(tensor, axis=1) 330 | tensor = tf.expand_dims(tensor, axis=-1) 331 | return tensor, True 332 | return tensor, False 333 | 334 | def _ensure_original_shape(self, tensor, was_2d): 335 | if was_2d: 336 | tensor = tf.squeeze(tensor, axis=-1) 337 | tensor = tf.squeeze(tensor, axis=1) 338 | return tensor 339 | 340 | def get_config(self): 341 | config = super(CutMix, self).get_config() 342 | config.update( 343 | { 344 | "alpha": self.alpha, 345 | "prob": self.prob, 346 | "full_height": self.full_height, 347 | "full_width": self.full_width, 348 | } 349 | ) 350 | return config 351 | 352 | 353 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra") 354 | class TimeFreqMask(tf.keras.layers.Layer): 355 | """ 356 | Applies Time Freq Mask to spectrogram input 357 | Ref: https://pytorch.org/audio/main/_modules/torchaudio/functional/functional.html#mask_along_axis_iid 358 | """ 359 | 360 | def __init__( 361 | self, 362 | freq_mask_prob=0.5, 363 | num_freq_masks=2, 364 | freq_mask_param=10, 365 | time_mask_prob=0.5, 366 | num_time_masks=2, 367 | time_mask_param=20, 368 | time_last=True, 369 | name="time_freq_mask", 370 | **kwargs, 371 | ): 372 | super(TimeFreqMask, self).__init__(name=name, **kwargs) 373 | self.freq_mask_prob = freq_mask_prob 374 | self.num_freq_masks = num_freq_masks 375 | self.freq_mask_param = freq_mask_param 376 | self.time_mask_prob = time_mask_prob 377 | self.num_time_masks = num_time_masks 378 | self.time_mask_param = time_mask_param 379 | self.time_last = time_last 380 | 381 | @tf.function 382 | def call(self, inputs, training=False): 383 | if not training: 384 | return inputs 385 | x = inputs 386 | # Adjust input shape 387 | ndims = tf.rank(x) 388 | shape = tf.shape(x) 389 | 390 | # if ndims == 3: 391 | # x = x[tf.newaxis, ...] 392 | # x = tf.reshape(x, shape=(1, tf.split(shape, 3))) 393 | # elif ndims == 2: 394 | # x = x[tf.newaxis, ..., tf.newaxis] 395 | # x = tf.reshape(x, shape=(1, tf.split(shape, 2), 1)) 396 | # else: 397 | # pass 398 | # elif ndims > 4 or ndims < 2: 399 | # raise ValueError("Input tensor must be 2, 3, or 4-dimensional.") 400 | # Apply time mask 401 | for _ in tf.range(self.num_time_masks): 402 | x = self.mask_along_axis_iid( 403 | x, 404 | self.time_mask_param, 405 | 0, 406 | 2 + int(self.time_last), 407 | self.time_mask_prob, 408 | ) 409 | # Apply freq mask 410 | for _ in tf.range(self.num_freq_masks): 411 | x = self.mask_along_axis_iid( 412 | x, 413 | self.freq_mask_param, 414 | 0, 415 | 2 + int(not self.time_last), 416 | self.freq_mask_prob, 417 | ) 418 | # Re-adjust output shape 419 | # if ndims == 3: 420 | # x = x[0] 421 | # elif ndims == 2: 422 | # x = x[0, ..., 0] 423 | return x 424 | 425 | def mask_along_axis_iid(self, specs, mask_param, mask_value, axis, p): 426 | if axis not in [2, 3]: 427 | raise ValueError("Only Frequency and Time masking are supported") 428 | 429 | if not 0.0 <= p <= 1.0: 430 | raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).") 431 | 432 | mask_param = mask_param # self._get_mask_param(mask_param, p, specs.shape[axis]) 433 | if tf.random.uniform([]) > p: 434 | return specs 435 | 436 | specs = tf.transpose(specs, perm=[0, 3, 1, 2]) # (batch, channel, freq, time) 437 | 438 | dtype = specs.dtype 439 | shape = tf.shape(specs) 440 | 441 | value = tf.random.uniform(shape=shape[:2], dtype=dtype) * mask_param 442 | min_value = tf.random.uniform(shape=shape[:2], dtype=dtype) * ( 443 | specs.shape[axis] - value 444 | ) 445 | 446 | # Create broadcastable mask 447 | mask_start = tf.cast(min_value, tf.float32)[..., None, None] 448 | mask_end = (tf.cast(min_value, tf.float32) + tf.cast(value, tf.float32))[ 449 | ..., None, None 450 | ] 451 | mask = tf.range(0, specs.shape[axis], dtype=dtype) 452 | 453 | # Per batch example masking 454 | specs = tf.linalg.matrix_transpose(specs) if axis == 2 else specs 455 | cond = (mask >= mask_start) & (mask < mask_end) 456 | specs = tf.where( 457 | cond, tf.fill(tf.shape(specs), tf.cast(mask_value, dtype=dtype)), specs 458 | ) 459 | specs = tf.linalg.matrix_transpose(specs) if axis == 2 else specs 460 | 461 | specs = tf.transpose(specs, perm=[0, 2, 3, 1]) # (batch, freq, time, channel) 462 | 463 | return specs 464 | 465 | def get_config(self): 466 | config = super(TimeFreqMask, self).get_config() 467 | config.update( 468 | { 469 | "freq_mask_prob": self.freq_mask_prob, 470 | "num_freq_masks": self.num_freq_masks, 471 | "freq_mask_param": self.freq_mask_param, 472 | "time_mask_prob": self.time_mask_prob, 473 | "num_time_masks": self.num_time_masks, 474 | "time_mask_param": self.time_mask_param, 475 | "time_last": self.time_last, 476 | } 477 | ) 478 | return config 479 | 480 | 481 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra") 482 | class ZScoreMinMax(tf.keras.layers.Layer): 483 | """ 484 | Applies Z-score normalization and Min-Max normalization to the input tensor. 485 | """ 486 | def __init__(self, name="z_score_min_max", **kwargs): 487 | super(ZScoreMinMax, self).__init__(name=name, **kwargs) 488 | 489 | @tf.function 490 | def call(self, inputs): 491 | # Standardize using Z-score 492 | mean = tf.math.reduce_mean(inputs) 493 | std = tf.math.reduce_std(inputs) 494 | standardized = tf.where(tf.math.equal(std, 0), inputs - mean, (inputs - mean) / std) 495 | 496 | # Normalize using Min-Max 497 | min_val = tf.math.reduce_min(standardized) 498 | max_val = tf.math.reduce_max(standardized) 499 | normalized = tf.where(tf.math.equal(max_val - min_val, 0), standardized - min_val, 500 | (standardized - min_val) / (max_val - min_val)) 501 | 502 | return normalized 503 | 504 | def get_config(self): 505 | config = super(ZScoreMinMax, self).get_config() 506 | return config -------------------------------------------------------------------------------- /tensorflow_extra/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | # Generates random integer 4 | def random_int(shape=[], minval=0, maxval=1): 5 | return tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=tf.int32) 6 | 7 | 8 | # Generats random float 9 | def random_float(shape=[], minval=0.0, maxval=1.0): 10 | rnd = tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=tf.float32) 11 | return rnd 12 | -------------------------------------------------------------------------------- /tensorflow_extra/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.0.2" 2 | --------------------------------------------------------------------------------