├── .github
└── workflows
│ └── publish-to-test-pypi.yml
├── .gitignore
├── .vscode
└── settings.json
├── LICENSE.md
├── README.md
├── audios
└── sample-6s.wav
├── images
├── cat.jpg
├── dog.jpg
└── smelu.png
├── requirements.txt
├── setup.py
└── tensorflow_extra
├── __init__.py
├── activations.py
├── layers.py
├── utils.py
└── version.py
/.github/workflows/publish-to-test-pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python 🐍 distributions 📦 to PyPI
2 |
3 | on: push
4 |
5 | jobs:
6 | build-n-publish:
7 | name: Build and publish Python 🐍 distributions 📦 to PyPI
8 | runs-on: ubuntu-18.04
9 | steps:
10 | - uses: actions/checkout@master
11 | - name: Set up Python 3.6
12 | uses: actions/setup-python@v1
13 | with:
14 | python-version: 3.6
15 | - name: Install pypa/build
16 | run: >-
17 | python -m
18 | pip install
19 | build
20 | --user
21 | - name: Build a binary wheel and a source tarball
22 | run: >-
23 | python -m
24 | build
25 | --sdist
26 | --wheel
27 | --outdir dist/
28 | .
29 | - name: Publish distribution 📦 to PyPI
30 | if: startsWith(github.ref, 'refs/tags')
31 | uses: pypa/gh-action-pypi-publish@master
32 | with:
33 | password: ${{ secrets.PYPI_API_TOKEN }}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # pipenv
87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
90 | # install all needed dependencies.
91 | #Pipfile.lock
92 |
93 | # celery beat schedule file
94 | celerybeat-schedule
95 |
96 | # SageMath parsed files
97 | *.sage.py
98 |
99 | # Environments
100 | .env
101 | .venv
102 | env/
103 | venv/
104 | ENV/
105 | env.bak/
106 | venv.bak/
107 |
108 | # Spyder project settings
109 | .spyderproject
110 | .spyproject
111 |
112 | # Rope project settings
113 | .ropeproject
114 |
115 | # mkdocs documentation
116 | /site
117 |
118 | # mypy
119 | .mypy_cache/
120 | .dmypy.json
121 | dmypy.json
122 |
123 | # Pyre type checker
124 | .pyre/
125 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.linting.enabled": false
3 | }
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 |
2 | The MIT License (MIT)
3 |
4 | Copyright (c) 2021 Awsaf
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tensorflow Extra
2 | > TensorFlow GPU & TPU compatible operations: MelSpectrogram, TimeFreqMask, CutMix, MixUp, ZScore, and more
3 |
4 | # Installation
5 | For Stable version
6 | ```shell
7 | !pip install tensorflow-extra
8 | ```
9 | or
10 | For updated version
11 | ```shell
12 | !pip install git+https://github.com/awsaf49/tensorflow_extra
13 | ```
14 |
15 | # Usage
16 | To check use case of this library, checkout [BirdCLEF23: Pretraining is All you Need](https://www.kaggle.com/code/awsaf49/birdclef23-pretraining-is-all-you-need-train) notebook. It uses this library along with **Multi Stage Transfer Learning** for Bird Call Identification task.
17 |
18 |
19 | # Layers
20 | ## MelSpectrogram
21 | Converts audio data to mel-spectrogram in GPU/TPU.
22 | ```py
23 | import tensorflow_extra as tfe
24 | audio2spec = tfe.layers.MelSpectrogram()
25 | spec = audio2spec(audio)
26 | ```
27 |
28 |
29 |
30 |
31 | ## Time Frequency Masking
32 | Can also control number of stripes.
33 | ```py
34 | time_freq_mask = tfe.layers.TimeFreqMask()
35 | spec = time_freq_mask(spec)
36 | ```
37 |
38 |
39 | ## CutMix
40 | Can be used with audio, spec, image. For spec full freq resolution can be used using `full_height=True`.
41 | ```py
42 | cutmix = tfe.layers.CutMix()
43 | audio = cutmix(audio, training=True) # accepts both audio & spectrogram
44 | ```
45 |
46 |
47 |
48 | ## MixUp
49 | Can be used with audio, spec, image. For spec full freq resolution can be used using `full_height=True`.
50 | ```py
51 | mixup = tfe.layers.MixUp()
52 | audio = mixup(audio, training=True) # accepts both audio & spectrogram
53 | ```
54 |
55 |
56 |
57 |
58 | ## Normalization
59 | Applies standardization and rescaling.
60 | ```py
61 | norm = tfe.layers.ZScoreMinMax()
62 | spec = norm(spec)
63 | ```
64 |
65 |
66 |
67 | # Activations
68 | ## SmeLU: Smooth ReLU
69 | ```py
70 | import tensorflow as tf
71 | import tensorflow_extra as tfe
72 |
73 | a = tf.constant([-2.5, -1.0, 0.5, 1.0, 2.5])
74 | b = tfe.activations.smelu(a) # array([0., 0.04166667, 0.6666667 , 1.0416666 , 2.5])
75 | ```
76 |
77 |
--------------------------------------------------------------------------------
/audios/sample-6s.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/awsaf49/tensorflow_extra/d7646de433cab9b671154c2e07192c9ae28bd4ba/audios/sample-6s.wav
--------------------------------------------------------------------------------
/images/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/awsaf49/tensorflow_extra/d7646de433cab9b671154c2e07192c9ae28bd4ba/images/cat.jpg
--------------------------------------------------------------------------------
/images/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/awsaf49/tensorflow_extra/d7646de433cab9b671154c2e07192c9ae28bd4ba/images/dog.jpg
--------------------------------------------------------------------------------
/images/smelu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/awsaf49/tensorflow_extra/d7646de433cab9b671154c2e07192c9ae28bd4ba/images/smelu.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from codecs import open
3 | from os import path
4 |
5 | here = path.abspath(path.dirname(__file__))
6 |
7 | # Get the long description from the README file
8 | with open(path.join(here, "README.md"), encoding="utf-8") as f:
9 | long_description = f.read()
10 |
11 | with open(path.join(here, "requirements.txt")) as f:
12 | install_requires = [x for x in f.read().splitlines() if len(x)]
13 |
14 | exec(open("tensorflow_extra/version.py").read())
15 |
16 | setup(
17 | name="tensorflow_extra",
18 | version=__version__,
19 | description="Tensorflow Extra Utilities. https://github.com/awsaf49/tensorflow_extra",
20 | long_description=long_description,
21 | long_description_content_type="text/markdown",
22 | url="https://github.com/awsaf49/tensorflow_extra",
23 | author="Awsaf",
24 | author_email="awsaf49@gmail.com",
25 | classifiers=[
26 | # How mature is this project? Common values are
27 | # 3 - Alpha
28 | # 4 - Beta
29 | # 5 - Production/Stable
30 | "Development Status :: 3 - Alpha",
31 | "Intended Audience :: Developers",
32 | "Intended Audience :: Science/Research",
33 | "License :: OSI Approved :: Apache Software License",
34 | "Programming Language :: Python :: 3.6",
35 | "Programming Language :: Python :: 3.7",
36 | "Programming Language :: Python :: 3.8",
37 | "Topic :: Scientific/Engineering",
38 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
39 | "Topic :: Software Development",
40 | "Topic :: Software Development :: Libraries",
41 | "Topic :: Software Development :: Libraries :: Python Modules",
42 | ],
43 | # Note that this is a string of words separated by whitespace, not a list.
44 | keywords="tensorflow extra utilities",
45 | packages=find_packages(exclude=["tests"]),
46 | include_package_data=True,
47 | install_requires=install_requires,
48 | python_requires=">=3.6",
49 | license="MIT",
50 | )
51 |
--------------------------------------------------------------------------------
/tensorflow_extra/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__
2 | from tensorflow_extra import activations
3 | from tensorflow_extra import layers
4 |
--------------------------------------------------------------------------------
/tensorflow_extra/activations.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra")
4 | def smelu(x, beta=1.5):
5 | """Smooth ReLU (SmeLU): Smooth activations and reproducibility in deep networks, https://arxiv.org/abs/2010.09931
6 |
7 | Args:
8 | x : numpy or tensorflow tensor
9 | beta (float): smooth value. Defaults to 1.5.
10 |
11 | Returns:
12 | tensorflow tensor
13 | """
14 | x = tf.convert_to_tensor(x)
15 | return tf.where(
16 | tf.math.abs(x) <= beta, ((x + beta) ** 2) / (4 * beta), tf.nn.relu(x)
17 | )
18 |
--------------------------------------------------------------------------------
/tensorflow_extra/layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow_probability as tfp
3 |
4 |
5 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra")
6 | class MelSpectrogram(tf.keras.layers.Layer):
7 | """
8 | Mel Spectrogram Layer to convert audio to mel spectrogram which works with single or batched inputs.
9 |
10 | Args:
11 | n_fft (int): Size of the FFT window.
12 | hop_length (int): Number of samples between successive STFT columns.
13 | win_length (int): Size of the STFT window. If None, defaults to n_fft.
14 | window_fn (str): Name of the window function to use.
15 | sr (int): Sample rate of the input signal.
16 | n_mels (int): Number of mel bins to generate.
17 | fmin (float): Minimum frequency of the mel bins.
18 | fmax (float): Maximum frequency of the mel bins. If None, defaults to sr / 2.
19 | power (float): Exponent for the magnitude spectrogram.
20 | power_to_db (bool): Whether to convert the power spectrogram to decibels.
21 | top_db (float): Maximum decibel value for the output spectrogram.
22 | power_to_db (bool): Whether to convert spectrogram from energy to power.
23 | out_channels (int): Number of output channels. If None, no channel is created.
24 |
25 | Call Args:
26 | input (tf.Tensor): Audio signal of shape (audio_len,) or (None, audio_len)
27 |
28 | Returns:
29 | tf.Tensor: Mel spectrogram of shape (..., n_mels, time, out_channels)
30 | or (..., n_mels, time) if out_channels is None.
31 |
32 | """
33 |
34 | def __init__(
35 | self,
36 | n_fft=2048,
37 | hop_length=512,
38 | win_length=None,
39 | window="hann_window",
40 | sr=16000,
41 | n_mels=128,
42 | fmin=20.0,
43 | fmax=None,
44 | power_to_db=True,
45 | top_db=80.0,
46 | power=2.0,
47 | amin=1e-10,
48 | ref=1.0,
49 | out_channels=None,
50 | name="mel_spectrogram",
51 | **kwargs,
52 | ):
53 | super(MelSpectrogram, self).__init__(name=name, **kwargs)
54 | self.n_fft = n_fft
55 | self.hop_length = hop_length
56 | self.win_length = win_length or n_fft
57 | self.window = window
58 | self.sr = sr
59 | self.n_mels = n_mels
60 | self.fmin = fmin
61 | self.fmax = fmax or int(sr / 2)
62 | self.power_to_db = power_to_db
63 | self.top_db = top_db
64 | self.power = power
65 | self.amin = amin
66 | self.ref = ref
67 | self.out_channels = out_channels
68 |
69 | @tf.function
70 | def call(self, input):
71 | spec = self.spectrogram(input) # audio to spectrogram with shape
72 | spec = self.melscale(spec) # spectrogram to mel spectrogram
73 | if self.power_to_db:
74 | spec = self.dbscale(spec) # mel spectrogram to decibel mel spectrogram
75 | spec = tf.linalg.matrix_transpose(
76 | spec
77 | ) # (..., time, n_mels) to (..., n_mels, time)
78 | if self.out_channels is not None:
79 | spec = self.update_channels(spec)
80 | return spec
81 |
82 | def spectrogram(self, input):
83 | spec = tf.signal.stft(
84 | input,
85 | frame_length=self.win_length,
86 | frame_step=self.hop_length,
87 | fft_length=self.n_fft,
88 | window_fn=getattr(tf.signal, self.window),
89 | pad_end=True,
90 | )
91 | spec = tf.math.pow(tf.math.abs(spec), self.power)
92 | return spec
93 |
94 | def melscale(self, input):
95 | nbin = tf.shape(input)[-1]
96 | matrix = tf.signal.linear_to_mel_weight_matrix(
97 | num_mel_bins=self.n_mels,
98 | num_spectrogram_bins=nbin,
99 | sample_rate=self.sr,
100 | lower_edge_hertz=self.fmin,
101 | upper_edge_hertz=self.fmax,
102 | )
103 | return tf.tensordot(input, matrix, axes=1)
104 |
105 | def dbscale(self, input):
106 | log_spec = 10.0 * (
107 | tf.math.log(tf.math.maximum(input, self.amin)) / tf.math.log(10.0)
108 | )
109 | if callable(self.ref):
110 | ref_value = self.ref(log_spec)
111 | else:
112 | ref_value = tf.math.abs(self.ref)
113 | log_spec -= (
114 | 10.0
115 | * tf.math.log(tf.math.maximum(ref_value, self.amin))
116 | / tf.math.log(10.0)
117 | )
118 | log_spec = tf.math.maximum(log_spec, tf.math.reduce_max(log_spec) - self.top_db)
119 | return log_spec
120 |
121 | def update_channels(self, input):
122 | spec = input[..., tf.newaxis]
123 | if self.out_channels > 1:
124 | multiples = tf.concat(
125 | [
126 | tf.ones(tf.rank(spec) - 1, dtype=tf.int32),
127 | tf.constant([self.out_channels], dtype=tf.int32),
128 | ],
129 | axis=0,
130 | )
131 | spec = tf.tile(spec, multiples)
132 | return spec
133 |
134 | def get_config(self):
135 | config = super(MelSpectrogram, self).get_config()
136 | config.update(
137 | {
138 | "n_fft": self.n_fft,
139 | "hop_length": self.hop_length,
140 | "win_length": self.win_length,
141 | "window": self.window,
142 | "sr": self.sr,
143 | "n_mels": self.n_mels,
144 | "fmin": self.fmin,
145 | "fmax": self.fmax,
146 | "power_to_db": self.power_to_db,
147 | "top_db": self.top_db,
148 | "power": self.power,
149 | "amin": self.amin,
150 | "ref": self.ref,
151 | "out_channels": self.out_channels,
152 | }
153 | )
154 | return config
155 |
156 |
157 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra")
158 | class MixUp(tf.keras.layers.Layer):
159 | """
160 | MixUp Augmentation Layer to apply MixUp to one batch.
161 |
162 | Args:
163 | alpha (float): Alpha parameter for beta distribution.
164 | prob (float): Probability of applying MixUp.
165 |
166 | Call Args:
167 | images (tf.Tensor): Batch of images.
168 | labels (tf.Tensor): Batch of labels.
169 |
170 | Returns:
171 | tf.Tensor: Batch of image.
172 | tf.Tensor: Batch of labels.
173 |
174 | """
175 |
176 | def __init__(self, alpha=0.2, prob=0.5, name="mix_up", **kwargs):
177 | super(MixUp, self).__init__(name=name, **kwargs)
178 | self.alpha = alpha
179 | self.prob = prob
180 |
181 | @tf.function
182 | def call(self, images, labels=None, training=False):
183 |
184 | # Skip batch if not training or if prob is not met or if labels are not provided
185 | if tf.random.uniform([]) > self.prob or not training or labels is None:
186 | return (images, labels) if labels is not None else images
187 |
188 | # Get original shape
189 | spec_shape = tf.shape(images)
190 | label_shape = tf.shape(labels)
191 |
192 | # Select lambda from beta distribution
193 | beta = tfp.distributions.Beta(self.alpha, self.alpha)
194 | lam = beta.sample(1)[0]
195 |
196 | # It's faster to roll the batch by one instead of shuffling it to create image pairs
197 | images = lam * images + (1 - lam) * tf.roll(images, shift=1, axis=0)
198 | labels = lam * labels + (1 - lam) * tf.roll(labels, shift=1, axis=0)
199 |
200 | # Ensure original shape
201 | images = tf.reshape(images, spec_shape)
202 | labels = tf.reshape(labels, label_shape)
203 |
204 | return images, labels
205 |
206 | def get_config(self):
207 | config = super(MixUp, self).get_config()
208 | config.update(
209 | {
210 | "alpha": self.alpha,
211 | "prob": self.prob,
212 | }
213 | )
214 | return config
215 |
216 |
217 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra")
218 | class CutMix(tf.keras.layers.Layer):
219 | """
220 | Augmentation layer to apply CutMix to one batch.
221 |
222 | Args:
223 | alpha (float): Alpha parameter for beta distribution.
224 | prob (float): Probability of applying CutMix.
225 | full_height (bool): If True, the patch will be cut with full height of the image.
226 | full_width (bool): If True, the patch will be cut with full width of the image.
227 |
228 | Call Args:
229 | images (tf.Tensor): Batch of images.
230 | labels (tf.Tensor): Batch of labels.
231 |
232 | Returns:
233 | tf.Tensor: Batch of image.
234 | tf.Tensor: Batch of labels.
235 | """
236 |
237 | def __init__(
238 | self,
239 | alpha=0.2,
240 | prob=0.5,
241 | full_height=False,
242 | full_width=False,
243 | name="cut_mix",
244 | **kwargs,
245 | ):
246 | super(CutMix, self).__init__(name=name, **kwargs)
247 | self.alpha = alpha
248 | self.prob = prob
249 | self.full_height = full_height
250 | self.full_width = full_width
251 |
252 | @tf.function
253 | def call(self, images, labels=None, training=False):
254 | # Skip batch if not training or if prob is not met or if labels are not provided
255 | if tf.random.uniform([]) > self.prob or not training or labels is None:
256 | return (images, labels) if labels is not None else images
257 |
258 | # Ensure 4D input
259 | images, was_2d = self._ensure_4d(images)
260 |
261 | # Get original shapes
262 | image_shape = tf.shape(images)
263 | label_shape = tf.shape(labels)
264 |
265 | # Select lambda from beta distribution
266 | beta = tfp.distributions.Beta(self.alpha, self.alpha)
267 | lam = beta.sample(1)[0]
268 |
269 | # It's faster to roll the batch by one instead of shuffling it to create image pairs
270 | images_rolled = tf.roll(images, shift=1, axis=0)
271 | labels_rolled = tf.roll(labels, shift=1, axis=0)
272 |
273 | # Find dimensions of patch
274 | H = tf.cast(image_shape[1], tf.int32)
275 | W = tf.cast(image_shape[2], tf.int32)
276 | r_x = (
277 | tf.random.uniform([], maxval=W, dtype=tf.int32)
278 | if not self.full_width
279 | else 0
280 | )
281 | r_y = (
282 | tf.random.uniform([], maxval=H, dtype=tf.int32)
283 | if not self.full_height
284 | else 0
285 | )
286 | r = 0.5 * tf.math.sqrt(1.0 - lam)
287 | r_w_p = r if not self.full_width else 1.0
288 | r_h_p = r if not self.full_height else 1.0
289 | r_w_half = tf.cast(r_w_p * tf.cast(W, tf.float32), tf.int32)
290 | r_h_half = tf.cast(r_h_p * tf.cast(H, tf.float32), tf.int32)
291 |
292 | # Find the coordinates of the patch
293 | x1 = tf.cast(tf.clip_by_value(r_x - r_w_half, 0, W), tf.int32)
294 | x2 = tf.cast(tf.clip_by_value(r_x + r_w_half, 1, W), tf.int32)
295 | y1 = tf.cast(tf.clip_by_value(r_y - r_h_half, 0, H), tf.int32)
296 | y2 = tf.cast(tf.clip_by_value(r_y + r_h_half, 1, H), tf.int32)
297 |
298 | # Extract outer-pad patch -> [0, 0, 1, 1, 0, 0]
299 | patch1 = images[:, y1:y2, x1:x2, :] # [batch, height, width, channel]
300 | patch1 = tf.pad(
301 | patch1, [[0, 0], [y1, H - y2], [x1, W - x2], [0, 0]]
302 | ) # outer-pad
303 |
304 | # Extract inner-pad patch -> [2, 2, 0, 0, 2, 2]
305 | patch2 = images_rolled[:, y1:y2, x1:x2, :]
306 | patch2 = tf.pad(
307 | patch2, [[0, 0], [y1, H - y2], [x1, W - x2], [0, 0]]
308 | ) # outer-pad
309 | patch2 = images_rolled - patch2 # inner-pad = img - outer-pad
310 |
311 | # Combine patches [0, 0, 1, 1, 0, 0] + [2, 2, 0, 0, 2, 2] -> [2, 2, 1, 1, 2, 2]
312 | images = patch1 + patch2
313 |
314 | # Combine labels
315 | lam = tf.cast((1.0 - (x2 - x1) * (y2 - y1) / (W * H)), tf.float32)
316 | labels = lam * labels + (1.0 - lam) * labels_rolled
317 |
318 | # Ensure original shape
319 | images = tf.reshape(images, image_shape)
320 | labels = tf.reshape(labels, label_shape)
321 |
322 | # Ensure original shape
323 | images = self._ensure_original_shape(images, was_2d)
324 |
325 | return images, labels
326 |
327 | def _ensure_4d(self, tensor):
328 | if len(tensor.shape) == 2:
329 | tensor = tf.expand_dims(tensor, axis=1)
330 | tensor = tf.expand_dims(tensor, axis=-1)
331 | return tensor, True
332 | return tensor, False
333 |
334 | def _ensure_original_shape(self, tensor, was_2d):
335 | if was_2d:
336 | tensor = tf.squeeze(tensor, axis=-1)
337 | tensor = tf.squeeze(tensor, axis=1)
338 | return tensor
339 |
340 | def get_config(self):
341 | config = super(CutMix, self).get_config()
342 | config.update(
343 | {
344 | "alpha": self.alpha,
345 | "prob": self.prob,
346 | "full_height": self.full_height,
347 | "full_width": self.full_width,
348 | }
349 | )
350 | return config
351 |
352 |
353 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra")
354 | class TimeFreqMask(tf.keras.layers.Layer):
355 | """
356 | Applies Time Freq Mask to spectrogram input
357 | Ref: https://pytorch.org/audio/main/_modules/torchaudio/functional/functional.html#mask_along_axis_iid
358 | """
359 |
360 | def __init__(
361 | self,
362 | freq_mask_prob=0.5,
363 | num_freq_masks=2,
364 | freq_mask_param=10,
365 | time_mask_prob=0.5,
366 | num_time_masks=2,
367 | time_mask_param=20,
368 | time_last=True,
369 | name="time_freq_mask",
370 | **kwargs,
371 | ):
372 | super(TimeFreqMask, self).__init__(name=name, **kwargs)
373 | self.freq_mask_prob = freq_mask_prob
374 | self.num_freq_masks = num_freq_masks
375 | self.freq_mask_param = freq_mask_param
376 | self.time_mask_prob = time_mask_prob
377 | self.num_time_masks = num_time_masks
378 | self.time_mask_param = time_mask_param
379 | self.time_last = time_last
380 |
381 | @tf.function
382 | def call(self, inputs, training=False):
383 | if not training:
384 | return inputs
385 | x = inputs
386 | # Adjust input shape
387 | ndims = tf.rank(x)
388 | shape = tf.shape(x)
389 |
390 | # if ndims == 3:
391 | # x = x[tf.newaxis, ...]
392 | # x = tf.reshape(x, shape=(1, tf.split(shape, 3)))
393 | # elif ndims == 2:
394 | # x = x[tf.newaxis, ..., tf.newaxis]
395 | # x = tf.reshape(x, shape=(1, tf.split(shape, 2), 1))
396 | # else:
397 | # pass
398 | # elif ndims > 4 or ndims < 2:
399 | # raise ValueError("Input tensor must be 2, 3, or 4-dimensional.")
400 | # Apply time mask
401 | for _ in tf.range(self.num_time_masks):
402 | x = self.mask_along_axis_iid(
403 | x,
404 | self.time_mask_param,
405 | 0,
406 | 2 + int(self.time_last),
407 | self.time_mask_prob,
408 | )
409 | # Apply freq mask
410 | for _ in tf.range(self.num_freq_masks):
411 | x = self.mask_along_axis_iid(
412 | x,
413 | self.freq_mask_param,
414 | 0,
415 | 2 + int(not self.time_last),
416 | self.freq_mask_prob,
417 | )
418 | # Re-adjust output shape
419 | # if ndims == 3:
420 | # x = x[0]
421 | # elif ndims == 2:
422 | # x = x[0, ..., 0]
423 | return x
424 |
425 | def mask_along_axis_iid(self, specs, mask_param, mask_value, axis, p):
426 | if axis not in [2, 3]:
427 | raise ValueError("Only Frequency and Time masking are supported")
428 |
429 | if not 0.0 <= p <= 1.0:
430 | raise ValueError(f"The value of p must be between 0.0 and 1.0 ({p} given).")
431 |
432 | mask_param = mask_param # self._get_mask_param(mask_param, p, specs.shape[axis])
433 | if tf.random.uniform([]) > p:
434 | return specs
435 |
436 | specs = tf.transpose(specs, perm=[0, 3, 1, 2]) # (batch, channel, freq, time)
437 |
438 | dtype = specs.dtype
439 | shape = tf.shape(specs)
440 |
441 | value = tf.random.uniform(shape=shape[:2], dtype=dtype) * mask_param
442 | min_value = tf.random.uniform(shape=shape[:2], dtype=dtype) * (
443 | specs.shape[axis] - value
444 | )
445 |
446 | # Create broadcastable mask
447 | mask_start = tf.cast(min_value, tf.float32)[..., None, None]
448 | mask_end = (tf.cast(min_value, tf.float32) + tf.cast(value, tf.float32))[
449 | ..., None, None
450 | ]
451 | mask = tf.range(0, specs.shape[axis], dtype=dtype)
452 |
453 | # Per batch example masking
454 | specs = tf.linalg.matrix_transpose(specs) if axis == 2 else specs
455 | cond = (mask >= mask_start) & (mask < mask_end)
456 | specs = tf.where(
457 | cond, tf.fill(tf.shape(specs), tf.cast(mask_value, dtype=dtype)), specs
458 | )
459 | specs = tf.linalg.matrix_transpose(specs) if axis == 2 else specs
460 |
461 | specs = tf.transpose(specs, perm=[0, 2, 3, 1]) # (batch, freq, time, channel)
462 |
463 | return specs
464 |
465 | def get_config(self):
466 | config = super(TimeFreqMask, self).get_config()
467 | config.update(
468 | {
469 | "freq_mask_prob": self.freq_mask_prob,
470 | "num_freq_masks": self.num_freq_masks,
471 | "freq_mask_param": self.freq_mask_param,
472 | "time_mask_prob": self.time_mask_prob,
473 | "num_time_masks": self.num_time_masks,
474 | "time_mask_param": self.time_mask_param,
475 | "time_last": self.time_last,
476 | }
477 | )
478 | return config
479 |
480 |
481 | @tf.keras.utils.register_keras_serializable(package="tensorflow_extra")
482 | class ZScoreMinMax(tf.keras.layers.Layer):
483 | """
484 | Applies Z-score normalization and Min-Max normalization to the input tensor.
485 | """
486 | def __init__(self, name="z_score_min_max", **kwargs):
487 | super(ZScoreMinMax, self).__init__(name=name, **kwargs)
488 |
489 | @tf.function
490 | def call(self, inputs):
491 | # Standardize using Z-score
492 | mean = tf.math.reduce_mean(inputs)
493 | std = tf.math.reduce_std(inputs)
494 | standardized = tf.where(tf.math.equal(std, 0), inputs - mean, (inputs - mean) / std)
495 |
496 | # Normalize using Min-Max
497 | min_val = tf.math.reduce_min(standardized)
498 | max_val = tf.math.reduce_max(standardized)
499 | normalized = tf.where(tf.math.equal(max_val - min_val, 0), standardized - min_val,
500 | (standardized - min_val) / (max_val - min_val))
501 |
502 | return normalized
503 |
504 | def get_config(self):
505 | config = super(ZScoreMinMax, self).get_config()
506 | return config
--------------------------------------------------------------------------------
/tensorflow_extra/utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | # Generates random integer
4 | def random_int(shape=[], minval=0, maxval=1):
5 | return tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=tf.int32)
6 |
7 |
8 | # Generats random float
9 | def random_float(shape=[], minval=0.0, maxval=1.0):
10 | rnd = tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=tf.float32)
11 | return rnd
12 |
--------------------------------------------------------------------------------
/tensorflow_extra/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.0.2"
2 |
--------------------------------------------------------------------------------