├── .gitignore ├── DEMO.ipynb ├── DEMO_train_audio_transform.ipynb ├── LICENSE.md ├── README.md ├── Wave-U-Net ├── InterpolationLayer.py ├── LICENSE ├── OutputLayer.py ├── README.md ├── UnetAudioSeparator.py └── unet_utils.py ├── audio ├── ex01_unprocessed_input.wav ├── ex02_unprocessed_input.wav ├── ex03_unprocessed_input.wav ├── ex04_unprocessed_input.wav ├── ex05_unprocessed_input.wav ├── ex06_unprocessed_input.wav ├── ex07_unprocessed_input.wav ├── ex08_unprocessed_input.wav ├── ex09_unprocessed_input.wav └── ex10_unprocessed_input.wav ├── data ├── mturk_experiment_1_results_interspeech2021.csv ├── mturk_experiment_2_results_interspeech2021.csv └── toy_dataset_0000.tfrecords ├── models ├── .gitkeep ├── audio_transforms │ └── .gitkeep └── recognition_networks │ ├── .gitkeep │ ├── arch1.json │ ├── arch2.json │ ├── arch3.json │ └── deep_feature_loss_weights.json ├── util_audio_preprocess.py ├── util_audio_transform.py ├── util_auditory_model_loss.py ├── util_cochlear_model.py └── util_recognition_network.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # SLURM 107 | *.out 108 | 109 | *.wav 110 | *.ckpt-* 111 | -------------------------------------------------------------------------------- /DEMO_train_audio_transform.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import sys\n", 11 | "import glob\n", 12 | "import numpy as np\n", 13 | "import tensorflow as tf\n", 14 | "import importlib\n", 15 | "import IPython.display as ipd\n", 16 | "\n", 17 | "import util_audio_preprocess\n", 18 | "import util_audio_transform\n", 19 | "import util_auditory_model_loss\n", 20 | "import util_cochlear_model\n", 21 | "import util_recognition_network\n" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "name": "stdout", 31 | "output_type": "stream", 32 | "text": [ 33 | "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/control_flow_ops.py:3632: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 34 | "Instructions for updating:\n", 35 | "Colocations handled automatically by placer.\n" 36 | ] 37 | } 38 | ], 39 | "source": [ 40 | "\"\"\"\n", 41 | "Data inputput pipeline\n", 42 | "\n", 43 | "Recommended training data format is tfrecords files containing foreground\n", 44 | "(speech) and background (noise) waveforms stored separately. Model was\n", 45 | "originally trained with 2-second audio clips (20 kHz sampling rate).\n", 46 | "\n", 47 | "NOTE: the full dataset used to train models in the paper is not included in\n", 48 | "this code release as it was compiled from previously published datasets\n", 49 | "(speech from Wall Street Journal and Spoken Wikipedia Corpora; background\n", 50 | "noise from Audioset). Data available upon request to authors.\n", 51 | "\"\"\"\n", 52 | "\n", 53 | "filenames = glob.glob('data/toy_dataset*.tfrecords')\n", 54 | "batch_size = 8\n", 55 | "feature_description = {\n", 56 | " 'background/signal': tf.io.FixedLenFeature([], tf.string, default_value=None),\n", 57 | " 'foreground/signal': tf.io.FixedLenFeature([], tf.string, default_value=None),\n", 58 | "}\n", 59 | "bytes_description = {\n", 60 | " 'background/signal': {'dtype': tf.float32, 'shape': [40000]}, \n", 61 | " 'foreground/signal': {'dtype': tf.float32, 'shape': [40000]},\n", 62 | "}\n", 63 | "\n", 64 | "\n", 65 | "def parse_tfrecord(tfrecord):\n", 66 | " tfrecord = tf.parse_single_example(tfrecord, features=feature_description)\n", 67 | " for key in bytes_description.keys():\n", 68 | " tfrecord[key] = tf.decode_raw(tfrecord[key], bytes_description[key]['dtype'])\n", 69 | " tfrecord[key] = tf.reshape(tfrecord[key], bytes_description[key]['shape'])\n", 70 | " return tfrecord\n", 71 | "\n", 72 | "\n", 73 | "def preprocess_audio_batch(batch):\n", 74 | " \"\"\"\n", 75 | " Function combines foreground (speech) and background (noise) audio\n", 76 | " signals with signal-to-noise ratios drawn uniformly between -20 and\n", 77 | " +10 dB. The returned dictionary contains the noisy speech signal,\n", 78 | " the clean speech signal, and the SNR.\n", 79 | " \"\"\"\n", 80 | " foreground_signal = batch['foreground/signal']\n", 81 | " background_signal = batch['background/signal']\n", 82 | " snr = tf.random.uniform(\n", 83 | " [tf.shape(foreground_signal)[0], 1],\n", 84 | " minval=-20.0,\n", 85 | " maxval=10.0,\n", 86 | " dtype=foreground_signal.dtype)\n", 87 | " signal_in_noise, signal, noise_scaled = util_audio_preprocess.tf_set_snr(\n", 88 | " foreground_signal,\n", 89 | " background_signal,\n", 90 | " snr)\n", 91 | " batch = {\n", 92 | " 'snr': snr,\n", 93 | " 'waveform_noisy': signal_in_noise,\n", 94 | " 'waveform_clean': signal,\n", 95 | " }\n", 96 | " return batch\n", 97 | "\n", 98 | "\n", 99 | "tf.reset_default_graph()\n", 100 | "tf.random.set_random_seed(0)\n", 101 | "\n", 102 | "dataset = tf.data.TFRecordDataset(filenames=filenames, compression_type='GZIP')\n", 103 | "dataset = dataset.map(parse_tfrecord)\n", 104 | "dataset = dataset.batch(batch_size)\n", 105 | "dataset = dataset.map(preprocess_audio_batch)\n", 106 | "dataset = dataset.prefetch(buffer_size=4)\n", 107 | "dataset = dataset.shuffle(buffer_size=32)\n", 108 | "dataset = dataset.repeat(count=None)\n", 109 | "\n", 110 | "iterator = dataset.make_one_shot_iterator()\n", 111 | "input_tensor_dict = iterator.get_next()\n" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 3, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "WARNING:tensorflow:From Wave-U-Net/UnetAudioSeparator.py:97: conv1d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.\n", 124 | "Instructions for updating:\n", 125 | "Use keras.layers.conv1d instead.\n", 126 | "1 recognition networks included for deep feature loss:\n", 127 | "|__ arch1_taskA: models/recognition_networks/arch1_taskA.ckpt-550000\n", 128 | "Building waveform loss\n", 129 | "Building cochlear model loss\n", 130 | "[make_cos_filters_nx] using filter_spacing=`erb`\n", 131 | "[make_cos_filters_nx] using filter_spacing=`erb`\n", 132 | "Building deep feature loss (recognition network: arch1_taskA)\n", 133 | "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/layers/core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n", 134 | "Instructions for updating:\n", 135 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n", 136 | "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", 137 | "Instructions for updating:\n", 138 | "Use tf.cast instead.\n", 139 | "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/signal/fft_ops.py:315: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", 140 | "Instructions for updating:\n", 141 | "Use tf.cast instead.\n" 142 | ] 143 | }, 144 | { 145 | "name": "stderr", 146 | "output_type": "stream", 147 | "text": [ 148 | "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gradients_impl.py:110: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n", 149 | " \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n" 150 | ] 151 | } 152 | ], 153 | "source": [ 154 | "\"\"\"\n", 155 | "Model training graph\n", 156 | "\n", 157 | "Components:\n", 158 | "1. U-Net audio transform (`util_audio_transform.build_unet`)\n", 159 | "2. Auditory model loss function (`util_auditory_model_loss.AuditoryModelLoss`)\n", 160 | "3. Tensorflow optimizer to train U-Net weights\n", 161 | "\"\"\"\n", 162 | "\n", 163 | "### U-Net audio transform\n", 164 | "tensor_waveform_noisy = input_tensor_dict['waveform_noisy']\n", 165 | "tensor_waveform_clean = input_tensor_dict['waveform_clean']\n", 166 | "tensor_waveform_denoised = util_audio_transform.build_unet(tensor_waveform_noisy)\n", 167 | "\n", 168 | "### Build auditory model loss function (specify recognition networks\n", 169 | "### to include in the deep feature loss)\n", 170 | "list_recognition_networks = [\n", 171 | " 'arch1_taskA',\n", 172 | "# 'arch2_taskA',\n", 173 | "# 'arch3_taskA',\n", 174 | "]\n", 175 | "auditory_model = util_auditory_model_loss.AuditoryModelLoss(\n", 176 | " list_recognition_networks=list_recognition_networks,\n", 177 | " tensor_wave0=tensor_waveform_clean,\n", 178 | " tensor_wave1=tensor_waveform_denoised)\n", 179 | "\n", 180 | "transform_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='separator')\n", 181 | "transform_saver = tf.train.Saver(var_list=transform_var_list, max_to_keep=0)\n", 182 | "\n", 183 | "### Specify loss function (waveform, cochlear model, or deep features)\n", 184 | "### and build optimizer object + training operation\n", 185 | "loss = auditory_model.loss_cochlear_model # <-- cochlear model loss is lightweight and works well\n", 186 | "# loss = auditory_model.loss_deep_features\n", 187 | "# loss = auditory_model.loss_waveform\n", 188 | "optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)\n", 189 | "train_op = optimizer.minimize(\n", 190 | " loss=loss,\n", 191 | " global_step=None,\n", 192 | " var_list=transform_var_list)\n" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": 4, 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "Loading `arch1_taskA` variables from models/recognition_networks/arch1_taskA.ckpt-550000\n", 205 | "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n", 206 | "Instructions for updating:\n", 207 | "Use standard file APIs to check for files with this prefix.\n", 208 | "INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskA.ckpt-550000\n", 209 | "INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskA.ckpt-550000\n", 210 | "Loss after training step 000000 = 759.71\n", 211 | "Loss after training step 000010 = 594.60\n", 212 | "Loss after training step 000020 = 547.26\n", 213 | "Loss after training step 000030 = 454.63\n", 214 | "Loss after training step 000040 = 478.92\n", 215 | "Loss after training step 000050 = 450.72\n", 216 | "Loss after training step 000060 = 395.44\n", 217 | "Loss after training step 000070 = 427.17\n", 218 | "Loss after training step 000080 = 456.67\n", 219 | "Loss after training step 000090 = 438.67\n", 220 | "Loss after training step 000100 = 431.82\n", 221 | "INFO:tensorflow:new_model.ckpt-100 is not in all_model_checkpoint_paths. Manually adding it.\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "\"\"\"\n", 227 | "Simple training routine\n", 228 | "\n", 229 | "Models in the paper were all trained for 600000 steps\n", 230 | "with batch size 8 and learning rate 10e-4.\n", 231 | "\"\"\"\n", 232 | "\n", 233 | "with tf.Session() as sess:\n", 234 | " sess.run(tf.global_variables_initializer())\n", 235 | " auditory_model.load_auditory_model_vars(sess)\n", 236 | " \n", 237 | " for step in range(101):\n", 238 | " _, step_loss = sess.run([train_op, loss])\n", 239 | " if step % 10 == 0:\n", 240 | " print(\"Loss after training step {:06d} = {:.02f}\".format(step, step_loss.mean()))\n", 241 | " \n", 242 | " transform_saver.save(\n", 243 | " sess,\n", 244 | " save_path='new_model.ckpt',\n", 245 | " global_step=step,\n", 246 | " write_meta_graph=False)\n" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [] 255 | } 256 | ], 257 | "metadata": { 258 | "kernelspec": { 259 | "display_name": "Python 3", 260 | "language": "python", 261 | "name": "python3" 262 | }, 263 | "language_info": { 264 | "codemirror_mode": { 265 | "name": "ipython", 266 | "version": 3 267 | }, 268 | "file_extension": ".py", 269 | "mimetype": "text/x-python", 270 | "name": "python", 271 | "nbconvert_exporter": "python", 272 | "pygments_lexer": "ipython3", 273 | "version": "3.5.2" 274 | } 275 | }, 276 | "nbformat": 4, 277 | "nbformat_minor": 2 278 | } 279 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Mark R. Saddler, Andrew Francl, Jenelle Feather 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Speech Denoising with Auditory Models ([arXiv](https://arxiv.org/abs/2011.10706), [audio examples](https://mcdermottlab.mit.edu/denoising/demo.html)) 3 | This is a TensorFlow implementation of our [Speech Denoising with Auditory Models](https://arxiv.org/abs/2011.10706). 4 | 5 | Contact: [Mark Saddler](mailto:msaddler@mit.edu) or [Andrew Francl](mailto:francl@mit.edu) 6 | 7 | 8 | ### Citation 9 | If you use our code for research, please cite our paper: 10 | Mark R. Saddler\*, Andrew Francl\*, Jenelle Feather, Kaizhi Qian, Yang Zhang, Josh H. McDermott (2021). Speech Denoising with Auditory Models. *Proc. Interspeech* 2021, 2681-2685. [arXiv:2011.10706](https://arxiv.org/abs/2011.10706). 11 | 12 | 13 | ### License 14 | The source code is published under the MIT license. See [LICENSE](./LICENSE.md) for details. In general, you can use the code for any purpose with proper attribution. If you do something interesting with the code, we'll be happy to know. Feel free to contact us. 15 | 16 | 17 | ### Requirements 18 | In order to speed setup and aid reproducibility we provide a Singularity container. This container holds all the libraries and dependencies needed to run the code and allows you to work in the same environment as was originally used. Please see the [Singularity Documentation](https://sylabs.io/guides/3.8/user-guide/) for more details. Download Singularity image: [tensorflow-1.13.0-denoising.simg](https://drive.google.com/file/d/1KFGMJnuX4l1KRQRRVnzbjXE6bjHA7Tjm/view?usp=sharing). 19 | 20 | 21 | ### Trained Models 22 | We provide model checkpoints for all of our trained audio transforms and deep feature recognition networks. Users must download the audio transform checkpoints to evaluate our denoising algorithms on their own audio. Both sets of checkpoints must be downloaded to run our [DEMO Jupyter notebook](./DEMO.ipynb). Download the entire `auditory-model-denoising/models` directory [here](https://drive.google.com/drive/folders/1HmXSCVOKQCq7G62rs9KE_jsvVO0UqclC?usp=sharing): 23 | - Recognition network checkpoints: [auditory-model-denoising/models/recognition_networks](https://drive.google.com/file/d/1v9dKlRCnMP7X9v5IFcg4H0U1bXottgDo/view?usp=sharing) 24 | - Auditory transform checkpoints: [auditory-model-denoising/models/audio_transforms](https://drive.google.com/file/d/1L21NqxN-nVSzpY9CtjtH-1zlKPhEQkow/view?usp=sharing) 25 | 26 | 27 | ### Quick Start 28 | We provide a [Jupyter notebook](./DEMO.ipynb) that (1) demos how to run our trained denoising models and (2) provides examples of how to compute the auditory model losses used to train the models. A [second notebook](./DEMO_train_audio_transform.ipynb) demos how to train a new audio transform using the auditory model losses. 29 | -------------------------------------------------------------------------------- /Wave-U-Net/InterpolationLayer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def learned_interpolation_layer(input, padding, level): 5 | ''' 6 | Implements a trainable upsampling layer by interpolation by a factor of two, from N samples to N*2 - 1. 7 | Interpolation of intermediate feature vectors v_1 and v_2 (of dimensionality F) is performed by 8 | w \cdot v_1 + (1-w) \cdot v_2, where \cdot is point-wise multiplication, and w an F-dimensional weight vector constrained to [0,1] 9 | :param input: Input features of shape [batch_size, 1, width, F] 10 | :param padding: 11 | :param level: 12 | :return: 13 | ''' 14 | assert(padding == "valid" or padding == "same") 15 | features = input.get_shape().as_list()[3] 16 | 17 | # Construct 2FxF weight matrix, where F is the number of feature channels in the feature map. 18 | # Matrix is constrained, made up out of two diagonal FxF matrices with diagonal weights w and 1-w. w is constrained to be in [0,1] # mioid 19 | weights = tf.get_variable("interp_" + str(level), shape=[features], dtype=tf.float32) 20 | weights_scaled = tf.nn.sigmoid(weights) # Constrain weights to [0,1] 21 | counter_weights = 1.0 - weights_scaled # Mirrored weights for the features from the other time step 22 | conv_weights = tf.expand_dims(tf.concat([tf.expand_dims(tf.diag(weights_scaled), axis=0), tf.expand_dims(tf.diag(counter_weights), axis=0)], axis=0), axis=0) 23 | intermediate_vals = tf.nn.conv2d(input, conv_weights, strides=[1,1,1,1], padding=padding.upper()) 24 | 25 | intermediate_vals = tf.transpose(intermediate_vals, [2, 0, 1, 3]) 26 | out = tf.transpose(input, [2, 0, 1, 3]) 27 | num_entries = out.get_shape().as_list()[0] 28 | out = tf.concat([out, intermediate_vals], axis=0) 29 | indices = list() 30 | 31 | # Interleave interpolated features with original ones, starting with the first original one 32 | num_outputs = (2*num_entries - 1) if padding == "valid" else 2*num_entries 33 | for idx in range(num_outputs): 34 | if idx % 2 == 0: 35 | indices.append(idx // 2) 36 | else: 37 | indices.append(num_entries + idx//2) 38 | out = tf.gather(out, indices) 39 | current_layer = tf.transpose(out, [1, 2, 0, 3]) 40 | return current_layer -------------------------------------------------------------------------------- /Wave-U-Net/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Daniel Stoller 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Wave-U-Net/OutputLayer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import unet_utils 4 | 5 | def independent_outputs(featuremap, source_names, num_channels, filter_width, padding, activation): 6 | outputs = dict() 7 | for name in source_names: 8 | outputs[name] = tf.layers.conv1d(featuremap, num_channels, filter_width, activation=activation, padding=padding) 9 | return outputs 10 | 11 | def difference_output(input_mix, featuremap, source_names, num_channels, filter_width, padding, activation, training): 12 | outputs = dict() 13 | sum_source = 0 14 | for name in source_names[:-1]: 15 | out = tf.layers.conv1d(featuremap, num_channels, filter_width, activation=activation, padding=padding) 16 | outputs[name] = out 17 | sum_source = sum_source + out 18 | 19 | # Compute last source based on the others 20 | last_source = unet_utils.crop(input_mix, sum_source.get_shape().as_list()) - sum_source 21 | last_source = unet_utils.AudioClip(last_source, training) 22 | outputs[source_names[-1]] = last_source 23 | return outputs 24 | -------------------------------------------------------------------------------- /Wave-U-Net/README.md: -------------------------------------------------------------------------------- 1 | # Wave-U-Net 2 | Implementation of the [Wave-U-Net](https://arxiv.org/abs/1806.03185) for audio source separation. 3 | 4 | Initial github repository copied from https://github.com/f90/Wave-U-Net (architecture was modified for compatibility). 5 | -------------------------------------------------------------------------------- /Wave-U-Net/UnetAudioSeparator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | import unet_utils 5 | import InterpolationLayer 6 | import OutputLayer 7 | 8 | class UnetAudioSeparator: 9 | ''' 10 | U-Net separator network for singing voice separation. 11 | Uses valid convolutions, so it predicts for the centre part of the input - only certain input and output shapes are therefore possible (see getpadding function) 12 | ''' 13 | 14 | def __init__(self, model_config): 15 | ''' 16 | Initialize U-net 17 | :param num_layers: Number of down- and upscaling layers in the network 18 | ''' 19 | self.num_layers = model_config.get("num_layers",12) 20 | self.num_initial_filters = model_config.get("num_initial_filters",24) 21 | self.filter_size = model_config.get("filter_size",15) 22 | self.merge_filter_size = model_config.get("merge_filter_size",5) 23 | self.input_filter_size = model_config.get("input_filter_size",15) 24 | self.output_filter_size = model_config.get("output_filter_size",1) 25 | self.upsampling = model_config.get("upsampling", 'learned') 26 | self.output_type = model_config.get("output_type", 'direct') 27 | self.context = model_config.get("context", False) 28 | self.padding = "valid" if model_config.get("context", False) else "same" 29 | self.source_names = model_config.get("source_names", ['enhancement']) 30 | self.num_channels = 1 if model_config.get("mono_downmix", True) else 2 31 | self.output_activation = model_config.get("output_activation", 'identity') 32 | 33 | def get_padding(self, shape): 34 | ''' 35 | Calculates the required amounts of padding along each axis of the input and output, so that the Unet works and has the given shape as output shape 36 | :param shape: Desired output shape 37 | :return: Input_shape, output_shape, where each is a list [batch_size, time_steps, channels] 38 | ''' 39 | 40 | if self.context: 41 | # Check if desired shape is possible as output shape - go from output shape towards lowest-res feature map 42 | rem = float(shape[1]) # Cut off batch size number and channel 43 | 44 | # Output filter size 45 | rem = rem - self.output_filter_size + 1 46 | 47 | # Upsampling blocks 48 | for i in range(self.num_layers): 49 | rem = rem + self.merge_filter_size - 1 50 | rem = (rem + 1.) / 2.# out = in + in - 1 <=> in = (out+1)/ 51 | 52 | # Round resulting feature map dimensions up to nearest integer 53 | x = np.asarray(np.ceil(rem),dtype=np.int64) 54 | assert(x >= 2) 55 | 56 | # Compute input and output shapes based on lowest-res feature map 57 | output_shape = x 58 | input_shape = x 59 | 60 | # Extra conv 61 | input_shape = input_shape + self.filter_size - 1 62 | 63 | # Go from centre feature map through up- and downsampling blocks 64 | for i in range(self.num_layers): 65 | output_shape = 2*output_shape - 1 #Upsampling 66 | output_shape = output_shape - self.merge_filter_size + 1 # Conv 67 | 68 | input_shape = 2*input_shape - 1 # Decimation 69 | if i < self.num_layers - 1: 70 | input_shape = input_shape + self.filter_size - 1 # Conv 71 | else: 72 | input_shape = input_shape + self.input_filter_size - 1 73 | 74 | # Output filters 75 | output_shape = output_shape - self.output_filter_size + 1 76 | 77 | input_shape = np.concatenate([[shape[0]], [input_shape], [self.num_channels]]) 78 | output_shape = np.concatenate([[shape[0]], [output_shape], [self.num_channels]]) 79 | 80 | return input_shape, output_shape 81 | else: 82 | return [shape[0], shape[1], self.num_channels], [shape[0], shape[1], self.num_channels] 83 | 84 | def get_output(self, input, training, return_spectrogram=False, reuse=True): 85 | ''' 86 | Creates symbolic computation graph of the U-Net for a given input batch 87 | :param input: Input batch of mixtures, 3D tensor [batch_size, num_samples, num_channels] 88 | :param reuse: Whether to create new parameter variables or reuse existing ones 89 | :return: U-Net output: List of source estimates. Each item is a 3D tensor [batch_size, num_out_samples, num_channels] 90 | ''' 91 | with tf.variable_scope("separator", reuse=reuse): 92 | enc_outputs = list() 93 | current_layer = input 94 | 95 | # Down-convolution: Repeat strided conv 96 | for i in range(self.num_layers): 97 | current_layer = tf.layers.conv1d(current_layer, self.num_initial_filters + (self.num_initial_filters * i), self.filter_size, strides=1, activation=unet_utils.LeakyReLU, padding=self.padding) # out = in - filter + 1 98 | enc_outputs.append(current_layer) 99 | current_layer = current_layer[:,::2,:] # Decimate by factor of 2 # out = (in-1)/2 + 1 100 | 101 | current_layer = tf.layers.conv1d(current_layer, self.num_initial_filters + (self.num_initial_filters * self.num_layers),self.filter_size,activation=unet_utils.LeakyReLU,padding=self.padding) # One more conv here since we need to compute features after last decimation 102 | 103 | # Feature map here shall be X along one dimension 104 | 105 | # Upconvolution 106 | for i in range(self.num_layers): 107 | #UPSAMPLING 108 | current_layer = tf.expand_dims(current_layer, axis=1) 109 | if self.upsampling == 'learned': 110 | # Learned interpolation between two neighbouring time positions by using a convolution filter of width 2, and inserting the responses in the middle of the two respective inputs 111 | current_layer = InterpolationLayer.learned_interpolation_layer(current_layer, self.padding, i) 112 | else: 113 | if self.context: 114 | current_layer = tf.image.resize_bilinear(current_layer, [1, current_layer.get_shape().as_list()[2] * 2 - 1], align_corners=True) 115 | else: 116 | current_layer = tf.image.resize_bilinear(current_layer, [1, current_layer.get_shape().as_list()[2]*2]) # out = in + in - 1 117 | current_layer = tf.squeeze(current_layer, axis=1) 118 | # UPSAMPLING FINISHED 119 | assert(enc_outputs[-i-1].get_shape().as_list()[1] == current_layer.get_shape().as_list()[1] or self.context) #No cropping should be necessary unless we are using context 120 | current_layer = unet_utils.crop_and_concat(enc_outputs[-i-1], current_layer, match_feature_dim=False) 121 | current_layer = tf.layers.conv1d(current_layer, self.num_initial_filters + (self.num_initial_filters * (self.num_layers - i - 1)), self.merge_filter_size, 122 | activation=unet_utils.LeakyReLU, 123 | padding=self.padding) # out = in - filter + 1 124 | 125 | current_layer = unet_utils.crop_and_concat(input, current_layer, match_feature_dim=False) 126 | 127 | # Output layer 128 | # Determine output activation function 129 | if self.output_activation == "identity": 130 | out_activation = tf.identity 131 | elif self.output_activation == "tanh": 132 | out_activation = tf.tanh 133 | elif self.output_activation == "linear": 134 | out_activation = lambda x: unet_utils.AudioClip(x, training) 135 | else: 136 | raise NotImplementedError 137 | 138 | if self.output_type == "direct": 139 | return OutputLayer.independent_outputs(current_layer, self.source_names, self.num_channels, self.output_filter_size, self.padding, out_activation) 140 | elif self.output_type == "difference": 141 | cropped_input = unet_utils.crop(input,current_layer.get_shape().as_list(), match_feature_dim=False) 142 | return OutputLayer.difference_output(cropped_input, current_layer, self.source_names, self.num_channels, self.output_filter_size, self.padding, out_activation, training) 143 | else: 144 | raise NotImplementedError 145 | -------------------------------------------------------------------------------- /Wave-U-Net/unet_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import librosa 4 | 5 | def getTrainableVariables(tag=""): 6 | return [v for v in tf.trainable_variables() if tag in v.name] 7 | 8 | def getNumParams(tensors): 9 | return np.sum([np.prod(t.get_shape().as_list()) for t in tensors]) 10 | 11 | def crop_and_concat(x1,x2, match_feature_dim=True): 12 | ''' 13 | Copy-and-crop operation for two feature maps of different size. 14 | Crops the first input x1 equally along its borders so that its shape is equal to 15 | the shape of the second input x2, then concatenates them along the feature channel axis. 16 | :param x1: First input that is cropped and combined with the second input 17 | :param x2: Second input 18 | :return: Combined feature map 19 | ''' 20 | if x2 is None: 21 | return x1 22 | 23 | x1 = crop(x1,x2.get_shape().as_list(), match_feature_dim) 24 | return tf.concat([x1, x2], axis=2) 25 | 26 | def random_amplify(sample): 27 | ''' 28 | Randomly amplifies or attenuates the input signal 29 | :return: Amplified signal 30 | ''' 31 | for key, val in sample.items(): 32 | if key != "mix": 33 | sample[key] = tf.random_uniform([], 0.7, 1.0) * val 34 | 35 | sample["mix"] = tf.add_n([val for key, val in sample.items() if key != "mix"]) 36 | return sample 37 | 38 | def crop_sample(sample, crop_frames): 39 | for key, val in sample.items(): 40 | if key != "mix" and crop_frames > 0: 41 | sample[key] = val[crop_frames:-crop_frames,:] 42 | return sample 43 | 44 | def pad_freqs(tensor, target_shape): 45 | ''' 46 | Pads the frequency axis of a 4D tensor of shape [batch_size, freqs, timeframes, channels] or 2D tensor [freqs, timeframes] with zeros 47 | so that it reaches the target shape. If the number of frequencies to pad is uneven, the rows are appended at the end. 48 | :param tensor: Input tensor to pad with zeros along the frequency axis 49 | :param target_shape: Shape of tensor after zero-padding 50 | :return: Padded tensor 51 | ''' 52 | target_freqs = (target_shape[1] if len(target_shape) == 4 else target_shape[0]) #TODO 53 | if isinstance(tensor, tf.Tensor): 54 | input_shape = tensor.get_shape().as_list() 55 | else: 56 | input_shape = tensor.shape 57 | 58 | if len(input_shape) == 2: 59 | input_freqs = input_shape[0] 60 | else: 61 | input_freqs = input_shape[1] 62 | 63 | diff = target_freqs - input_freqs 64 | if diff % 2 == 0: 65 | pad = [(diff/2, diff/2)] 66 | else: 67 | pad = [(diff//2, diff//2 + 1)] # Add extra frequency bin at the end 68 | 69 | if len(target_shape) == 2: 70 | pad = pad + [(0,0)] 71 | else: 72 | pad = [(0,0)] + pad + [(0,0), (0,0)] 73 | 74 | if isinstance(tensor, tf.Tensor): 75 | return tf.pad(tensor, pad, mode='constant', constant_values=0.0) 76 | else: 77 | return np.pad(tensor, pad, mode='constant', constant_values=0.0) 78 | 79 | def LeakyReLU(x, alpha=0.2): 80 | return tf.maximum(alpha*x, x) 81 | 82 | def AudioClip(x, training): 83 | ''' 84 | Simply returns the input if training is set to True, otherwise clips the input to [-1,1] 85 | :param x: Input tensor (coming from last layer of neural network) 86 | :param training: Whether model is in training (True) or testing mode (False) 87 | :return: Output tensor (potentially clipped) 88 | ''' 89 | if training: 90 | return x 91 | else: 92 | return tf.maximum(tf.minimum(x, 1.0), -1.0) 93 | 94 | def resample(audio, orig_sr, new_sr): 95 | return librosa.resample(audio.T, orig_sr, new_sr).T 96 | 97 | def load(path, sr=22050, mono=True, offset=0.0, duration=None, dtype=np.float32): 98 | # ALWAYS output (n_frames, n_channels) audio 99 | y, orig_sr = librosa.load(path, sr, mono, offset, duration, dtype) 100 | if len(y.shape) == 1: 101 | y = np.expand_dims(y, axis=0) 102 | return y.T, orig_sr 103 | 104 | def crop(tensor, target_shape, match_feature_dim=True): 105 | ''' 106 | Crops a 3D tensor [batch_size, width, channels] along the width axes to a target shape. 107 | Performs a centre crop. If the dimension difference is uneven, crop last dimensions first. 108 | :param tensor: 4D tensor [batch_size, width, height, channels] that should be cropped. 109 | :param target_shape: Target shape (4D tensor) that the tensor should be cropped to 110 | :return: Cropped tensor 111 | ''' 112 | shape = np.array(tensor.get_shape().as_list()) 113 | # Handles case when batch size is not defined during graph construction 114 | if shape[0] is None and target_shape[0] is None: 115 | shape[0] = -1 116 | target_shape[0] = -1 117 | diff = shape - np.array(target_shape) 118 | assert(diff[0] == 0 and (diff[2] == 0 or not match_feature_dim))# Only width axis can differ 119 | if (diff[1] % 2 != 0): 120 | print("WARNING: Cropping with uneven number of extra entries on one side") 121 | assert diff[1] >= 0 # Only positive difference allowed 122 | if diff[1] == 0: 123 | return tensor 124 | crop_start = diff // 2 125 | crop_end = diff - crop_start 126 | 127 | return tensor[:,crop_start[1]:-crop_end[1],:] 128 | 129 | def spectrogramToAudioFile(magnitude, fftWindowSize, hopSize, phaseIterations=10, phase=None, length=None): 130 | ''' 131 | Computes an audio signal from the given magnitude spectrogram, and optionally an initial phase. 132 | Griffin-Lim is executed to recover/refine the given the phase from the magnitude spectrogram. 133 | :param magnitude: Magnitudes to be converted to audio 134 | :param fftWindowSize: Size of FFT window used to create magnitudes 135 | :param hopSize: Hop size in frames used to create magnitudes 136 | :param phaseIterations: Number of Griffin-Lim iterations to recover phase 137 | :param phase: If given, starts ISTFT with this particular phase matrix 138 | :param length: If given, audio signal is clipped/padded to this number of frames 139 | :return: 140 | ''' 141 | if phase is not None: 142 | if phaseIterations > 0: 143 | # Refine audio given initial phase with a number of iterations 144 | return reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations, phase, length) 145 | # reconstructing the new complex matrix 146 | stftMatrix = magnitude * np.exp(phase * 1j) # magnitude * e^(j*phase) 147 | audio = librosa.istft(stftMatrix, hop_length=hopSize, length=length) 148 | else: 149 | audio = reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations) 150 | return audio 151 | 152 | def reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations=10, initPhase=None, length=None): 153 | ''' 154 | Griffin-Lim algorithm for reconstructing the phase for a given magnitude spectrogram, optionally with a given 155 | intial phase. 156 | :param magnitude: Magnitudes to be converted to audio 157 | :param fftWindowSize: Size of FFT window used to create magnitudes 158 | :param hopSize: Hop size in frames used to create magnitudes 159 | :param phaseIterations: Number of Griffin-Lim iterations to recover phase 160 | :param initPhase: If given, starts reconstruction with this particular phase matrix 161 | :param length: If given, audio signal is clipped/padded to this number of frames 162 | :return: 163 | ''' 164 | for i in range(phaseIterations): 165 | if i == 0: 166 | if initPhase is None: 167 | reconstruction = np.random.random_sample(magnitude.shape) + 1j * (2 * np.pi * np.random.random_sample(magnitude.shape) - np.pi) 168 | else: 169 | reconstruction = np.exp(initPhase * 1j) # e^(j*phase), so that angle => phase 170 | else: 171 | reconstruction = librosa.stft(audio, fftWindowSize, hopSize) 172 | spectrum = magnitude * np.exp(1j * np.angle(reconstruction)) 173 | if i == phaseIterations - 1: 174 | audio = librosa.istft(spectrum, hopSize, length=length) 175 | else: 176 | audio = librosa.istft(spectrum, hopSize) 177 | return audio 178 | -------------------------------------------------------------------------------- /audio/ex01_unprocessed_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex01_unprocessed_input.wav -------------------------------------------------------------------------------- /audio/ex02_unprocessed_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex02_unprocessed_input.wav -------------------------------------------------------------------------------- /audio/ex03_unprocessed_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex03_unprocessed_input.wav -------------------------------------------------------------------------------- /audio/ex04_unprocessed_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex04_unprocessed_input.wav -------------------------------------------------------------------------------- /audio/ex05_unprocessed_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex05_unprocessed_input.wav -------------------------------------------------------------------------------- /audio/ex06_unprocessed_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex06_unprocessed_input.wav -------------------------------------------------------------------------------- /audio/ex07_unprocessed_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex07_unprocessed_input.wav -------------------------------------------------------------------------------- /audio/ex08_unprocessed_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex08_unprocessed_input.wav -------------------------------------------------------------------------------- /audio/ex09_unprocessed_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex09_unprocessed_input.wav -------------------------------------------------------------------------------- /audio/ex10_unprocessed_input.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex10_unprocessed_input.wav -------------------------------------------------------------------------------- /data/toy_dataset_0000.tfrecords: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/data/toy_dataset_0000.tfrecords -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/models/.gitkeep -------------------------------------------------------------------------------- /models/audio_transforms/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/models/audio_transforms/.gitkeep -------------------------------------------------------------------------------- /models/recognition_networks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/models/recognition_networks/.gitkeep -------------------------------------------------------------------------------- /models/recognition_networks/arch1.json: -------------------------------------------------------------------------------- 1 | [{"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_data_layer"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_0", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 32, "kernel_size": [2, 42], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_0"}}, {"layer_type": "hpool", "args": {"name": "pool_0", "padding": "VALID_TIME", "strides": [2, 4], "pool_size": [8, 16], "sqrt_window":true}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_0"}},{"layer_type": "tf.layers.conv2d", "args": {"name": "conv_1", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 64, "kernel_size": [2, 18], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_1"}}, {"layer_type": "hpool", "args": {"name": "pool_1", "padding": "VALID_TIME", "strides": [2, 4], "pool_size": [8, 16], "sqrt_window":true}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_1"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_2", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 128, "kernel_size": [6, 6], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_2"}}, {"layer_type": "hpool", "args": {"name": "pool_2", "padding": "VALID_TIME", "strides": [1, 4], "pool_size": [1, 16], "sqrt_window":true}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_2"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_3", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 256, "kernel_size": [6, 6], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_3"}}, {"layer_type": "hpool", "args": {"name": "pool_3", "padding": "VALID_TIME", "strides": [1, 4], "pool_size": [1, 16], "sqrt_window":true}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_3"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_4", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 512, "kernel_size": [8, 8], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_4"}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_4"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_5", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 512, "kernel_size": [6, 6], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_5"}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_5"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_6", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 512, "kernel_size": [8, 8], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_6"}}, {"layer_type": "hpool", "args": {"name": "pool_4", "padding": "VALID_TIME", "strides": [2, 4], "pool_size": [8, 16], "sqrt_window":true}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_6"}}, {"layer_type": "tf.layers.flatten", "args": {"name": "flatten_end_conv"}}, {"layer_type": "tf.layers.dense", "args": {"name": "fc_intermediate", "activation": null, "units": 512}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_fc_intermediate"}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_fc_intermediate"}}, {"layer_type": "tf.layers.dropout", "args": {"name": "dropout", "rate": 0.5}}, {"layer_type": "fc_top_classification", "args": {"activation": null, "name": "fc_top"}}] 2 | -------------------------------------------------------------------------------- /models/recognition_networks/arch2.json: -------------------------------------------------------------------------------- 1 | [{"args": {"strides": [1, 1], "kernel_size": [2, 32], "filters": 32, "dilation_rate": [1, 1], "name": "conv_0", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_0"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [12, 64], "sqrt_window": false, "strides": [3, 16], "name": "pool_0", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_0"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [2, 16], "filters": 128, "dilation_rate": [1, 1], "name": "conv_1", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_1"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 16], "sqrt_window": false, "strides": [1, 4], "name": "pool_1", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_1"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [1, 16], "filters": 64, "dilation_rate": [1, 1], "name": "conv_2", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_2"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [8, 16], "sqrt_window": false, "strides": [2, 4], "name": "pool_2", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_2"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [2, 8], "filters": 512, "dilation_rate": [1, 1], "name": "conv_3", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_3"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 8], "sqrt_window": false, "strides": [1, 2], "name": "pool_3", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_3"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [3, 8], "filters": 512, "dilation_rate": [1, 1], "name": "conv_4", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_4"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 1], "sqrt_window": false, "strides": [1, 1], "name": "pool_4", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_4"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [2, 3], "filters": 512, "dilation_rate": [1, 1], "name": "conv_5", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_5"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 1], "sqrt_window": false, "strides": [1, 1], "name": "pool_5", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_5"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [2, 4], "filters": 512, "dilation_rate": [1, 1], "name": "conv_6", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_6"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 8], "sqrt_window": false, "strides": [1, 2], "name": "pool_6", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_6"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [1, 3], "filters": 512, "dilation_rate": [1, 1], "name": "conv_7", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_7"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 1], "sqrt_window": false, "strides": [1, 1], "name": "pool_7", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_7"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"name": "flatten_end_conv"}, "layer_type": "tf.layers.flatten"}, {"args": {"name": "fc_intermediate", "activation": null, "units": 1024}, "layer_type": "tf.layers.dense"}, {"args": {"name": "lrelu_fc_intermediate"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"name": "batch_norm_fc_intermediate"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"name": "dropout", "rate": 0.5}, "layer_type": "tf.layers.dropout"}, {"args": {"name": "fc_top", "activation": null}, "layer_type": "fc_top_classification"}] -------------------------------------------------------------------------------- /models/recognition_networks/arch3.json: -------------------------------------------------------------------------------- 1 | [{"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [3, 64], "padding": "VALID_TIME", "filters": 32, "activation": null, "name": "conv_0", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_0"}}, {"layer_type": "hpool", "args": {"name": "pool_0", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 32], "strides": [1, 8]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_0"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [2, 16], "padding": "VALID_TIME", "filters": 32, "activation": null, "name": "conv_1", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_1"}}, {"layer_type": "hpool", "args": {"name": "pool_1", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 32], "strides": [1, 8]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_1"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [1, 8], "padding": "VALID_TIME", "filters": 256, "activation": null, "name": "conv_2", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_2"}}, {"layer_type": "hpool", "args": {"name": "pool_2", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [8, 16], "strides": [2, 4]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_2"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [2, 8], "padding": "VALID_TIME", "filters": 512, "activation": null, "name": "conv_3", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_3"}}, {"layer_type": "hpool", "args": {"name": "pool_3", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 1], "strides": [1, 1]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_3"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [3, 2], "padding": "VALID_TIME", "filters": 256, "activation": null, "name": "conv_4", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_4"}}, {"layer_type": "hpool", "args": {"name": "pool_4", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 1], "strides": [1, 1]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_4"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [1, 2], "padding": "VALID_TIME", "filters": 256, "activation": null, "name": "conv_5", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_5"}}, {"layer_type": "hpool", "args": {"name": "pool_5", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 8], "strides": [1, 2]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_5"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [1, 4], "padding": "VALID_TIME", "filters": 256, "activation": null, "name": "conv_6", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_6"}}, {"layer_type": "hpool", "args": {"name": "pool_6", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 1], "strides": [1, 1]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_6"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [1, 3], "padding": "VALID_TIME", "filters": 128, "activation": null, "name": "conv_7", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_7"}}, {"layer_type": "hpool", "args": {"name": "pool_7", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 1], "strides": [1, 1]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_7"}}, {"layer_type": "tf.layers.flatten", "args": {"name": "flatten_end_conv"}}, {"layer_type": "tf.layers.dropout", "args": {"rate": 0.5, "name": "dropout"}}, {"layer_type": "fc_top_classification", "args": {"name": "fc_top", "activation": null}}] -------------------------------------------------------------------------------- /models/recognition_networks/deep_feature_loss_weights.json: -------------------------------------------------------------------------------- 1 | { 2 | "arch1_taskA": { 3 | "batch_norm_0": 7.721951988869115e-07, 4 | "batch_norm_1": 2.180155442434395e-06, 5 | "batch_norm_2": 4.35400902509891e-06, 6 | "batch_norm_3": 9.663577649529144e-06, 7 | "batch_norm_4": 7.037957926197431e-06, 8 | "batch_norm_5": 8.04823966794683e-06, 9 | "batch_norm_6": 9.601534802363368e-05 10 | }, 11 | "arch1_taskR": { 12 | "batch_norm_0": 4.5313309082274834e-06, 13 | "batch_norm_1": 2.2921854969402775e-06, 14 | "batch_norm_2": 9.444441099113525e-07, 15 | "batch_norm_3": 4.3454074465209316e-07, 16 | "batch_norm_4": 4.744116803227874e-07, 17 | "batch_norm_5": 7.252991823836001e-07, 18 | "batch_norm_6": 7.048356282228414e-07 19 | }, 20 | "arch1_taskW": { 21 | "batch_norm_0": 8.148634005683363e-07, 22 | "batch_norm_1": 2.655124304882748e-06, 23 | "batch_norm_2": 5.720402614649923e-06, 24 | "batch_norm_3": 1.7785463446355866e-05, 25 | "batch_norm_4": 1.1564212308251202e-05, 26 | "batch_norm_5": 1.3680587713195843e-05, 27 | "batch_norm_6": 0.00012747735009257673 28 | }, 29 | "arch2_taskA": { 30 | "batch_norm_0": 1.2787083010970974e-05, 31 | "batch_norm_1": 3.872243390175745e-06, 32 | "batch_norm_2": 4.0916914026823114e-05, 33 | "batch_norm_3": 1.0983149329860448e-05, 34 | "batch_norm_4": 1.6411619372836008e-05, 35 | "batch_norm_5": 1.9605537700159232e-05, 36 | "batch_norm_6": 8.89673192136142e-05, 37 | "batch_norm_7": 0.00015303936142223106 38 | }, 39 | "arch2_taskR": { 40 | "batch_norm_0": 0.0009033044077854935, 41 | "batch_norm_1": 0.001985916145766335, 42 | "batch_norm_2": 0.04740769988747281, 43 | "batch_norm_3": 0.049138172254876385, 44 | "batch_norm_4": 0.09488066449971078, 45 | "batch_norm_5": 0.14831164137971978, 46 | "batch_norm_6": 0.8964458129175272, 47 | "batch_norm_7": 2.0977237380635794 48 | }, 49 | "arch2_taskW": { 50 | "batch_norm_0": 1.7039983238137272e-05, 51 | "batch_norm_1": 4.049126793732267e-06, 52 | "batch_norm_2": 5.5138894571540054e-05, 53 | "batch_norm_3": 1.5994307965068624e-05, 54 | "batch_norm_4": 2.3465571033206468e-05, 55 | "batch_norm_5": 2.6932619594160208e-05, 56 | "batch_norm_6": 9.091449999742952e-05, 57 | "batch_norm_7": 0.0001247474154084437 58 | }, 59 | "arch3_taskA": { 60 | "batch_norm_0": 9.367178190155468e-07, 61 | "batch_norm_1": 4.047365064236076e-06, 62 | "batch_norm_2": 3.221161721690842e-06, 63 | "batch_norm_3": 1.6302867318651998e-06, 64 | "batch_norm_4": 3.209753982180463e-06, 65 | "batch_norm_5": 7.648978449406465e-06, 66 | "batch_norm_6": 9.518763618955694e-06, 67 | "batch_norm_7": 0.0001503918748524232 68 | }, 69 | "arch3_taskR": { 70 | "batch_norm_0": 8.944401019583153e-05, 71 | "batch_norm_1": 0.0015361094702168968, 72 | "batch_norm_2": 0.005594106590352144, 73 | "batch_norm_3": 0.004969826705029876, 74 | "batch_norm_4": 0.012617730321076999, 75 | "batch_norm_5": 0.04534875913067663, 76 | "batch_norm_6": 0.0649332731409872, 77 | "batch_norm_7": 0.19110162724547713 78 | }, 79 | "arch3_taskW": { 80 | "batch_norm_0": 8.376852943496161e-07, 81 | "batch_norm_1": 4.130054862512673e-06, 82 | "batch_norm_2": 4.30953902070214e-06, 83 | "batch_norm_3": 2.510501905339445e-06, 84 | "batch_norm_4": 5.482680435228746e-06, 85 | "batch_norm_5": 1.2116750065352435e-05, 86 | "batch_norm_6": 1.3421245766398092e-05, 87 | "batch_norm_7": 0.0001430609287507201 88 | } 89 | } -------------------------------------------------------------------------------- /util_audio_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | 5 | 6 | def tf_demean(x, axis=1): 7 | ''' 8 | Helper function to mean-subtract tensor. 9 | 10 | Args 11 | ---- 12 | x (tensor): tensor to be mean-subtracted 13 | axis (int): kwarg for tf.reduce_mean (axis along which to compute mean) 14 | 15 | Returns 16 | ------- 17 | x_demean (tensor): mean-subtracted tensor 18 | ''' 19 | x_demean = tf.math.subtract(x, tf.reduce_mean(x, axis=1, keepdims=True)) 20 | return x_demean 21 | 22 | 23 | def tf_rms(x, axis=1, keepdims=True): 24 | ''' 25 | Helper function to compute RMS amplitude of a tensor. 26 | 27 | Args 28 | ---- 29 | x (tensor): tensor for which RMS amplitude should be computed 30 | axis (int): kwarg for tf.reduce_mean (axis along which to compute mean) 31 | keepdims (bool): kwarg for tf.reduce_mean (specify if mean should keep collapsed dimension) 32 | 33 | Returns 34 | ------- 35 | rms_x (tensor): root-mean-square amplitude of x 36 | ''' 37 | rms_x = tf.sqrt(tf.reduce_mean(tf.math.square(x), axis=axis, keepdims=keepdims)) 38 | return rms_x 39 | 40 | 41 | def tf_set_snr(signal, noise, snr): 42 | ''' 43 | Helper function to combine signal and noise tensors with specified SNR. 44 | 45 | Args 46 | ---- 47 | signal (tensor): signal tensor 48 | noise (tensor): noise tensor 49 | snr (tensor): desired signal-to-noise ratio in dB 50 | 51 | Returns 52 | ------- 53 | signal_in_noise (tensor): equal to signal + noise_scaled 54 | signal (tensor): mean-subtracted version of input signal tensor 55 | noise_scaled (tensor): mean-subtracted and scaled version of input noise tensor 56 | 57 | Raises 58 | ------ 59 | InvalidArgumentError: Raised when rms(signal) == 0 or rms(noise) == 0. 60 | Occurs if noise or signal input are all zeros, which is incompatible with set_snr implementation. 61 | ''' 62 | # Mean-subtract the provided signal and noise 63 | signal = tf_demean(signal, axis=1) 64 | noise = tf_demean(noise, axis=1) 65 | # Compute RMS amplitudes of provided signal and noise 66 | rms_signal = tf_rms(signal, axis=1, keepdims=True) 67 | rms_noise = tf_rms(noise, axis=1, keepdims=True) 68 | # Ensure neither signal nor noise has an RMS amplitude of zero 69 | msg = 'The rms({:s}) == 0. Results from {:s} input values all equal to zero' 70 | tf.debugging.assert_none_equal(rms_signal, tf.zeros_like(rms_signal), 71 | message=msg.format('signal','signal')).mark_used() 72 | tf.debugging.assert_none_equal(rms_noise, tf.zeros_like(rms_noise), 73 | message=msg.format('noise','noise')).mark_used() 74 | # Convert snr from dB to desired ratio of RMS(signal) to RMS(noise) 75 | rms_ratio = tf.math.pow(10.0, snr / 20.0) 76 | # Re-scale RMS of the noise such that signal + noise will have desired SNR 77 | noise_scale_factor = tf.math.divide(rms_signal, tf.math.multiply(rms_noise, rms_ratio)) 78 | noise_scaled = tf.math.multiply(noise_scale_factor, noise) 79 | signal_in_noise = tf.math.add(signal, noise_scaled) 80 | return signal_in_noise, signal, noise_scaled 81 | 82 | 83 | def tf_set_dbspl(x, dbspl): 84 | ''' 85 | Helper function to scale tensor to a specified sound pressure level 86 | in dB re 20e-6 Pa (dB SPL). 87 | 88 | Args 89 | ---- 90 | x (tensor): tensor to be scaled 91 | dbspl (tensor): desired sound pressure level in dB re 20e-6 Pa 92 | 93 | Returns 94 | ------- 95 | x (tensor): mean-subtracted and scaled tensor 96 | scale_factor (tensor): constant x is multiplied by to set dB SPL 97 | ''' 98 | x = tf_demean(x, axis=1) 99 | rms_new = 20e-6 * tf.math.pow(10.0, dbspl / 20.0) 100 | rms_old = tf_rms(x, axis=1, keepdims=True) 101 | scale_factor = rms_new / rms_old 102 | x = tf.math.multiply(scale_factor, x) 103 | return x, scale_factor 104 | -------------------------------------------------------------------------------- /util_audio_transform.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | 5 | sys.path.append('Wave-U-Net') 6 | import UnetAudioSeparator 7 | 8 | 9 | def build_unet(tensor_waveform, signal_rate=20000, UNET_PARAMS={}): 10 | ''' 11 | This function builds the tensorflow graph for the Wave-U-Net. 12 | 13 | Args 14 | ---- 15 | tensor_waveform (tensor): input audio waveform (shape: [batch, time]) 16 | signal_rate (int): sampling rate of input waveform (Hz) 17 | UNET_PARAMS (dict): U-net configuration parameters 18 | 19 | Returns 20 | ------- 21 | tensor_waveform_unet (tensor): U-net transformed waveform (shape: [batch, time]) 22 | ''' 23 | padding = UNET_PARAMS.get('padding', [[0,0], [480,480]]) 24 | tensor_waveform_zero_padded = tf.pad(tensor_waveform, padding) 25 | tensor_waveform_expanded = tf.expand_dims(tensor_waveform_zero_padded,axis=2) 26 | unet_audio_separator = UnetAudioSeparator.UnetAudioSeparator(UNET_PARAMS) 27 | unet_audio_separator_output = unet_audio_separator.get_output( 28 | tensor_waveform_expanded, 29 | training=True, 30 | return_spectrogram=False, 31 | reuse=tf.AUTO_REUSE) 32 | tensor_waveform_unet = unet_audio_separator_output["enhancement"] 33 | tensor_waveform_unet = tensor_waveform_unet[:, padding[1][0]:-padding[1][1],:] 34 | tensor_waveform_unet = tf.squeeze(tensor_waveform_unet, axis=2) 35 | return tensor_waveform_unet 36 | -------------------------------------------------------------------------------- /util_auditory_model_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import json 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | import util_recognition_network 9 | import util_cochlear_model 10 | 11 | 12 | class AuditoryModelLoss(): 13 | def __init__(self, 14 | dir_recognition_networks='models/recognition_networks', 15 | list_recognition_networks=None, 16 | fn_weights='deep_feature_loss_weights.json', 17 | config_cochlear_model={}, 18 | tensor_wave0=None, 19 | tensor_wave1=None): 20 | """ 21 | The AuditoryModelLoss class creates an object to measure distances between 22 | auditory model representations of sounds. 23 | 24 | Args 25 | ---- 26 | dir_recognition_networks (str): directory containing recognition network architectures 27 | and model checkpoints 28 | list_recognition_networks (list or None): list of recognition network keys specifying 29 | the deep feature loss (if None, use all checkpoints in `dir_recognition_networks`) 30 | fn_weights (str): filename for deep feature loss weights, which weight the contribution 31 | of each recognition network layer to the total deep feature loss weight 32 | config_cochlear_model (dict): parameters for util_cochlear_model.build_cochlear_model 33 | tensor_wave0 (tensor with shape [batch, 40000]): reference waveform tensor (for training) 34 | tensor_wave1 (tensor with shape [batch, 40000]): denoised waveform tensor (for training) 35 | """ 36 | if not os.path.isabs(fn_weights): 37 | fn_weights = os.path.join(dir_recognition_networks, fn_weights) 38 | with open(fn_weights, 'r') as f_weights: 39 | deep_feature_loss_weights = json.load(f_weights) 40 | if list_recognition_networks is None: 41 | print(("`list_recognition_networks` not specified --> " 42 | "searching for all checkpoints in {}".format(dir_recognition_networks))) 43 | list_fn_ckpt = glob.glob(os.path.join(dir_recognition_networks, '*index')) 44 | list_fn_ckpt = [fn_ckpt.replace('.index', '') for fn_ckpt in list_fn_ckpt] 45 | else: 46 | list_fn_ckpt = [] 47 | for network_key in list_recognition_networks: 48 | tmp = glob.glob(os.path.join(dir_recognition_networks, '{}*index'.format(network_key))) 49 | msg = "Failed to find exactly 1 checkpoint for recognition network {}".format(network_key) 50 | assert len(tmp) == 1, msg 51 | list_fn_ckpt.append(tmp[0].replace('.index', '')) 52 | print("{} recognition networks included for deep feature loss:".format(len(list_fn_ckpt))) 53 | config_recognition_networks = {} 54 | for fn_ckpt in list_fn_ckpt: 55 | network_key = os.path.basename(fn_ckpt).split('.')[0] 56 | if 'taskA' in network_key: 57 | n_classes_dict = {"task_audioset": 517} 58 | else: 59 | n_classes_dict = {"task_word": 794} 60 | config_recognition_networks[network_key] = { 61 | 'fn_ckpt': fn_ckpt, 62 | 'fn_arch': fn_ckpt[:fn_ckpt.rfind('_task')] + '.json', 63 | 'n_classes_dict': n_classes_dict, 64 | 'weights': deep_feature_loss_weights[network_key], 65 | } 66 | print('|__ {}: {}'.format(network_key, fn_ckpt)) 67 | self.config_recognition_networks = config_recognition_networks 68 | self.config_cochlear_model = config_cochlear_model 69 | self.tensor_wave0 = tensor_wave0 70 | self.tensor_wave1 = tensor_wave1 71 | self.build_auditory_model() 72 | self.sess = None 73 | self.vars_loaded = False 74 | return 75 | 76 | 77 | def l1_distance(self, feature0, feature1): 78 | """ 79 | Computes L1 distance between two features (preserving axis 0 for batch) 80 | """ 81 | axis = np.arange(1, len(feature0.get_shape().as_list())) 82 | return tf.reduce_sum(tf.math.abs(feature0 - feature1), axis=axis) 83 | 84 | 85 | def build_auditory_model(self, dtype=tf.float32): 86 | """ 87 | Constructs the full auditory model and losses. This function builds two 88 | waveform placeholders and constructs two identical auditory models operating 89 | on the two placeholders. Waveform, cochlear model, and deep feature losses are 90 | computed by measuring distances between identical stages of the two auditory 91 | models. 92 | """ 93 | # Build placeholders for two waveforms (if needed) and compute waveform loss 94 | if (self.tensor_wave0 is None) or (self.tensor_wave1 is None): 95 | self.tensor_wave0 = tf.placeholder(dtype, [None, 40000]) 96 | self.tensor_wave1 = tf.placeholder(dtype, [None, 40000]) 97 | print('Building waveform loss') 98 | self.loss_waveform = self.l1_distance(self.tensor_wave0, self.tensor_wave1) 99 | # Build cochlear model for each waveform and compute cochlear model loss 100 | print('Building cochlear model loss') 101 | tensor_coch0, _ = util_cochlear_model.build_cochlear_model( 102 | self.tensor_wave0, 103 | **self.config_cochlear_model) 104 | tensor_coch1, _ = util_cochlear_model.build_cochlear_model( 105 | self.tensor_wave1, 106 | **self.config_cochlear_model) 107 | self.loss_cochlear_model = self.l1_distance(tensor_coch0, tensor_coch1) 108 | # Build network(s) for each waveform and compute deep feature losses 109 | self.loss_deep_features_dict = {} 110 | self.loss_deep_features = tf.zeros([], dtype=dtype) 111 | for network_key in sorted(self.config_recognition_networks.keys()): 112 | print('Building deep feature loss (recognition network: {})'.format(network_key)) 113 | with open(self.config_recognition_networks[network_key]['fn_arch'], 'r') as f: 114 | list_layer_dict = json.load(f) 115 | # Build network for stimulus 0 116 | with tf.variable_scope(network_key + '0') as scope: 117 | _, tensors_network0 = util_recognition_network.build_network( 118 | tensor_coch0, 119 | list_layer_dict, 120 | n_classes_dict=self.config_recognition_networks[network_key]['n_classes_dict']) 121 | var_list = { 122 | v.name.replace(scope.name + '/', '').replace(':0', ''): v 123 | for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name) 124 | } 125 | self.config_recognition_networks[network_key]['saver0'] = tf.train.Saver( 126 | var_list=var_list, 127 | max_to_keep=0) 128 | # Build network for stimulus 1 129 | with tf.variable_scope(network_key + '1') as scope: 130 | _, tensors_network1 = util_recognition_network.build_network( 131 | tensor_coch1, 132 | list_layer_dict, 133 | n_classes_dict=self.config_recognition_networks[network_key]['n_classes_dict']) 134 | var_list = { 135 | v.name.replace(scope.name + '/', '').replace(':0', ''): v 136 | for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name) 137 | } 138 | self.config_recognition_networks[network_key]['saver1'] = tf.train.Saver( 139 | var_list=var_list, 140 | max_to_keep=0) 141 | # Compute deep feature losses (weighted sum across layers) 142 | self.loss_deep_features_dict[network_key] = tf.zeros([], dtype=dtype) 143 | layer_weights = self.config_recognition_networks[network_key]['weights'] 144 | for layer_key in sorted(layer_weights.keys()): 145 | tmp = self.l1_distance(tensors_network0[layer_key], tensors_network1[layer_key]) 146 | self.loss_deep_features_dict[network_key] += layer_weights[layer_key] * tmp 147 | self.loss_deep_features += self.loss_deep_features_dict[network_key] 148 | 149 | 150 | def load_auditory_model_vars(self, sess): 151 | """ 152 | Loads the same variables into the two copies of each recognition network 153 | from recognition network checkpoints. 154 | """ 155 | self.sess = sess 156 | for network_key in sorted(self.config_recognition_networks.keys()): 157 | fn_ckpt = self.config_recognition_networks[network_key]['fn_ckpt'] 158 | saver0 = self.config_recognition_networks[network_key]['saver0'] 159 | saver1 = self.config_recognition_networks[network_key]['saver1'] 160 | print('Loading `{}` variables from {}'.format(network_key, fn_ckpt)) 161 | saver0.restore(self.sess, fn_ckpt) 162 | saver1.restore(self.sess, fn_ckpt) 163 | self.vars_loaded = True 164 | 165 | 166 | def waveform_loss(self, y0, y1): 167 | """ 168 | Method to compute waveform loss between two waveforms under 169 | the constructed auditory model. 170 | 171 | Args 172 | ---- 173 | y0 (np.ndarray): waveform with shape [batch, 40000] 174 | y1 (np.ndarray): waveform with shape [batch, 40000] 175 | 176 | Returns 177 | ------- 178 | (np.ndarray): L1 waveform loss with shape [batch] 179 | """ 180 | assert (self.sess is not None) and (not self.sess._closed) 181 | feed_dict={self.tensor_wave0: y0, self.tensor_wave1: y1} 182 | return self.sess.run(self.loss_waveform, feed_dict=feed_dict) 183 | 184 | 185 | def cochlear_model_loss(self, y0, y1): 186 | """ 187 | Method to compute cochlear model loss between two waveforms 188 | under the constructed auditory model. 189 | 190 | Args 191 | ---- 192 | y0 (np.ndarray): waveform with shape [batch, 40000] 193 | y1 (np.ndarray): waveform with shape [batch, 40000] 194 | 195 | Returns 196 | ------- 197 | (np.ndarray): L1 cochlear model loss with shape [batch] 198 | """ 199 | assert (self.sess is not None) and (not self.sess._closed) 200 | feed_dict={self.tensor_wave0: y0, self.tensor_wave1: y1} 201 | return self.sess.run(self.loss_cochlear_model, feed_dict=feed_dict) 202 | 203 | 204 | def deep_feature_loss(self, y0, y1): 205 | """ 206 | Method to compute deep feature loss between two waveforms 207 | under the constructed auditory model. 208 | 209 | Args 210 | ---- 211 | y0 (np.ndarray): waveform with shape [batch, 40000] 212 | y1 (np.ndarray): waveform with shape [batch, 40000] 213 | 214 | Returns 215 | ------- 216 | (np.ndarray): L1 deep feature loss with shape [batch] 217 | """ 218 | assert (self.sess is not None) and (not self.sess._closed) 219 | if not self.vars_loaded: 220 | print(("WARNING: `deep_feature_loss` called before loading vars")) 221 | feed_dict={self.tensor_wave0: y0, self.tensor_wave1: y1} 222 | return self.sess.run(self.loss_deep_features, feed_dict=feed_dict) 223 | -------------------------------------------------------------------------------- /util_cochlear_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is closely based on pycochleagram and tfcochleagram, 3 | which have been previously released: 4 | 5 | https://github.com/mcdermottLab/pycochleagram 6 | https://github.com/jenellefeather/tfcochleagram 7 | 8 | Minor modifications have been made here to provide a single script 9 | containing all functions needed to build the cochlear model used 10 | in this project. 11 | """ 12 | 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | 17 | import os 18 | import sys 19 | import warnings 20 | import functools 21 | import numpy as np 22 | import tensorflow as tf 23 | import scipy.signal as signal 24 | import matplotlib.pyplot as plt 25 | 26 | 27 | def freq2erb(freq_hz): 28 | """Converts Hz to human-defined ERBs, using the formula of Glasberg and Moore. 29 | 30 | Args: 31 | freq_hz (array_like): frequency to use for ERB. 32 | 33 | Returns: 34 | ndarray: **n_erb** -- Human-defined ERB representation of input. 35 | """ 36 | return 9.265 * np.log(1 + freq_hz / (24.7 * 9.265)) 37 | 38 | 39 | def erb2freq(n_erb): 40 | """Converts human ERBs to Hz, using the formula of Glasberg and Moore. 41 | Args: 42 | n_erb (array_like): Human-defined ERB to convert to frequency. 43 | Returns: 44 | ndarray: **freq_hz** -- Frequency representation of input. 45 | """ 46 | return 24.7 * 9.265 * (np.exp(n_erb / 9.265) - 1) 47 | 48 | 49 | def get_freq_rand_conversions(xp, seed=0, minval=0.0, maxval=1.0): 50 | """Generates freq2rand and rand2freq conversion functions. 51 | 52 | Args: 53 | xp (array_like): xvals for freq2rand linear interpolation. 54 | seed (int): numpy seed to generate yvals for linear interpolation. 55 | minval (float): yvals for linear interpolation are scaled to [minval, maxval]. 56 | maxval (float): yvals for linear interpolation are scaled to [minval, maxval]. 57 | 58 | Returns: 59 | freq2rand (function): converts Hz to random frequency scale 60 | rand2freq (function): converts random frequency scale to Hz 61 | """ 62 | np.random.seed(seed) 63 | yp = np.cumsum(np.random.poisson(size=xp.shape)) 64 | yp = ((maxval - minval) * (yp - yp.min())) / (yp.max() - yp.min()) + minval 65 | freq2rand = lambda x : np.interp(x, xp, yp) 66 | rand2freq = lambda y : np.interp(y, yp, xp) 67 | return freq2rand, rand2freq 68 | 69 | 70 | def make_cosine_filter(freqs, l, h, convert_to_erb=True): 71 | """Generate a half-cosine filter. Represents one subband of the cochleagram. 72 | A half-cosine filter is created using the values of freqs that are within the 73 | interval [l, h]. The half-cosine filter is centered at the center of this 74 | interval, i.e., (h - l) / 2. Values outside the valid interval [l, h] are 75 | discarded. So, if freqs = [1, 2, 3, ... 10], l = 4.5, h = 8, the cosine filter 76 | will only be defined on the domain [5, 6, 7] and the returned output will only 77 | contain 3 elements. 78 | Args: 79 | freqs (array_like): Array containing the domain of the filter, in ERB space; 80 | see convert_to_erb parameter below.. A single half-cosine 81 | filter will be defined only on the valid section of these values; 82 | specifically, the values between cutoffs ``l`` and ``h``. A half-cosine filter 83 | centered at (h - l ) / 2 is created on the interval [l, h]. 84 | l (float): The lower cutoff of the half-cosine filter in ERB space; see 85 | convert_to_erb parameter below. 86 | h (float): The upper cutoff of the half-cosine filter in ERB space; see 87 | convert_to_erb parameter below. 88 | convert_to_erb (bool, default=True): If this is True, the values in 89 | input arguments ``freqs``, ``l``, and ``h`` will be transformed from Hz to ERB 90 | space before creating the half-cosine filter. If this is False, the 91 | input arguments are assumed to be in ERB space. 92 | Returns: 93 | ndarray: **half_cos_filter** -- A half-cosine filter defined using elements of 94 | freqs within [l, h]. 95 | """ 96 | if convert_to_erb: 97 | freqs_erb = freq2erb(freqs) 98 | l_erb = freq2erb(l) 99 | h_erb = freq2erb(h) 100 | else: 101 | freqs_erb = freqs 102 | l_erb = l 103 | h_erb = h 104 | 105 | avg_in_erb = (l_erb + h_erb) / 2 # center of filter 106 | rnge_in_erb = h_erb - l_erb # width of filter 107 | # return np.cos((freq2erb(freqs[a_l_ind:a_h_ind+1]) - avg)/rnge * np.pi) # h_ind+1 to include endpoint 108 | # return np.cos((freqs_erb[(freqs_erb >= l_erb) & (freqs_erb <= h_erb)]- avg_in_erb) / rnge_in_erb * np.pi) # map cutoffs to -pi/2, pi/2 interval 109 | return np.cos((freqs_erb[(freqs_erb > l_erb) & (freqs_erb < h_erb)]- avg_in_erb) / rnge_in_erb * np.pi) # map cutoffs to -pi/2, pi/2 interval 110 | 111 | 112 | def make_full_filter_set(filts, signal_length=None): 113 | """Create the full set of filters by extending the filterbank to negative FFT 114 | frequencies. 115 | Args: 116 | filts (array_like): Array containing the cochlear filterbank in frequency space, 117 | i.e., the output of make_cos_filters_nx. Each row of ``filts`` is a 118 | single filter, with columns indexing frequency. 119 | signal_length (int, optional): Length of the signal to be filtered with this filterbank. 120 | This should be equal to filter length * 2 - 1, i.e., 2*filts.shape[1] - 1, and if 121 | signal_length is None, this value will be computed with the above formula. 122 | This parameter might be deprecated later. 123 | 124 | Returns: 125 | ndarray: **full_filter_set** -- Array containing the complete filterbank in 126 | frequency space. This output can be directly applied to the frequency 127 | representation of a signal. 128 | """ 129 | if signal_length is None: 130 | signal_length = 2 * filts.shape[1] - 1 131 | 132 | # note that filters are currently such that each ROW is a filter and COLUMN idxs freq 133 | if np.remainder(signal_length, 2) == 0: # even -- don't take the DC & don't double sample nyquist 134 | neg_filts = np.flipud(filts[1:filts.shape[0] - 1, :]) 135 | else: # odd -- don't take the DC 136 | neg_filts = np.flipud(filts[1:filts.shape[0], :]) 137 | fft_filts = np.vstack((filts, neg_filts)) 138 | # we need to switch representation to apply filters to fft of the signal, not sure why, but do it here 139 | return fft_filts.T 140 | 141 | 142 | def make_cos_filters_nx(signal_length, sr, n, low_lim, hi_lim, sample_factor, 143 | padding_size=None, full_filter=True, strict=True, 144 | bandwidth_scale_factor=1.0, include_lowpass=True, 145 | include_highpass=True, filter_spacing='erb'): 146 | """Create cosine filters, oversampled by a factor provided by "sample_factor" 147 | Args: 148 | signal_length (int): Length of signal to be filtered with the generated 149 | filterbank. The signal length determines the length of the filters. 150 | sr (int): Sampling rate associated with the signal waveform. 151 | n (int): Number of filters (subbands) to be generated with standard 152 | sampling (i.e., using a sampling factor of 1). Note, the actual number of 153 | filters in the generated filterbank depends on the sampling factor, and 154 | may optionally include lowpass and highpass filters that allow for 155 | perfect reconstruction of the input signal (the exact number of lowpass 156 | and highpass filters is determined by the sampling factor). The 157 | number of filters in the generated filterbank is given below: 158 | +---------------+---------------+-+------------+---+---------------------+ 159 | | sample factor | n_out |=| bandpass |\ +| highpass + lowpass | 160 | +===============+===============+=+============+===+=====================+ 161 | | 1 | n+2 |=| n |\ +| 1 + 1 | 162 | +---------------+---------------+-+------------+---+---------------------+ 163 | | 2 | 2*n+1+4 |=| 2*n+1 |\ +| 2 + 2 | 164 | +---------------+---------------+-+------------+---+---------------------+ 165 | | 4 | 4*n+3+8 |=| 4*n+3 |\ +| 4 + 4 | 166 | +---------------+---------------+-+------------+---+---------------------+ 167 | | s | s*(n+1)-1+2*s |=| s*(n+1)-1 |\ +| s + s | 168 | +---------------+---------------+-+------------+---+---------------------+ 169 | low_lim (int): Lower limit of frequency range. Filters will not be defined 170 | below this limit. 171 | hi_lim (int): Upper limit of frequency range. Filters will not be defined 172 | above this limit. 173 | sample_factor (int): Positive integer that determines how densely ERB function 174 | will be sampled to create bandpass filters. 1 represents standard sampling; 175 | adjacent bandpass filters will overlap by 50%. 2 represents 2x overcomplete sampling; 176 | adjacent bandpass filters will overlap by 75%. 4 represents 4x overcomplete sampling; 177 | adjacent bandpass filters will overlap by 87.5%. 178 | padding_size (int, optional): If None (default), the signal will not be padded 179 | before filtering. Otherwise, the filters will be created assuming the 180 | waveform signal will be padded to length padding_size*signal_length. 181 | full_filter (bool, default=True): If True (default), the complete filter that 182 | is ready to apply to the signal is returned. If False, only the first 183 | half of the filter is returned (likely positive terms of FFT). 184 | strict (bool, default=True): If True (default), will throw an error if 185 | sample_factor is not a power of two. This facilitates comparison across 186 | sample_factors. Also, if True, will throw an error if provided hi_lim 187 | is greater than the Nyquist rate. 188 | bandwidth_scale_factor (float, default=1.0): scales the bandpass filter bandwidths. 189 | bandwidth_scale_factor=2.0 means half-cosine filters will be twice as wide. 190 | Note that values < 1 will cause frequency gaps between the filters. 191 | bandwidth_scale_factor requires sample_factor=1, include_lowpass=False, include_highpass=False. 192 | include_lowpass (bool, default=True): if set to False, lowpass filter will be discarded. 193 | include_highpass (bool, default=True): if set to False, highpass filter will be discarded. 194 | filter_spacing (str, default='erb'): Specifies the type of reference spacing for the 195 | half-cosine filters. Options include 'erb' and 'linear'. 196 | Returns: 197 | tuple: 198 | A tuple containing the output: 199 | * **filts** (*array*)-- The filterbank consisting of filters have 200 | cosine-shaped frequency responses, with center frequencies equally 201 | spaced from low_lim to hi_lim on a scale specified by filter_spacing 202 | * **center_freqs** (*array*) -- center frequencies of filterbank in filts 203 | * **freqs** (*array*) -- freq vector in Hz, same frequency dimension as filts 204 | Raises: 205 | ValueError: Various value errors for bad choices of sample_factor or frequency 206 | limits; see description for strict parameter. 207 | UserWarning: Raises warning if cochlear filters exceed the Nyquist 208 | limit or go below 0. 209 | NotImplementedError: Raises error if specified filter_spacing is not implemented 210 | """ 211 | 212 | # Specifiy the type of filter spacing, if using linear filters instead 213 | if filter_spacing == 'erb': 214 | _freq2ref = freq2erb 215 | _ref2freq = erb2freq 216 | elif filter_spacing == 'erb_r': 217 | _freq2ref = lambda x: freq2erb(hi_lim) - freq2erb(hi_lim - x) 218 | _ref2freq = lambda x: hi_lim - erb2freq(freq2erb(hi_lim) - x) 219 | elif (filter_spacing == 'lin') or (filter_spacing == 'linear'): 220 | _freq2ref = lambda x: x 221 | _ref2freq = lambda x: x 222 | elif 'random' in filter_spacing: 223 | _freq2ref, _ref2freq = get_freq_rand_conversions( 224 | np.linspace(low_lim, hi_lim, n), 225 | seed=int(filter_spacing.split('-')[1].replace('seed', '')), 226 | minval=freq2erb(low_lim), 227 | maxval=freq2erb(hi_lim)) 228 | else: 229 | raise NotImplementedError('unrecognized spacing mode: %s' % filter_spacing) 230 | print('[make_cos_filters_nx] using filter_spacing=`{}`'.format(filter_spacing)) 231 | 232 | if not bandwidth_scale_factor == 1.0: 233 | assert sample_factor == 1, "bandwidth_scale_factor only supports sample_factor=1" 234 | assert include_lowpass == False, "bandwidth_scale_factor only supports include_lowpass=False" 235 | assert include_highpass == False, "bandwidth_scale_factor only supports include_highpass=False" 236 | 237 | if not isinstance(sample_factor, int): 238 | raise ValueError('sample_factor must be an integer, not %s' % type(sample_factor)) 239 | if sample_factor <= 0: 240 | raise ValueError('sample_factor must be positive') 241 | 242 | if sample_factor != 1 and np.remainder(sample_factor, 2) != 0: 243 | msg = 'sample_factor odd, and will change filter widths. Use even sample factors for comparison.' 244 | if strict: 245 | raise ValueError(msg) 246 | else: 247 | warnings.warn(msg, RuntimeWarning, stacklevel=2) 248 | 249 | if padding_size is not None and padding_size >= 1: 250 | signal_length += padding_size 251 | 252 | if np.remainder(signal_length, 2) == 0: # even length 253 | n_freqs = signal_length // 2 # .0 does not include DC, likely the sampling grid 254 | max_freq = sr / 2 # go all the way to nyquist 255 | else: # odd length 256 | n_freqs = (signal_length - 1) // 2 # .0 257 | max_freq = sr * (signal_length - 1) / 2 / signal_length # just under nyquist 258 | 259 | # verify the high limit is allowed by the sampling rate 260 | if hi_lim > sr / 2: 261 | hi_lim = max_freq 262 | msg = 'input arg "hi_lim" exceeds nyquist limit for max frequency; ignore with "strict=False"' 263 | if strict: 264 | raise ValueError(msg) 265 | else: 266 | warnings.warn(msg, RuntimeWarning, stacklevel=2) 267 | 268 | # changing the sampling density without changing the filter locations 269 | # (and, thereby changing their widths) requires that a certain number of filters 270 | # be used. 271 | n_filters = sample_factor * (n + 1) - 1 272 | n_lp_hp = 2 * sample_factor 273 | freqs = np.linspace(0, max_freq, n_freqs + 1) 274 | filts = np.zeros((n_freqs + 1, n_filters + n_lp_hp)) 275 | 276 | # cutoffs are evenly spaced on the scale specified by filter_spacing; for ERB scale, 277 | # interpolate linearly in erb space then convert back. 278 | # Also return the actual spacing used to generate the sequence (in case numpy does 279 | # something weird) 280 | center_freqs, step_spacing = np.linspace(_freq2ref(low_lim), _freq2ref(hi_lim), n_filters + 2, retstep=True) # +2 for bin endpoints 281 | # we need to exclude the endpoints 282 | center_freqs = center_freqs[1:-1] 283 | 284 | freqs_ref = _freq2ref(freqs) 285 | for i in range(n_filters): 286 | i_offset = i + sample_factor 287 | l = center_freqs[i] - sample_factor * bandwidth_scale_factor * step_spacing 288 | h = center_freqs[i] + sample_factor * bandwidth_scale_factor * step_spacing 289 | if _ref2freq(h) > sr/2: 290 | cf = _ref2freq(center_freqs[i]) 291 | msg = "High ERB cutoff of filter with cf={:.2f}Hz exceeds {:.2f}Hz (Nyquist frequency)" 292 | warnings.warn(msg.format(cf, sr/2)) 293 | if _ref2freq(l) < 0: 294 | cf = _ref2freq(center_freqs[i]) 295 | msg = 'Low ERB cutoff of filter with cf={:.2f}Hz is not strictly positive' 296 | warnings.warn(msg.format(cf)) 297 | # the first sample_factor # of rows in filts will be lowpass filters 298 | filts[(freqs_ref > l) & (freqs_ref < h), i_offset] = make_cosine_filter(freqs_ref, l, h, convert_to_erb=False) 299 | 300 | # add lowpass and highpass filters (there will be sample_factor number of each) 301 | for i in range(sample_factor): 302 | # account for the fact that the first sample_factor # of filts are lowpass 303 | i_offset = i + sample_factor 304 | lp_h_ind = max(np.where(freqs < _ref2freq(center_freqs[i]))[0]) # lowpass filter goes up to peak of first cos filter 305 | lp_filt = np.sqrt(1 - np.power(filts[:lp_h_ind+1, i_offset], 2)) 306 | 307 | hp_l_ind = min(np.where(freqs > _ref2freq(center_freqs[-1-i]))[0]) # highpass filter goes down to peak of last cos filter 308 | hp_filt = np.sqrt(1 - np.power(filts[hp_l_ind:, -1-i_offset], 2)) 309 | 310 | filts[:lp_h_ind+1, i] = lp_filt 311 | filts[hp_l_ind:, -1-i] = hp_filt 312 | 313 | # get center freqs for lowpass and highpass filters 314 | cfs_low = np.copy(center_freqs[:sample_factor]) - sample_factor * step_spacing 315 | cfs_hi = np.copy(center_freqs[-sample_factor:]) + sample_factor * step_spacing 316 | center_freqs = np.concatenate((cfs_low, center_freqs, cfs_hi)) 317 | 318 | # ensure that squared freq response adds to one 319 | filts = filts / np.sqrt(sample_factor) 320 | 321 | # convert center freqs from ERB numbers to Hz 322 | center_freqs = _ref2freq(center_freqs) 323 | 324 | # rectify 325 | center_freqs[center_freqs < 0] = 1 326 | 327 | # discard highpass and lowpass filters, if requested 328 | if include_lowpass == False: 329 | filts = filts[:, sample_factor:] 330 | center_freqs = center_freqs[sample_factor:] 331 | if include_highpass == False: 332 | filts = filts[:, :-sample_factor] 333 | center_freqs = center_freqs[:-sample_factor] 334 | 335 | # make the full filter by adding negative components 336 | if full_filter: 337 | filts = make_full_filter_set(filts, signal_length) 338 | 339 | return filts, center_freqs, freqs 340 | 341 | 342 | def tflog10(x): 343 | """Implements log base 10 in tensorflow """ 344 | numerator = tf.log(x) 345 | denominator = tf.log(tf.constant(10, dtype=numerator.dtype)) 346 | return numerator / denominator 347 | 348 | 349 | @tf.custom_gradient 350 | def stable_power_compression_norm_grad(x): 351 | """With this power compression function, the gradients from the power compression are not applied via backprop, we just pass the previous gradient onwards""" 352 | e = tf.nn.relu(x) # add relu to x to avoid NaN in loss 353 | p = tf.pow(e,0.3) 354 | def grad(dy): #try to check for nans before we clip the gradients. (use tf.where) 355 | return dy 356 | return p, grad 357 | 358 | 359 | @tf.custom_gradient 360 | def stable_power_compression(x): 361 | """Clip the gradients for the power compression and remove nans. Clipped values are (-1,1), so any cochleagram value below ~0.2 will be clipped.""" 362 | e = tf.nn.relu(x) # add relu to x to avoid NaN in loss 363 | p = tf.pow(e,0.3) 364 | def grad(dy): #try to check for nans before we clip the gradients. (use tf.where) 365 | g = 0.3 * pow(e,-0.7) 366 | is_nan_values = tf.is_nan(g) 367 | replace_nan_values = tf.ones(tf.shape(g), dtype=tf.float32)*1 368 | return dy * tf.where(is_nan_values,replace_nan_values,tf.clip_by_value(g, -1, 1)) 369 | return p, grad 370 | 371 | 372 | def cochleagram_graph(nets, SIGNAL_SIZE, SR, ENV_SR=200, LOW_LIM=20, HIGH_LIM=8000, N=40, SAMPLE_FACTOR=4, compression='none', WINDOW_SIZE=1001, debug=False, subbands_ifft=False, pycoch_downsamp=False, linear_max=796.87416837456942, input_node='input_signal', mean_subtract=False, rms_normalize=False, SMOOTH_ABS = False, return_subbands_only=False, include_all_keys=False, rectify_and_lowpass_subbands=False, pad_factor=None, return_coch_params=False, rFFT=False, linear_params=None, custom_filts=None, custom_compression_op=None, erb_filter_kwargs={}, reshape_kell2018=False, include_subbands_noise=False, subbands_noise_mean=0., subbands_noise_stddev=0., rate_level_kwargs={}, preprocess_kwargs={}): 373 | """ 374 | Creates a tensorflow cochleagram graph using the pycochleagram erb filters to create the cochleagram with the tensorflow functions. 375 | Parameters 376 | ---------- 377 | nets : dictionary 378 | dictionary containing parts of the cochleagram graph. At a minumum, nets['input_signal'] (or equivilant) should be defined containing a placeholder (if just constructing cochleagrams) or a variable (if optimizing over the cochleagrams), and can have a batch size>1. 379 | SIGNAL_SIZE : int 380 | the length of the audio signal used for the cochleagram graph 381 | SR : int 382 | raw sampling rate in Hz for the audio. 383 | ENV_SR : int 384 | the sampling rate for the cochleagram after downsampling 385 | LOW_LIM : int 386 | Lower frequency limits for the filters. 387 | HIGH_LIM : int 388 | Higher frequency limits for the filters. 389 | N : int 390 | Number of filters to uniquely span the frequency space 391 | SAMPLE_FACTOR : int 392 | number of times to overcomplete the filters. 393 | compression : string. see include_compression for compression options 394 | determine compression type to use in the cochleagram graph. If return_subbands is true, compress the rectified subbands 395 | WINDOW_SIZE : int 396 | the size of a window to use for the downsampling filter 397 | debug : boolean 398 | Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal when set to True (default False). 399 | subbands_ifft : boolean 400 | If true, adds the ifft of the subbands to nets 401 | input_node : string 402 | Name of the top level of nets, this is the input into the cochleagram graph. 403 | mean_subtract : boolean 404 | If true, subtracts the mean of the waveform (explicitly removes the DC offset) 405 | rms_normalize : Boolean # ONLY USE WHEN GENERATING COCHLEAGRAMS 406 | If true, divides the input signal by its RMS value, such that the RMS value of the sound going into the cochleagram generation is equal to 1. This option should be false if inverting cochleagrams, as it can cause problems with the gradients 407 | linear_max : float 408 | If default value, use 796.87416837456942, which is the 5th percentile from the speech dataset when it is rms normalized to a value of 1. This value is only used if the compression is 'linearbelow1', 'linearbelow1sqrt', 'stable_point3' 409 | SMOOTH_ABS : Boolean 410 | If True, uses a smoother version of the absolute value for the hilbert transform sqrt(10^-3 + real(env) + imag(env)) 411 | return_subbands_only : Boolean 412 | If True, returns the non-envelope extracted subbands before taking the hilbert envelope as the output node of the graph 413 | include_all_keys : Boolean 414 | If True, returns all of the cochleagram and subbands processing keys in the dictionary 415 | rectify_and_lowpass_subbands : Boolean 416 | If True, rectifies and lowpasses the subbands before returning them (only works with return_subbands_only) 417 | pad_factor : int 418 | how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length) 419 | return_coch_params : Boolean 420 | If True, returns the cochleagram generation parameters in addition to nets 421 | rFFT : Boolean 422 | If True, builds the graph using rFFT and irFFT operations whenever possible 423 | linear_params : list of floats 424 | used for the linear compression operation, [m, b] where the output of the compression is y=mx+b. m and b can be vectors of shape [1,num_filts,1] to apply different values to each frequency channel. 425 | custom_filts : None, or numpy array 426 | if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS] 427 | custom_compression_op : None or tensorflow partial function 428 | if specified as a function, applies the tensorflow function as a custom compression operation. Should take the input node and 'name' as the arguments 429 | erb_filter_kwargs : dictionary 430 | contains additional arguments with filter parameters to use with erb.make_erb_cos_filters 431 | reshape_kell2018 : boolean (False) 432 | if true, reshapes the output cochleagram to be 256x256 as used by kell2018 433 | include_subbands_noise : boolean (False) 434 | if include_subbands_noise and return_subbands_only are both true, white noise is added to subbands after compression (this feature is currently only accessible when return_subbands_only == True) 435 | subbands_noise_mean : float 436 | sets mean of subbands white noise if include_subbands_noise == True 437 | subbands_noise_stddev : float 438 | sets standard deviation of subbands white noise if include_subbands_noise == True 439 | rate_level_kwargs : dictionary 440 | contains keyword arguments for AN_rate_level_function (used if compression == 'rate_level') 441 | preprocess_kwargs : dictionary 442 | contains keyword arguments for preprocess_input function (used to randomize input dB SPL) 443 | 444 | Returns 445 | ------- 446 | nets : dictionary 447 | a dictionary containing the parts of the cochleagram graph. Top node in this graph is nets['output_tfcoch_graph'] 448 | COCH_PARAMS : dictionary (Optional) 449 | a dictionary containing all of the input parameters into the function 450 | """ 451 | if return_coch_params: 452 | COCH_PARAMS = locals() 453 | COCH_PARAMS.pop('nets') 454 | 455 | # run preprocessing operations on the input (ie rms normalization, convert to complex) 456 | nets = preprocess_input(nets, SIGNAL_SIZE, input_node, mean_subtract, rms_normalize, rFFT, **preprocess_kwargs) 457 | 458 | # fft of the input 459 | nets = fft_of_input(nets, pad_factor,debug, rFFT) 460 | 461 | # Make a wrapper for the compression function so it can be applied to the cochleagram and the subbands 462 | compression_function = functools.partial(include_compression, compression=compression, linear_max=linear_max, linear_params=linear_params, rate_level_kwargs=rate_level_kwargs, custom_compression_op=custom_compression_op) 463 | 464 | # make cochlear filters and compute the cochlear subbands 465 | nets = extract_cochlear_subbands(nets, SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, pad_factor, debug, subbands_ifft, return_subbands_only, rectify_and_lowpass_subbands, rFFT, custom_filts, erb_filter_kwargs, include_all_keys, compression_function, include_subbands_noise, subbands_noise_mean, subbands_noise_stddev) 466 | 467 | # Build the rest of the graph for the downsampled cochleagram, if we are returning the cochleagram or if we want to build the whole graph anyway. 468 | if (not return_subbands_only) or include_all_keys: 469 | # hilbert transform on subband fft 470 | nets = hilbert_transform_from_fft(nets, SR, SIGNAL_SIZE, pad_factor, debug, rFFT) 471 | 472 | # absolute value of the envelopes (and expand to one channel) 473 | nets = abs_envelopes(nets, SMOOTH_ABS) 474 | 475 | # downsample and rectified nonlinearity 476 | nets = downsample_and_rectify(nets, SR, ENV_SR, WINDOW_SIZE, pycoch_downsamp) 477 | 478 | # compress cochleagram 479 | nets = compression_function(nets, input_node_name='cochleagram_no_compression', output_node_name='cochleagram') 480 | 481 | if reshape_kell2018: 482 | nets, output_node_name_coch = reshape_coch_kell_2018(nets) 483 | else: 484 | output_node_name_coch = 'cochleagram' 485 | 486 | if return_subbands_only: 487 | nets['output_tfcoch_graph'] = nets['subbands_time_processed'] 488 | else: 489 | nets['output_tfcoch_graph'] = nets[output_node_name_coch] 490 | 491 | # return 492 | if return_coch_params: 493 | return nets, COCH_PARAMS 494 | else: 495 | return nets 496 | 497 | 498 | def preprocess_input(nets, SIGNAL_SIZE, input_node, mean_subtract, rms_normalize, rFFT, 499 | set_dBSPL=False, dBSPL_range=[60., 60.]): 500 | """ 501 | Does preprocessing on the input (rms and converting to complex number) 502 | Parameters 503 | ---------- 504 | nets : dictionary 505 | dictionary containing parts of the cochleagram graph. should already contain input_node 506 | input_node : string 507 | Name of the top level of nets, this is the input into the cochleagram graph. 508 | mean_subtract : boolean 509 | If true, subtracts the mean of the waveform (explicitly removes the DC offset) 510 | rms_normalize : Boolean # TODO: incorporate stable gradient code for RMS 511 | If true, divides the input signal by its RMS value, such that the RMS value of the sound going 512 | rFFT : Boolean 513 | If true, preprocess input for using the rFFT operations 514 | set_dBSPL : Boolean 515 | If true, re-scale input waveform to dB SPL sampled uniformly from dBSPL_range 516 | dBSPL_range : list 517 | Range of sound presentation levels in units of dB re 20e-6 Pa ([minval, maxval]) 518 | Returns 519 | ------- 520 | nets : dictionary 521 | updated dictionary containing parts of the cochleagram graph. 522 | """ 523 | 524 | if rFFT: 525 | if SIGNAL_SIZE%2!=0: 526 | print('rFFT is only tested with even length signals. Change your input length.') 527 | return 528 | 529 | processed_input_node = input_node 530 | 531 | if mean_subtract: 532 | processed_input_node = processed_input_node + '_mean_subtract' 533 | nets[processed_input_node] = nets[input_node] - tf.reshape(tf.reduce_mean(nets[input_node],1),(-1,1)) 534 | input_node = processed_input_node 535 | 536 | if rms_normalize: # TODO: incoporate stable RMS normalization 537 | processed_input_node = processed_input_node + '_rms_normalized' 538 | nets['rms_input'] = tf.sqrt(tf.reduce_mean(tf.square(nets[input_node]), 1)) 539 | nets[processed_input_node] = tf.identity(nets[input_node]/tf.reshape(nets['rms_input'],(-1,1)),'rms_normalized_input') 540 | input_node = processed_input_node 541 | 542 | if set_dBSPL: # NOTE: unstable if RMS of input is zero 543 | processed_input_node = processed_input_node + '_set_dBSPL' 544 | assert rms_normalize == False, "rms_normalize must be False if set_dBSPL=True" 545 | assert len(dBSPL_range) == 2, "dBSPL_range must be specified as [minval, maxval]" 546 | nets['dBSPL_set'] = tf.random.uniform([tf.shape(nets[input_node])[0], 1], 547 | minval=dBSPL_range[0], maxval=dBSPL_range[1], 548 | dtype=nets[input_node].dtype, name='sample_dBSPL_set') 549 | nets['rms_set'] = 20e-6 * tf.math.pow(10., nets['dBSPL_set'] / 20.) 550 | nets['rms_input'] = tf.sqrt(tf.reduce_mean(tf.square(nets[input_node]), axis=1, keepdims=True)) 551 | nets[processed_input_node] = tf.math.multiply(nets['rms_set'] / nets['rms_input'], nets[input_node], 552 | name='scale_input_to_dBSPL_set') 553 | input_node = processed_input_node 554 | 555 | if not rFFT: 556 | nets['input_signal_i'] = nets[input_node]*0.0 557 | nets['input_signal_complex'] = tf.complex(nets[input_node], nets['input_signal_i'], name='input_complex') 558 | else: 559 | nets['input_real'] = nets[input_node] 560 | return nets 561 | 562 | 563 | def fft_of_input(nets, pad_factor, debug, rFFT): 564 | """ 565 | Computs the fft of the signal and adds appropriate padding 566 | 567 | Parameters 568 | ---------- 569 | nets : dictionary 570 | dictionary containing parts of the cochleagram graph. 'subbands' are used for the hilbert transform 571 | pad_factor : int 572 | how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length) 573 | debug : boolean 574 | Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal when set to True. 575 | rFFT : Boolean 576 | If true, cochleagram graph is constructed using rFFT wherever possible 577 | Returns 578 | ------- 579 | nets : dictionary 580 | updated dictionary containing parts of the cochleagram graph with the rFFT of the input 581 | """ 582 | # fft of the input 583 | if not rFFT: 584 | if pad_factor is not None: 585 | nets['input_signal_complex'] = tf.concat([nets['input_signal_complex'], tf.zeros([nets['input_signal_complex'].get_shape()[0], nets['input_signal_complex'].get_shape()[1]*(pad_factor-1)], dtype=tf.complex64)], axis=1) 586 | nets['fft_input'] = tf.fft(nets['input_signal_complex'],name='fft_of_input') 587 | else: 588 | nets['fft_input'] = tf.spectral.rfft(nets['input_real'],name='fft_of_input') # Since the DFT of a real signal is Hermitian-symmetric, RFFT only returns the fft_length / 2 + 1 unique components of the FFT: the zero-frequency term, followed by the fft_length / 2 positive-frequency terms. 589 | 590 | nets['fft_input'] = tf.expand_dims(nets['fft_input'], 1, name='exd_fft_of_input') 591 | 592 | if debug: # return the real and imaginary parts of the fft separately 593 | nets['fft_input_r'] = tf.real(nets['fft_input']) 594 | nets['fft_input_i'] = tf.imag(nets['fft_input']) 595 | return nets 596 | 597 | 598 | def extract_cochlear_subbands(nets, SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, pad_factor, debug, subbands_ifft, return_subbands_only, rectify_and_lowpass_subbands, rFFT, custom_filts, erb_filter_kwargs, include_all_keys, compression_function, include_subbands_noise, subbands_noise_mean, subbands_noise_stddev): 599 | """ 600 | Computes the cochlear subbands from the fft of the input signal 601 | Parameters 602 | ---------- 603 | nets : dictionary 604 | dictionary containing parts of the cochleagram graph. 'fft_input' is multiplied by the cochlear filters 605 | SIGNAL_SIZE : int 606 | the length of the audio signal used for the cochleagram graph 607 | SR : int 608 | raw sampling rate in Hz for the audio. 609 | LOW_LIM : int 610 | Lower frequency limits for the filters. 611 | HIGH_LIM : int 612 | Higher frequency limits for the filters. 613 | N : int 614 | Number of filters to uniquely span the frequency space 615 | SAMPLE_FACTOR : int 616 | number of times to overcomplete the filters. 617 | N : int 618 | Number of filters to uniquely span the frequency space 619 | SAMPLE_FACTOR : int 620 | number of times to overcomplete the filters. 621 | pad_factor : int 622 | how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length) 623 | debug : boolean 624 | Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal 625 | subbands_ifft : boolean 626 | If true, adds the ifft of the subbands to nets 627 | return_subbands_only : Boolean 628 | If True, returns the non-envelope extracted subbands before taking the hilbert envelope as the output node of the graph 629 | rectify_and_lowpass_subbands : Boolean 630 | If True, rectifies and lowpasses the subbands before returning them (only works with return_subbands_only) 631 | rFFT : Boolean 632 | If true, cochleagram graph is constructed using rFFT wherever possible 633 | custom_filts : None, or numpy array 634 | if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS] 635 | erb_filter_kwargs : dictionary 636 | contains additional arguments with filter parameters to use with erb.make_erb_cos_filters 637 | include_all_keys : Boolean 638 | If True, includes the time subbands and the cochleagram in the dictionary keys 639 | compression_function : function 640 | A partial function that takes in nets and the input and output names to apply compression 641 | include_subbands_noise : boolean (False) 642 | if include_subbands_noise and return_subbands_only are both true, white noise is added to subbands after compression (this feature is currently only accessible when return_subbands_only == True) 643 | subbands_noise_mean : float 644 | sets mean of subbands white noise if include_subbands_noise == True 645 | subbands_noise_stddev : float 646 | sets standard deviation of subbands white noise if include_subbands_noise == True 647 | Returns 648 | ------- 649 | nets : dictionary 650 | updated dictionary containing parts of the cochleagram graph. 651 | """ 652 | 653 | # make the erb filters tensor 654 | nets['filts_tensor'] = make_filts_tensor(SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, use_rFFT=rFFT, pad_factor=pad_factor, custom_filts=custom_filts, erb_filter_kwargs=erb_filter_kwargs) 655 | 656 | # make subbands by multiplying filts with fft of input 657 | nets['subbands'] = tf.multiply(nets['filts_tensor'],nets['fft_input'],name='mul_subbands') 658 | if debug: # return the real and imaginary parts of the subbands separately -- use if matching to their output 659 | nets['subbands_r'] = tf.real(nets['subbands']) 660 | nets['subbands_i'] = tf.imag(nets['subbands']) 661 | 662 | # TODO: with using subbands_ifft is redundant. 663 | # make the time subband operations if we are returning the subbands or if we want to include all of the keys in the graph 664 | if subbands_ifft or return_subbands_only or include_all_keys: 665 | if not rFFT: 666 | nets['subbands_ifft'] = tf.real(tf.ifft(nets['subbands'],name='ifft_subbands'),name='ifft_subbands_r') 667 | else: 668 | nets['subbands_ifft'] = tf.spectral.irfft(nets['subbands'],name='ifft_subbands') 669 | if return_subbands_only or include_all_keys: 670 | nets['subbands_time'] = nets['subbands_ifft'] 671 | if rectify_and_lowpass_subbands: # TODO: the subband operations are hard coded in? 672 | nets['subbands_time_relu'] = tf.nn.relu(nets['subbands_time'], name='rectified_subbands') 673 | nets['subbands_time_lowpassed'] = hanning_pooling_1d_no_depthwise(nets['subbands_time_relu'], downsample=2, length_of_window=2*4, make_plots=False, data_format='NCW', normalize=True, sqrt_window=False) 674 | 675 | # TODO: noise is only added in the case when we are calcalculating the time subbands, but we might want something similar for the cochleagram 676 | if return_subbands_only or include_all_keys: 677 | # Compress subbands if specified and add noise. 678 | nets = compression_function(nets, input_node_name='subbands_time_lowpassed', output_node_name='subbands_time_lowpassed_compressed') 679 | if include_subbands_noise: 680 | nets = add_neural_noise(nets, subbands_noise_mean, subbands_noise_stddev, input_node_name='subbands_time_lowpassed_compressed', output_node_name='subbands_time_lowpassed_compressed_with_noise') 681 | nets['subbands_time_lowpassed_compressed_with_noise'] = tf.expand_dims(nets['subbands_time_lowpassed_compressed_with_noise'],-1) 682 | nets['subbands_time_processed'] = nets['subbands_time_lowpassed_compressed_with_noise'] 683 | else: 684 | nets['subbands_time_lowpassed_compressed'] = tf.expand_dims(nets['subbands_time_lowpassed_compressed'],-1) 685 | nets['subbands_time_processed'] = nets['subbands_time_lowpassed_compressed'] 686 | 687 | return nets 688 | 689 | 690 | def hilbert_transform_from_fft(nets, SR, SIGNAL_SIZE, pad_factor, debug, rFFT): 691 | """ 692 | Performs the hilbert transform from the subband FFT -- gets ifft using only the real parts of the signal 693 | Parameters 694 | ---------- 695 | nets : dictionary 696 | dictionary containing parts of the cochleagram graph. 'subbands' are used for the hilbert transform 697 | SR : int 698 | raw sampling rate in Hz for the audio. 699 | SIGNAL_SIZE : int 700 | the length of the audio signal used for the cochleagram graph 701 | pad_factor : int 702 | how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length) 703 | debug : boolean 704 | Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal when set to True. 705 | rFFT : Boolean 706 | If true, cochleagram graph is constructed using rFFT wherever possible 707 | """ 708 | 709 | if not rFFT: 710 | # make the step tensor for the hilbert transform (only keep the real components) 711 | if pad_factor is not None: 712 | freq_signal = np.fft.fftfreq(SIGNAL_SIZE*pad_factor, 1./SR) 713 | else: 714 | freq_signal = np.fft.fftfreq(SIGNAL_SIZE,1./SR) 715 | nets['step_tensor'] = make_step_tensor(freq_signal) 716 | 717 | # envelopes in frequency domain -- hilbert transform of the subbands 718 | nets['envelopes_freq'] = tf.multiply(nets['subbands'],nets['step_tensor'],name='env_freq') 719 | else: 720 | # make the padding to turn rFFT into a step function 721 | num_filts = nets['filts_tensor'].get_shape().as_list()[1] 722 | # num_batch = nets['subbands'].get_shape().as_list()[0] 723 | num_batch = tf.shape(nets['subbands'])[0] 724 | # TODO: this also might be a problem when we have pad_factor > 1 725 | print(num_batch) 726 | print(num_filts) 727 | print(int(SIGNAL_SIZE/2)-1) 728 | nets['hilbert_padding'] = tf.zeros([num_batch,num_filts,int(SIGNAL_SIZE/2)-1], tf.complex64) 729 | nets['envelopes_freq'] = tf.concat([nets['subbands'],nets['hilbert_padding']],2,name='env_freq') 730 | 731 | if debug: # return real and imaginary parts separately 732 | nets['envelopes_freq_r'] = tf.real(nets['envelopes_freq']) 733 | nets['envelopes_freq_i'] = tf.imag(nets['envelopes_freq']) 734 | 735 | # fft of the envelopes. 736 | nets['envelopes_time'] = tf.ifft(nets['envelopes_freq'],name='ifft_envelopes') 737 | 738 | if not rFFT: # TODO: was this a bug in pycochleagram where the pad factor doesn't actually work? 739 | if pad_factor is not None: 740 | nets['envelopes_time'] = nets['envelopes_time'][:,:,:SIGNAL_SIZE] 741 | 742 | if debug: # return real and imaginary parts separately 743 | nets['envelopes_time_r'] = tf.real(nets['envelopes_time']) 744 | nets['envelopes_time_i'] = tf.imag(nets['envelopes_time']) 745 | 746 | return nets 747 | 748 | 749 | def abs_envelopes(nets, SMOOTH_ABS): 750 | """ 751 | Absolute value of the envelopes (and expand to one channel), analytic hilbert signal 752 | 753 | Parameters 754 | ---------- 755 | nets : dictionary 756 | dictionary containing the cochleagram graph. Downsampling will be applied to 'envelopes_time' 757 | SMOOTH_ABS : Boolean 758 | If True, uses a smoother version of the absolute value for the hilbert transform sqrt(10^-3 + real(env) + imag(env)) 759 | Returns 760 | ------- 761 | nets : dictionary 762 | dictionary containing the updated cochleagram graph 763 | """ 764 | 765 | if SMOOTH_ABS: 766 | nets['envelopes_abs'] = tf.sqrt(1e-10 + tf.square(tf.real(nets['envelopes_time'])) + tf.square(tf.imag(nets['envelopes_time']))) 767 | else: 768 | nets['envelopes_abs'] = tf.abs(nets['envelopes_time'], name='complex_abs_envelopes') 769 | nets['envelopes_abs'] = tf.expand_dims(nets['envelopes_abs'],3, name='exd_abs_real_envelopes') 770 | return nets 771 | 772 | 773 | def downsample_and_rectify(nets, SR, ENV_SR, WINDOW_SIZE, pycoch_downsamp): 774 | """ 775 | Downsamples the cochleagram and then performs rectification on the output (in case the downsampling results in small negative numbers) 776 | Parameters 777 | ---------- 778 | nets : dictionary 779 | dictionary containing the cochleagram graph. Downsampling will be applied to 'envelopes_abs' 780 | SR : int 781 | raw sampling rate of the audio signal 782 | ENV_SR : int 783 | end sampling rate of the envelopes 784 | WINDOW_SIZE : int 785 | the size of the downsampling window (should be large enough to go to zero on the edges). 786 | pycoch_downsamp : Boolean 787 | if true, uses a slightly different downsampling function 788 | Returns 789 | ------- 790 | nets : dictionary 791 | dictionary containing parts of the cochleagram graph with added nodes for the downsampled subbands 792 | """ 793 | # The stride for the downsample, works fine if it is an integer. 794 | DOWNSAMPLE = SR/ENV_SR 795 | if not ENV_SR == SR: 796 | # make the downsample tensor 797 | nets['downsample_filt_tensor'] = make_downsample_filt_tensor(SR, ENV_SR, WINDOW_SIZE, pycoch_downsamp=pycoch_downsamp) 798 | nets['cochleagram_preRELU'] = tf.nn.conv2d(nets['envelopes_abs'], nets['downsample_filt_tensor'], [1, 1, DOWNSAMPLE, 1], 'SAME',name='conv2d_cochleagram_raw') 799 | else: 800 | nets['cochleagram_preRELU'] = nets['envelopes_abs'] 801 | nets['cochleagram_no_compression'] = tf.nn.relu(nets['cochleagram_preRELU'], name='coch_no_compression') 802 | 803 | return nets 804 | 805 | 806 | def include_compression(nets, compression='none', linear_max=796.87416837456942, input_node_name='cochleagram_no_compression', output_node_name='cochleagram', linear_params=None, rate_level_kwargs={}, custom_compression_op=None): 807 | """ 808 | Choose compression operation to use and adds appropriate nodes to nets 809 | Parameters 810 | ---------- 811 | nets : dictionary 812 | dictionary containing parts of the cochleagram graph. Compression will be applied to input_node_name 813 | compression : string 814 | type of compression to perform 815 | linear_max : float 816 | used for the linearbelow compression operations (compression is linear below a value and compressed above it) 817 | input_node_name : string 818 | name in nets to apply the compression 819 | output_node_name : string 820 | name in nets that will be used for the following operation (default is cochleagram, but if returning subbands than it can be chaged) 821 | linear_params : list of floats 822 | used for the linear compression operation, [m, b] where the output of the compression is y=mx+b. m and b can be vectors of shape [1,num_filts,1] to apply different values to each frequency channel. 823 | custom_compression_op : None or tensorflow partial function 824 | if specified as a function, applies the tensorflow function as a custom compression operation. Should take the input node and 'name' as the arguments 825 | Returns 826 | ------- 827 | nets : dictionary 828 | dictionary containing parts of the cochleagram graph with added nodes for the compressed cochleagram 829 | """ 830 | # compression of the cochleagram 831 | if compression=='quarter': 832 | nets[output_node_name] = tf.sqrt(tf.sqrt(nets[input_node_name], name=output_node_name)) 833 | elif compression=='quarter_plus': 834 | nets[output_node_name] = tf.sqrt(tf.sqrt(nets[input_node_name]+1e-01, name=output_node_name)) 835 | elif compression=='point3': 836 | nets[output_node_name] = tf.pow(nets[input_node_name],0.3, name=output_node_name) 837 | elif compression=='stable_point3': 838 | nets[output_node_name] = tf.identity(stable_power_compression(nets[input_node_name]*linear_max),name=output_node_name) 839 | elif compression=='stable_point3_norm_grads': 840 | nets[output_node_name] = tf.identity(stable_power_compression_norm_grad(nets[input_node_name]*linear_max),name=output_node_name) 841 | elif compression=='linearbelow1': 842 | nets[output_node_name] = tf.where((nets[input_node_name]*linear_max)<1, nets[input_node_name]*linear_max, tf.pow(nets[input_node_name]*linear_max,0.3), name=output_node_name) 843 | elif compression=='stable_linearbelow1': 844 | nets['stable_power_compressed_%s'%output_node_name] = tf.identity(stable_power_compression(nets[input_node_name]*linear_max),name='stable_power_compressed_%s'%output_node_name) 845 | nets[output_node_name] = tf.where((nets[input_node_name]*linear_max)<1, nets[input_node_name]*linear_max, nets['stable_power_compressed_%s'%output_node_name], name=output_node_name) 846 | elif compression=='linearbelow1sqrt': 847 | nets[output_node_name] = tf.where((nets[input_node_name]*linear_max)<1, nets[input_node_name]*linear_max, tf.sqrt(nets[input_node_name]*linear_max), name=output_node_name) 848 | elif compression=='quarter_clipped': 849 | nets[output_node_name] = tf.sqrt(tf.sqrt(tf.maximum(nets[input_node_name],1e-01), name=output_node_name)) 850 | elif compression=='none': 851 | nets[output_node_name] = nets[input_node_name] 852 | elif compression=='sqrt': 853 | nets[output_node_name] = tf.sqrt(nets[input_node_name], name=output_node_name) 854 | elif compression=='dB': # NOTE: this compression does not work well for the backwards pass, results in nans 855 | nets[output_node_name + '_noclipped'] = 20 * tflog10(nets[input_node_name])/tf.reduce_max(nets[input_node_name]) 856 | nets[output_node_name] = tf.maximum(nets[output_node_name + '_noclipped'], -60) 857 | elif compression=='dB_plus': # NOTE: this compression does not work well for the backwards pass, results in nans 858 | nets[output_node_name + '_noclipped'] = 20 * tflog10(nets[input_node_name]+1)/tf.reduce_max(nets[input_node_name]+1) 859 | nets[output_node_name] = tf.maximum(nets[output_node_name + '_noclipped'], -60, name=output_node_name) 860 | elif compression=='linear': 861 | assert (type(linear_params)==list) and len(linear_params)==2, "Specifying linear compression but not specifying the compression parameters in linear_params=[m, b]" 862 | nets[output_node_name] = linear_params[0]*nets[input_node_name] + linear_params[1] 863 | elif compression=='rate_level': 864 | nets[output_node_name] = AN_rate_level_function(nets[input_node_name], name=output_node_name, **rate_level_kwargs) 865 | elif compression=='custom': 866 | nets[output_node_name] = custom_compression_op(nets[input_node_name], name=output_node_name) 867 | 868 | return nets 869 | 870 | 871 | def make_step_tensor(freq_signal): 872 | """ 873 | Make step tensor for calcaulting the anlyatic envelopes. 874 | Parameters 875 | __________ 876 | freq_signal : array 877 | numpy array containing the frequenies of the audio signal (as calculated by np.fft.fftfreqs). 878 | Returns 879 | ------- 880 | step_tensor : tensorflow tensor 881 | tensorflow tensor with dimensions [0 len(freq_signal) 0 0] as a step function where frequencies > 0 are 1 and frequencies < 0 are 0. 882 | """ 883 | step_func = (freq_signal>=0).astype(np.int)*2 # wikipedia says that this should be 2x the original. 884 | step_func[freq_signal==0] = 0 # https://en.wikipedia.org/wiki/Analytic_signal (this shouldn't actually matter i think. 885 | step_tensor = tf.constant(step_func, dtype=tf.complex64) 886 | step_tensor = tf.expand_dims(step_tensor, 0) 887 | step_tensor = tf.expand_dims(step_tensor, 1) 888 | return step_tensor 889 | 890 | 891 | def make_filts_tensor(SIGNAL_SIZE, SR=16000, LOW_LIM=20, HIGH_LIM=8000, N=40, SAMPLE_FACTOR=4, use_rFFT=False, pad_factor=None, custom_filts=None, erb_filter_kwargs={}): 892 | """ 893 | Use pycochleagram to make the filters using the specified prameters (make_erb_cos_filters_nx). Then input them into a tensorflow tensor to be used in the tensorflow cochleagram graph. 894 | Parameters 895 | ---------- 896 | SIGNAL_SIZE: int 897 | length of the audio signal to convert, and the size of cochleagram filters to make. 898 | SR : int 899 | raw sampling rate in Hz for the audio. 900 | LOW_LIM : int 901 | Lower frequency limits for the filters. 902 | HIGH_LIM : int 903 | Higher frequency limits for the filters. 904 | N : int 905 | Number of filters to uniquely span the frequency space 906 | SAMPLE_FACTOR : int 907 | number of times to overcomplete the filters. 908 | use_rFFT : Boolean 909 | if True, the only returns the first half of the filters, corresponding to the positive component. 910 | custom_filts : None, or numpy array 911 | if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS] 912 | erb_filter_kwargs : dictionary 913 | contains additional arguments with filter parameters to use with erb.make_erb_cos_filters 914 | Returns 915 | ------- 916 | filts_tensor : tensorflow tensor, complex 917 | tensorflow tensor with dimensions [0 SIGNAL_SIZE NUMBER_OF_FILTERS] that includes the erb filters created from make_erb_cos_filters_nx in pycochleagram 918 | """ 919 | if pad_factor: 920 | padding_size = (pad_factor-1)*SIGNAL_SIZE 921 | else: 922 | padding_size=None 923 | 924 | if custom_filts is None: 925 | # make the filters 926 | filts, hz_cutoffs, freqs = make_erb_cos_filters_nx(SIGNAL_SIZE, SR, N, LOW_LIM, HIGH_LIM, SAMPLE_FACTOR, padding_size=padding_size, **erb_filter_kwargs) #TODO: decide if we want to change the pad_factor and full_filter arguments. 927 | else: # TODO: ADD CHECKS TO MAKE SURE THAT THESE MATCH UP WITH THE INPUT SIGNAL 928 | assert custom_filts.shape[1] == SIGNAL_SIZE, "CUSTOM FILTER SHAPE DOES NOT MATCH THE INPUT AUDIO SHAPE" 929 | filts = custom_filts 930 | 931 | if not use_rFFT: 932 | filts_tensor = tf.constant(filts, tf.complex64) 933 | else: # TODO I believe that this is where the padd factor problem comes in! We are only using part of the signal here. 934 | filts_tensor = tf.constant(filts[:,0:(int(SIGNAL_SIZE/2)+1)], tf.complex64) 935 | 936 | filts_tensor = tf.expand_dims(filts_tensor, 0) 937 | 938 | return filts_tensor 939 | 940 | 941 | def make_downsample_filt_tensor(SR=16000, ENV_SR=200, WINDOW_SIZE=1001, pycoch_downsamp=False): 942 | """ 943 | Make the sinc filter that will be used to downsample the cochleagram 944 | Parameters 945 | ---------- 946 | SR : int 947 | raw sampling rate of the audio signal 948 | ENV_SR : int 949 | end sampling rate of the envelopes 950 | WINDOW_SIZE : int 951 | the size of the downsampling window (should be large enough to go to zero on the edges). 952 | pycoch_downsamp : Boolean 953 | if true, uses a slightly different downsampling function 954 | Returns 955 | ------- 956 | downsample_filt_tensor : tensorflow tensor, tf.float32 957 | a tensor of shape [0 WINDOW_SIZE 0 0] the sinc windows with a kaiser lowpass filter that is applied while downsampling the cochleagram 958 | """ 959 | DOWNSAMPLE = SR/ENV_SR 960 | if not pycoch_downsamp: 961 | downsample_filter_times = np.arange(-WINDOW_SIZE/2,int(WINDOW_SIZE/2)) 962 | downsample_filter_response_orig = np.sinc(downsample_filter_times/DOWNSAMPLE)/DOWNSAMPLE 963 | downsample_filter_window = signal.kaiser(WINDOW_SIZE, 5) 964 | downsample_filter_response = downsample_filter_window * downsample_filter_response_orig 965 | else: 966 | max_rate = DOWNSAMPLE 967 | f_c = 1. / max_rate # cutoff of FIR filter (rel. to Nyquist) 968 | half_len = 10 * max_rate # reasonable cutoff for our sinc-like function 969 | if max_rate!=1: 970 | downsample_filter_response = signal.firwin(2 * half_len + 1, f_c, window=('kaiser', 5.0)) 971 | else: # just in case we aren't downsampling -- I think this should work? 972 | downsample_filter_response = zeros(2 * half_len + 1) 973 | downsample_filter_response[half_len + 1] = 1 974 | 975 | # Zero-pad our filter to put the output samples at the center 976 | # n_pre_pad = int((DOWNSAMPLE - half_len % DOWNSAMPLE)) 977 | # n_post_pad = 0 978 | # n_pre_remove = (half_len + n_pre_pad) // DOWNSAMPLE 979 | # We should rarely need to do this given our filter lengths... 980 | # while _output_len(len(h) + n_pre_pad + n_post_pad, x.shape[axis], 981 | # up, down) < n_out + n_pre_remove: 982 | # n_post_pad += 1 983 | # downsample_filter_response = np.concatenate((np.zeros(n_pre_pad), downsample_filter_response, np.zeros(n_post_pad))) 984 | 985 | downsample_filt_tensor = tf.constant(downsample_filter_response, tf.float32) 986 | downsample_filt_tensor = tf.expand_dims(downsample_filt_tensor, 0) 987 | downsample_filt_tensor = tf.expand_dims(downsample_filt_tensor, 2) 988 | downsample_filt_tensor = tf.expand_dims(downsample_filt_tensor, 3) 989 | 990 | return downsample_filt_tensor 991 | 992 | 993 | def add_neural_noise(nets, subbands_noise_mean, subbands_noise_stddev, input_node_name='subbands_time_lowpassed_compressed', output_node_name='subbands_time_lowpassed_compressed_with_noise'): 994 | # Add white noise variable with the same size to the rectified and compressed subbands 995 | nets['neural_noise'] = tf.random.normal(tf.shape(nets[input_node_name]), mean=subbands_noise_mean, 996 | stddev=subbands_noise_stddev, dtype=nets[input_node_name].dtype) 997 | nets[output_node_name] = tf.nn.relu(tf.math.add(nets[input_node_name], nets['neural_noise'])) 998 | return nets 999 | 1000 | 1001 | def reshape_coch_kell_2018(nets): 1002 | """ 1003 | Wrapper to reshape the cochleagram to 256x256 similar to that used in kell2018. 1004 | Note that this function relies on tf.image.resize_images which can have unexpected behavior... use with caution. 1005 | nets : dictionary 1006 | dictionary containing parts of the cochleagram graph. should already contain cochleagram 1007 | """ 1008 | print('### WARNING: tf.image.resize_images is not trusted, use caution ###') 1009 | nets['min_cochleagram'] = tf.reduce_min(nets['cochleagram']) 1010 | nets['max_cochleagram'] = tf.reduce_max(nets['cochleagram']) 1011 | # it is possible that this scaling is going to mess up the gradients for the waveform generation 1012 | nets['scaled_cochleagram'] = 255*(1-((nets['max_cochleagram']-nets['cochleagram'])/(nets['max_cochleagram']-nets['min_cochleagram']))) 1013 | nets['reshaped_cochleagram'] = tf.image.resize_images(nets['scaled_cochleagram'],[256,256], align_corners=False, preserve_aspect_ratio=False) 1014 | return nets, 'reshaped_cochleagram' 1015 | 1016 | 1017 | def convert_Pa_to_dBSPL(pa): 1018 | """ Converts units of Pa to dB re 20e-6 Pa (dB SPL) """ 1019 | return 20. * np.log10(pa / 20e-6) 1020 | 1021 | 1022 | def convert_dBSPL_to_Pa(dbspl): 1023 | """ Converts units of dB re 20e-6 Pa (dB SPL) to Pa """ 1024 | return 20e-6 * np.power(10., dbspl / 20.) 1025 | 1026 | 1027 | def AN_rate_level_function(tensor_subbands, name='rate_level_fcn', rate_spont=70., rate_max=250., 1028 | rate_normalize=True, beta=3., halfmax_dBSPL=20.): 1029 | """ 1030 | Function implements the auditory nerve rate-level function described by Peter Heil 1031 | and colleagues (2011, J. Neurosci.): the "amplitude-additivity model". 1032 | 1033 | Args 1034 | ---- 1035 | tensor_subbands (tensor): shape must be [batch, freq, time, (channel)], units are Pa 1036 | name (str): name for the tensorflow operation 1037 | rate_spont (float): spontaneous spiking rate (spikes/s) 1038 | rate_max (float): maximum spiking rate (spikes/s) 1039 | rate_normalize (bool): if True, output will be re-scaled between 0 and 1 1040 | beta (float or list): determines the steepness of rate-level function (dimensionless) 1041 | halfmax_dBSPL (float or list): determines threshold of rate-level function (units dB SPL) 1042 | 1043 | Returns 1044 | ------- 1045 | tensor_rates (tensor): same shape as tensor_subbands, units are spikes/s or normalized 1046 | """ 1047 | # Check arguments and compute shape for frequency-channel-specific parameters 1048 | assert rate_spont > 0, "rate_spont must be greater than zero to avoid division by zero" 1049 | if len(tensor_subbands.shape) == 3: 1050 | freq_specific_shape = [tensor_subbands.shape[1], 1] 1051 | elif len(tensor_subbands.shape) == 4: 1052 | freq_specific_shape = [tensor_subbands.shape[1], 1, 1] 1053 | else: 1054 | raise ValueError("tensor_subbands must have shape [batch, freq, time, (channel)]") 1055 | # Convert beta to tensor (can be a single value or frequency channel specific) 1056 | beta = np.array(beta).reshape([-1]) 1057 | assert_msg = "beta must be one value or a list of length {}".format(tensor_subbands.shape[1]) 1058 | assert len(beta) == 1 or len(beta) == tensor_subbands.shape[1], assert_msg 1059 | beta_vals = tf.constant(beta, 1060 | dtype=tensor_subbands.dtype, 1061 | shape=freq_specific_shape) 1062 | # Convert halfmax_dBSPL to tensor (can be a single value or frequency channel specific) 1063 | halfmax_dBSPL = np.array(halfmax_dBSPL).reshape([-1]) 1064 | assert_msg = "halfmax_dBSPL must be one value or a list of length {}".format(tensor_subbands.shape[1]) 1065 | assert len(halfmax_dBSPL) == 1 or len(halfmax_dBSPL) == tensor_subbands.shape[1], assert_msg 1066 | P_halfmax = tf.constant(convert_dBSPL_to_Pa(halfmax_dBSPL), 1067 | dtype=tensor_subbands.dtype, 1068 | shape=freq_specific_shape) 1069 | # Convert rate_spont and rate_max to tf.constants (single values) 1070 | R_spont = tf.constant(rate_spont, dtype=tensor_subbands.dtype, shape=[]) 1071 | R_max = tf.constant(rate_max, dtype=tensor_subbands.dtype, shape=[]) 1072 | # Implementation analogous to equation (8) from Heil et al. (2011, J. Neurosci.) 1073 | P_0 = P_halfmax / (tf.pow((R_max + R_spont) / R_spont, 1/beta_vals) - 1) 1074 | R_func = lambda P: R_max / (1 + ((R_max - R_spont) / R_spont) * tf.pow(P / P_0 + 1, -beta_vals)) 1075 | tensor_rates = tf.map_fn(R_func, tensor_subbands, name=name) 1076 | # If rate_normalize is True, re-scale spiking rates to fall between 0 and 1 1077 | if rate_normalize: 1078 | tensor_rates = (tensor_rates - R_spont) / (R_max - R_spont) 1079 | return tensor_rates 1080 | 1081 | 1082 | def make_hanning_kernel_1d(downsample=2, length_of_window=8, make_plots=False, normalize=False, sqrt_window=True): 1083 | """ 1084 | Make the symmetric 1d hanning kernel to use for the pooling filters 1085 | For downsample=2, using length_of_window=8 gives a reduction of -24.131545969216841 at 0.25 cycles 1086 | For downsample=3, using length_of_window=12 gives a reduction of -28.607805482176282 at 1/6 cycles 1087 | For downsample=4, using length_of_window=15 gives a reduction of -23 at 1/8 cycles 1088 | We want to reduce the frequencies above the nyquist by at least 20dB. 1089 | Parameters 1090 | ---------- 1091 | downsample : int 1092 | proportion downsampling 1093 | length_of_window : int 1094 | how large of a window to use 1095 | make_plots: boolean 1096 | make plots of the filters 1097 | normalize : boolean 1098 | if true, divide the filter by the sum of its values, so that the smoothed signal is the same amplitude as the original. 1099 | sqrt_window : boolean 1100 | if true, takes the sqrt of the window (old version) -- normal window generation has sqrt_window=False 1101 | Returns 1102 | ------- 1103 | one_dimensional_kernel : numpy array 1104 | hanning kernel in 1d to use as a kernel for filtering 1105 | """ 1106 | 1107 | window = 0.5 * (1 - np.cos(2.0 * np.pi * (np.arange(length_of_window)) / (length_of_window - 1))) 1108 | if sqrt_window: 1109 | one_dimensional_kernel = np.sqrt(window) 1110 | else: 1111 | one_dimensional_kernel = window 1112 | 1113 | if normalize: 1114 | one_dimensional_kernel = one_dimensional_kernel/sum(one_dimensional_kernel) 1115 | window = one_dimensional_kernel 1116 | 1117 | if make_plots: 1118 | A = np.fft.fft(window, 2048) / (len(window) / 2.0) 1119 | freq = np.linspace(-0.5, 0.5, len(A)) 1120 | response = 20.0 * np.log10(np.abs(np.fft.fftshift(A / abs(A).max()))) 1121 | 1122 | nyquist = 1 / (2 * downsample) 1123 | ny_idx = np.where(np.abs(freq - nyquist) == np.abs(freq - nyquist).min())[0][0] 1124 | print(['Frequency response at ' + 'nyquist (%.3f Hz)'%nyquist + ' is ' + '%d'%response[ny_idx]]) 1125 | plt.figure() 1126 | plt.plot(window) 1127 | plt.title(r"Hanning window") 1128 | plt.ylabel("Amplitude") 1129 | plt.xlabel("Sample") 1130 | plt.figure() 1131 | plt.plot(freq, response) 1132 | plt.axis([-0.5, 0.5, -120, 0]) 1133 | plt.title(r"Frequency response of the Hanning window") 1134 | plt.ylabel("Normalized magnitude [dB]") 1135 | plt.xlabel("Normalized frequency [cycles per sample]") 1136 | 1137 | return one_dimensional_kernel 1138 | 1139 | 1140 | def make_hanning_kernel_tensor_1d(n_channels, downsample=2, length_of_window=8, make_plots=False, normalize=False, sqrt_window=True): 1141 | """ 1142 | Make a tensor containing the symmetric 1d hanning kernel to use for the pooling filters 1143 | For downsample=2, using length_of_window=8 gives a reduction of -24.131545969216841 at 0.25 cycles 1144 | For downsample=3, using length_of_window=12 gives a reduction of -28.607805482176282 at 1/6 cycles 1145 | For downsample=4, using length_of_window=15 gives a reduction of -23 at 1/8 cycles 1146 | We want to reduce the frequencies above the nyquist by at least 20dB. 1147 | Parameters 1148 | ---------- 1149 | n_channels : int 1150 | number of channels to copy the kernel into 1151 | downsample : int 1152 | proportion downsampling 1153 | length_of_window : int 1154 | how large of a window to use 1155 | make_plots: boolean 1156 | make plots of the filters 1157 | normalize : boolean 1158 | if true, divide the filter by the sum of its values, so that the smoothed signal is the same amplitude as the original. 1159 | sqrt_window : boolean 1160 | if true, takes the sqrt of the window (old version) -- normal window generation has sqrt_window=False 1161 | Returns 1162 | ------- 1163 | hanning_tensor : tensorflow tensor 1164 | tensorflow tensor containing the hanning tensor with size [1 length_of_window n_channels 1] 1165 | """ 1166 | hanning_kernel = make_hanning_kernel_1d(downsample=downsample,length_of_window=length_of_window,make_plots=make_plots, normalize=normalize, sqrt_window=sqrt_window) 1167 | hanning_kernel = np.expand_dims(np.dstack([hanning_kernel.astype(np.float32)]*n_channels),axis=3) 1168 | hanning_tensor = tf.constant(hanning_kernel) 1169 | return hanning_tensor 1170 | 1171 | 1172 | def hanning_pooling_1d(input_tensor, downsample=2, length_of_window=8, make_plots=False, data_format='NWC', normalize=False, sqrt_window=True): 1173 | """ 1174 | Parameters 1175 | ---------- 1176 | input_tensor : tensorflow tensor 1177 | tensor on which we will apply the hanning pooling operation 1178 | downsample : int 1179 | proportion downsampling 1180 | length_of_window : int 1181 | how large of a window to use 1182 | make_plots: boolean 1183 | make plots of the filters 1184 | data_format : 'NWC' or 'NCW' 1185 | Defaults to "NWC", the data is stored in the order of [batch, in_width, in_channels]. 1186 | The "NCW" format stores data as [batch, in_channels, in_width]. 1187 | normalize : boolean 1188 | if true, divide the filter by the sum of its values, so that the smoothed signal is the same amplitude as the original. 1189 | sqrt_window : boolean 1190 | if true, takes the sqrt of the window (old version) -- normal window generation has sqrt_window=False 1191 | Returns 1192 | ------- 1193 | output_tensor : tensorflow tensor 1194 | tensorflow tensor containing the downsampled input_tensor of shape corresponding to data_format 1195 | """ 1196 | 1197 | if data_format=='NWC': 1198 | n_channels = input_tensor.get_shape().as_list()[2] 1199 | elif data_format=='NCW': 1200 | batch_size, n_channels, in_width = input_tensor.get_shape().as_list() 1201 | input_tensor = tf.transpose(input_tensor, [0, 2, 1]) # reshape to [batch_size, in_wdith, in_channels] 1202 | 1203 | input_tensor = tf.expand_dims(input_tensor,1) # reshape to [batch_size, 1, in_width, in_channels] 1204 | h_tensor = make_hanning_kernel_tensor_1d(n_channels, downsample=downsample, length_of_window=length_of_window, make_plots=make_plots, normalize=normalize, sqrt_window=sqrt_window) 1205 | 1206 | output_tensor = tf.nn.depthwise_conv2d(input_tensor, h_tensor, strides=[1, downsample, downsample, 1], padding='SAME', name='hpooling') 1207 | 1208 | output_tensor = tf.squeeze(output_tensor, name='squeeze_output') 1209 | if data_format=='NWC': 1210 | return output_tensor 1211 | elif data_format=='NCW': 1212 | return tf.transpose(output_tensor, [0, 2, 1]) # reshape to [batch_size, in_channels, out_width] 1213 | 1214 | 1215 | def make_hanning_kernel_tensor_1d_no_depthwise(n_channels, downsample=2, length_of_window=8, make_plots=False, normalize=False, sqrt_window=True): 1216 | """ 1217 | Make a tensor containing the symmetric 1d hanning kernel to use for the pooling filters 1218 | For downsample=2, using length_of_window=8 gives a reduction of -24.131545969216841 at 0.25 cycles 1219 | For downsample=3, using length_of_window=12 gives a reduction of -28.607805482176282 at 1/6 cycles 1220 | For downsample=4, using length_of_window=15 gives a reduction of -23 at 1/8 cycles 1221 | We want to reduce the frequencies above the nyquist by at least 20dB. 1222 | Parameters 1223 | ---------- 1224 | n_channels : int 1225 | number of channels to copy the kernel into 1226 | downsample : int 1227 | proportion downsampling 1228 | length_of_window : int 1229 | how large of a window to use 1230 | make_plots: boolean 1231 | make plots of the filters 1232 | normalize : boolean 1233 | if true, divide the filter by the sum of its values, so that the smoothed signal is the same amplitude as the original. 1234 | sqrt_window : boolean 1235 | if true, takes the sqrt of the window (old version) -- normal window generation has sqrt_window=False 1236 | Returns 1237 | ------- 1238 | hanning_tensor : tensorflow tensor 1239 | tensorflow tensor containing the hanning tensor with size [length_of_window, num_channels, num_channels] 1240 | """ 1241 | hanning_kernel = make_hanning_kernel_1d(downsample=downsample,length_of_window=length_of_window,make_plots=make_plots, normalize=normalize, sqrt_window=sqrt_window).astype(np.float32) 1242 | hanning_kernel_expanded = np.expand_dims(hanning_kernel,0) * np.expand_dims(np.eye(n_channels),3).astype(np.float32) # [n_channels, n_channels, filter_width] 1243 | hanning_tensor = tf.constant(hanning_kernel_expanded) # [length_of_window, num_channels, num_channels] 1244 | hanning_tensor = tf.transpose(hanning_tensor, [2, 0, 1]) 1245 | return hanning_tensor 1246 | 1247 | 1248 | def hanning_pooling_1d_no_depthwise(input_tensor, downsample=2, length_of_window=8, make_plots=False, data_format='NWC', normalize=False, sqrt_window=True): 1249 | """ 1250 | Parameters 1251 | ---------- 1252 | input_tensor : tensorflow tensor 1253 | tensor on which we will apply the hanning pooling operation 1254 | downsample : int 1255 | proportion downsampling 1256 | length_of_window : int 1257 | how large of a window to use 1258 | make_plots: boolean 1259 | make plots of the filters 1260 | data_format : 'NWC' or 'NCW' 1261 | Defaults to "NWC", the data is stored in the order of [batch, in_width, in_channels]. 1262 | The "NCW" format stores data as [batch, in_channels, in_width]. 1263 | normalize : boolean 1264 | if true, divide the filter by the sum of its values, so that the smoothed signal is the same amplitude as the original. 1265 | make_hanning_kernel_tensor_1d_no_depthwise 1266 | sqrt_window : boolean 1267 | if true, takes the sqrt of the window (old version) -- normal window generation has sqrt_window=False 1268 | Returns 1269 | ------- 1270 | output_tensor : tensorflow tensor 1271 | tensorflow tensor containing the downsampled input_tensor of shape corresponding to data_format 1272 | """ 1273 | 1274 | if data_format=='NWC': 1275 | n_channels = input_tensor.get_shape().as_list()[2] 1276 | elif data_format=='NCW': 1277 | batch_size, n_channels, in_width = input_tensor.get_shape().as_list() 1278 | input_tensor = tf.transpose(input_tensor, [0, 2, 1]) # reshape to [batch_size, in_wdith, in_channels] 1279 | 1280 | h_tensor = make_hanning_kernel_tensor_1d_no_depthwise(n_channels, downsample=downsample, length_of_window=length_of_window, make_plots=make_plots, normalize=normalize, sqrt_window=sqrt_window) 1281 | 1282 | output_tensor = tf.nn.conv1d(input_tensor, h_tensor, stride=downsample, padding='SAME', name='hpooling') 1283 | 1284 | if data_format=='NWC': 1285 | return output_tensor 1286 | elif data_format=='NCW': 1287 | return tf.transpose(output_tensor, [0, 2, 1]) # reshape to [batch_size, in_channels, out_width] 1288 | 1289 | 1290 | def build_cochlear_model(tensor_waveform, 1291 | signal_rate=20000, 1292 | filter_type='half-cosine', 1293 | filter_spacing='erb', 1294 | HIGH_LIM=8000, 1295 | LOW_LIM=20, 1296 | N=40, 1297 | SAMPLE_FACTOR=1, 1298 | bandwidth_scale_factor=1.0, 1299 | compression='stable_point3', 1300 | include_highpass=False, 1301 | include_lowpass=False, 1302 | linear_max=1.0, 1303 | rFFT=True, 1304 | rectify_and_lowpass_subbands=True, 1305 | return_subbands_only=True, 1306 | **kwargs): 1307 | """ 1308 | This function serves as a wrapper for `tfcochleagram_graph` and builds the cochlear model graph. 1309 | * * * * * * Default arguments are set to those used to train recognition networks * * * * * * 1310 | 1311 | Parameters 1312 | ---------- 1313 | tensor_waveform (tensor): input signal waveform (with shape [batch, time]) 1314 | signal_rate (int): sampling rate of signal waveform in Hz 1315 | filter_type (str): type of cochlear filters to build ('half-cosine') 1316 | filter_spacing (str, default='erb'): Specifies the type of reference spacing for the 1317 | half-cosine filters. Options include 'erb' and 'linear'. 1318 | HIGH_LIM (float): high frequency cutoff of filterbank (only used for 'half-cosine') 1319 | LOW_LIM (float): low frequency cutoff of filterbank (only used for 'half-cosine') 1320 | N (int): number of cochlear bandpass filters 1321 | SAMPLE_FACTOR (int): specifies how densely to sample cochlea (only used for 'half-cosine') 1322 | bandwidth_scale_factor (float): factor by which to symmetrically scale the filter bandwidths 1323 | bandwidth_scale_factor=2.0 means filters will be twice as wide. 1324 | Note that values < 1 will cause frequency gaps between the filters. 1325 | include_highpass (bool): determines if filterbank includes highpass filter(s) (only used for 'half-cosine') 1326 | include_lowpass (bool): determines if filterbank includes lowpass filter(s) (only used for 'half-cosine') 1327 | linear_max (float): used for the linearbelow compression operations 1328 | (compression is linear below a value and compressed above it) 1329 | rFFT (bool): If True, builds the graph using rFFT and irFFT operations whenever possible 1330 | rectify_and_lowpass_subbands (bool): If True, rectifies and lowpass-filters subbands before returning 1331 | return_subbands_only (bool): If True, returns subbands before taking the hilbert envelope as the output node 1332 | kwargs (dict): additional keyword arguments passed directly to tfcochleagram_graph 1333 | 1334 | Returns 1335 | ------- 1336 | tensor_cochlear_representation (tensor): output cochlear representation 1337 | coch_container (dict): dictionary containing cochlear model stages 1338 | """ 1339 | signal_length = tensor_waveform.get_shape().as_list()[-1] 1340 | 1341 | if filter_type == 'half-cosine': 1342 | assert HIGH_LIM <= signal_rate/2, "cochlear filterbank high_lim is above Nyquist frequency" 1343 | filts, center_freqs, freqs = make_cos_filters_nx( 1344 | signal_length, 1345 | signal_rate, 1346 | N, 1347 | LOW_LIM, 1348 | HIGH_LIM, 1349 | SAMPLE_FACTOR, 1350 | padding_size=None, 1351 | full_filter=True, 1352 | strict=True, 1353 | bandwidth_scale_factor=bandwidth_scale_factor, 1354 | include_lowpass=include_lowpass, 1355 | include_highpass=include_highpass, 1356 | filter_spacing=filter_spacing) 1357 | assert filts.shape[1] == signal_length, "filter array shape must match signal length" 1358 | else: 1359 | raise ValueError('Specified filter_type {} is not supported'.format(filter_type)) 1360 | 1361 | coch_container = {'input_signal': tensor_waveform} 1362 | coch_container = cochleagram_graph( 1363 | coch_container, 1364 | signal_length, 1365 | signal_rate, 1366 | LOW_LIM=LOW_LIM, 1367 | HIGH_LIM=HIGH_LIM, 1368 | N=N, 1369 | SAMPLE_FACTOR=SAMPLE_FACTOR, 1370 | custom_filts=filts, 1371 | linear_max=linear_max, 1372 | rFFT=rFFT, 1373 | rectify_and_lowpass_subbands=rectify_and_lowpass_subbands, 1374 | return_subbands_only=return_subbands_only, 1375 | **kwargs) 1376 | 1377 | tensor_cochlear_representation = coch_container['output_tfcoch_graph'] 1378 | return tensor_cochlear_representation, coch_container 1379 | -------------------------------------------------------------------------------- /util_recognition_network.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def build_network(tensor_input, list_layer_dict, n_classes_dict={}): 8 | """ 9 | Build tensorflow graph for a feedforward neural network given an 10 | input tensor and list of layer descriptions 11 | """ 12 | tensor_output = tensor_input 13 | tensors_dict = {} 14 | for layer_dict in list_layer_dict: 15 | layer_type = layer_dict['layer_type'] 16 | if 'batch_normalization' in layer_type: 17 | layer = tf.keras.layers.BatchNormalization(**layer_dict['args']) 18 | elif 'conv2d' in layer_type: 19 | layer = PaddedConv2D(**layer_dict['args']) 20 | elif 'dense' in layer_type: 21 | layer = tf.keras.layers.Dense(**layer_dict['args']) 22 | elif 'dropout' in layer_type: 23 | layer = tf.keras.layers.Dropout(**layer_dict['args']) 24 | elif 'hpool' in layer_type: 25 | layer = HanningPooling(**layer_dict['args']) 26 | elif 'flatten' in layer_type: 27 | layer = tf.keras.layers.Flatten(**layer_dict['args']) 28 | elif 'leaky_relu' in layer_type: 29 | layer = tf.keras.layers.LeakyReLU(**layer_dict['args']) 30 | elif 'relu' in layer_type: 31 | layer = tf.keras.layers.ReLU(**layer_dict['args']) 32 | elif 'fc_top' in layer_type: 33 | layer = DenseTaskHeads( 34 | n_classes_dict=n_classes_dict, 35 | **layer_dict['args']) 36 | else: 37 | msg = "layer_type={} not recognized".format(layer_type) 38 | raise NotImplementedError(msg) 39 | tensor_output = layer(tensor_output) 40 | if layer_dict.get('args', {}).get('name', None) is not None: 41 | tensors_dict[layer_dict['args']['name']] = tensor_output 42 | return tensor_output, tensors_dict 43 | 44 | 45 | def same_pad_along_axis(tensor_input, 46 | kernel_length=1, 47 | stride_length=1, 48 | axis=1, 49 | **kwargs): 50 | """ 51 | Adds 'SAME' padding to only specified axis of tensor_input 52 | for 2D convolution 53 | """ 54 | x = tensor_input.shape.as_list()[axis] 55 | if x % stride_length == 0: 56 | p = kernel_length - stride_length 57 | else: 58 | p = kernel_length - (x % stride_length) 59 | p = tf.math.maximum(p, 0) 60 | paddings = [(0, 0)] * len(tensor_input.shape) 61 | paddings[axis] = (p // 2, p - p // 2) 62 | return tf.pad(tensor_input, paddings, **kwargs) 63 | 64 | 65 | def PaddedConv2D(filters, 66 | kernel_size, 67 | strides=(1, 1), 68 | padding='VALID', 69 | **kwargs): 70 | """ 71 | Wrapper function around tf.keras.layers.Conv2D to support 72 | custom padding options 73 | """ 74 | if padding.upper() == 'VALID_TIME': 75 | pad_function = lambda x : same_pad_along_axis( 76 | x, 77 | kernel_length=kernel_size[0], 78 | stride_length=strides[0], 79 | axis=1) 80 | padding = 'VALID' 81 | else: 82 | pad_function = lambda x: x 83 | def layer(tensor_input): 84 | conv2d_layer = tf.keras.layers.Conv2D( 85 | filters, 86 | kernel_size, 87 | strides=strides, 88 | padding=padding, 89 | **kwargs) 90 | return conv2d_layer(pad_function(tensor_input)) 91 | return layer 92 | 93 | 94 | def HanningPooling(strides=2, 95 | pool_size=8, 96 | padding='SAME', 97 | sqrt_window=False, 98 | normalize=False, 99 | name=None): 100 | """ 101 | Weighted average pooling layer with Hanning window applied via 102 | 2D convolution (with identity transform as depthwise component) 103 | """ 104 | if isinstance(strides, int): 105 | strides = (strides, strides) 106 | if isinstance(pool_size, int): 107 | pool_size = (pool_size, pool_size) 108 | assert len(strides) == 2, "HanningPooling expects 2D args" 109 | assert len(pool_size) == 2, "HanningPooling expects 2D args" 110 | 111 | (dim0, dim1) = pool_size 112 | if dim0 == 1: 113 | win0 = np.ones(dim0) 114 | else: 115 | win0 = (1 - np.cos(2 * np.pi * np.arange(dim0) / (dim0 - 1))) / 2 116 | if dim1 == 1: 117 | win1 = np.ones(dim1) 118 | else: 119 | win1 = (1 - np.cos(2 * np.pi * np.arange(dim1) / (dim1 - 1))) / 2 120 | hanning_window = np.outer(win0, win1) 121 | if sqrt_window: 122 | hanning_window = np.sqrt(hanning_window) 123 | if normalize: 124 | hanning_window = hanning_window / hanning_window.sum() 125 | 126 | if padding.upper() == 'VALID_TIME': 127 | pad_function = lambda x : same_pad_along_axis( 128 | x, 129 | kernel_length=pool_size[0], 130 | stride_length=strides[0], 131 | axis=1) 132 | padding = 'VALID' 133 | else: 134 | pad_function = lambda x: x 135 | 136 | def layer(tensor_input): 137 | tensor_hanning_window = tf.constant( 138 | hanning_window[:, :, np.newaxis, np.newaxis], 139 | dtype=tensor_input.dtype, 140 | name="{}_hanning_window".format(name)) 141 | tensor_eye = tf.eye( 142 | num_rows=tensor_input.shape.as_list()[-1], 143 | num_columns=None, 144 | batch_shape=[1, 1], 145 | dtype=tensor_input.dtype, 146 | name=None) 147 | tensor_output = tf.nn.convolution( 148 | pad_function(tensor_input), 149 | tensor_hanning_window * tensor_eye, 150 | strides=strides, 151 | padding=padding, 152 | data_format=None, 153 | name=name) 154 | return tensor_output 155 | 156 | return layer 157 | 158 | 159 | def DenseTaskHeads(n_classes_dict={}, name='logits', **kwargs): 160 | """ 161 | Dense layer for each task head specified in n_classes_dict 162 | """ 163 | def layer(tensor_input): 164 | tensors_logits = {} 165 | for key in sorted(n_classes_dict.keys()): 166 | if len(n_classes_dict.keys()) > 1: 167 | classification_name = '{}_{}'.format(name, key) 168 | else: 169 | classification_name = name 170 | classification_layer = tf.keras.layers.Dense( 171 | units=n_classes_dict[key], 172 | name=classification_name, 173 | **kwargs) 174 | tensors_logits[key] = classification_layer(tensor_input) 175 | return tensors_logits 176 | return layer 177 | --------------------------------------------------------------------------------