├── .gitignore
├── DEMO.ipynb
├── DEMO_train_audio_transform.ipynb
├── LICENSE.md
├── README.md
├── Wave-U-Net
├── InterpolationLayer.py
├── LICENSE
├── OutputLayer.py
├── README.md
├── UnetAudioSeparator.py
└── unet_utils.py
├── audio
├── ex01_unprocessed_input.wav
├── ex02_unprocessed_input.wav
├── ex03_unprocessed_input.wav
├── ex04_unprocessed_input.wav
├── ex05_unprocessed_input.wav
├── ex06_unprocessed_input.wav
├── ex07_unprocessed_input.wav
├── ex08_unprocessed_input.wav
├── ex09_unprocessed_input.wav
└── ex10_unprocessed_input.wav
├── data
├── mturk_experiment_1_results_interspeech2021.csv
├── mturk_experiment_2_results_interspeech2021.csv
└── toy_dataset_0000.tfrecords
├── models
├── .gitkeep
├── audio_transforms
│ └── .gitkeep
└── recognition_networks
│ ├── .gitkeep
│ ├── arch1.json
│ ├── arch2.json
│ ├── arch3.json
│ └── deep_feature_loss_weights.json
├── util_audio_preprocess.py
├── util_audio_transform.py
├── util_auditory_model_loss.py
├── util_cochlear_model.py
└── util_recognition_network.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # SLURM
107 | *.out
108 |
109 | *.wav
110 | *.ckpt-*
111 |
--------------------------------------------------------------------------------
/DEMO_train_audio_transform.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import sys\n",
11 | "import glob\n",
12 | "import numpy as np\n",
13 | "import tensorflow as tf\n",
14 | "import importlib\n",
15 | "import IPython.display as ipd\n",
16 | "\n",
17 | "import util_audio_preprocess\n",
18 | "import util_audio_transform\n",
19 | "import util_auditory_model_loss\n",
20 | "import util_cochlear_model\n",
21 | "import util_recognition_network\n"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {},
28 | "outputs": [
29 | {
30 | "name": "stdout",
31 | "output_type": "stream",
32 | "text": [
33 | "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/control_flow_ops.py:3632: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
34 | "Instructions for updating:\n",
35 | "Colocations handled automatically by placer.\n"
36 | ]
37 | }
38 | ],
39 | "source": [
40 | "\"\"\"\n",
41 | "Data inputput pipeline\n",
42 | "\n",
43 | "Recommended training data format is tfrecords files containing foreground\n",
44 | "(speech) and background (noise) waveforms stored separately. Model was\n",
45 | "originally trained with 2-second audio clips (20 kHz sampling rate).\n",
46 | "\n",
47 | "NOTE: the full dataset used to train models in the paper is not included in\n",
48 | "this code release as it was compiled from previously published datasets\n",
49 | "(speech from Wall Street Journal and Spoken Wikipedia Corpora; background\n",
50 | "noise from Audioset). Data available upon request to authors.\n",
51 | "\"\"\"\n",
52 | "\n",
53 | "filenames = glob.glob('data/toy_dataset*.tfrecords')\n",
54 | "batch_size = 8\n",
55 | "feature_description = {\n",
56 | " 'background/signal': tf.io.FixedLenFeature([], tf.string, default_value=None),\n",
57 | " 'foreground/signal': tf.io.FixedLenFeature([], tf.string, default_value=None),\n",
58 | "}\n",
59 | "bytes_description = {\n",
60 | " 'background/signal': {'dtype': tf.float32, 'shape': [40000]}, \n",
61 | " 'foreground/signal': {'dtype': tf.float32, 'shape': [40000]},\n",
62 | "}\n",
63 | "\n",
64 | "\n",
65 | "def parse_tfrecord(tfrecord):\n",
66 | " tfrecord = tf.parse_single_example(tfrecord, features=feature_description)\n",
67 | " for key in bytes_description.keys():\n",
68 | " tfrecord[key] = tf.decode_raw(tfrecord[key], bytes_description[key]['dtype'])\n",
69 | " tfrecord[key] = tf.reshape(tfrecord[key], bytes_description[key]['shape'])\n",
70 | " return tfrecord\n",
71 | "\n",
72 | "\n",
73 | "def preprocess_audio_batch(batch):\n",
74 | " \"\"\"\n",
75 | " Function combines foreground (speech) and background (noise) audio\n",
76 | " signals with signal-to-noise ratios drawn uniformly between -20 and\n",
77 | " +10 dB. The returned dictionary contains the noisy speech signal,\n",
78 | " the clean speech signal, and the SNR.\n",
79 | " \"\"\"\n",
80 | " foreground_signal = batch['foreground/signal']\n",
81 | " background_signal = batch['background/signal']\n",
82 | " snr = tf.random.uniform(\n",
83 | " [tf.shape(foreground_signal)[0], 1],\n",
84 | " minval=-20.0,\n",
85 | " maxval=10.0,\n",
86 | " dtype=foreground_signal.dtype)\n",
87 | " signal_in_noise, signal, noise_scaled = util_audio_preprocess.tf_set_snr(\n",
88 | " foreground_signal,\n",
89 | " background_signal,\n",
90 | " snr)\n",
91 | " batch = {\n",
92 | " 'snr': snr,\n",
93 | " 'waveform_noisy': signal_in_noise,\n",
94 | " 'waveform_clean': signal,\n",
95 | " }\n",
96 | " return batch\n",
97 | "\n",
98 | "\n",
99 | "tf.reset_default_graph()\n",
100 | "tf.random.set_random_seed(0)\n",
101 | "\n",
102 | "dataset = tf.data.TFRecordDataset(filenames=filenames, compression_type='GZIP')\n",
103 | "dataset = dataset.map(parse_tfrecord)\n",
104 | "dataset = dataset.batch(batch_size)\n",
105 | "dataset = dataset.map(preprocess_audio_batch)\n",
106 | "dataset = dataset.prefetch(buffer_size=4)\n",
107 | "dataset = dataset.shuffle(buffer_size=32)\n",
108 | "dataset = dataset.repeat(count=None)\n",
109 | "\n",
110 | "iterator = dataset.make_one_shot_iterator()\n",
111 | "input_tensor_dict = iterator.get_next()\n"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 3,
117 | "metadata": {},
118 | "outputs": [
119 | {
120 | "name": "stdout",
121 | "output_type": "stream",
122 | "text": [
123 | "WARNING:tensorflow:From Wave-U-Net/UnetAudioSeparator.py:97: conv1d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.\n",
124 | "Instructions for updating:\n",
125 | "Use keras.layers.conv1d instead.\n",
126 | "1 recognition networks included for deep feature loss:\n",
127 | "|__ arch1_taskA: models/recognition_networks/arch1_taskA.ckpt-550000\n",
128 | "Building waveform loss\n",
129 | "Building cochlear model loss\n",
130 | "[make_cos_filters_nx] using filter_spacing=`erb`\n",
131 | "[make_cos_filters_nx] using filter_spacing=`erb`\n",
132 | "Building deep feature loss (recognition network: arch1_taskA)\n",
133 | "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/keras/layers/core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
134 | "Instructions for updating:\n",
135 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
136 | "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
137 | "Instructions for updating:\n",
138 | "Use tf.cast instead.\n",
139 | "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/signal/fft_ops.py:315: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
140 | "Instructions for updating:\n",
141 | "Use tf.cast instead.\n"
142 | ]
143 | },
144 | {
145 | "name": "stderr",
146 | "output_type": "stream",
147 | "text": [
148 | "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gradients_impl.py:110: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n",
149 | " \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n"
150 | ]
151 | }
152 | ],
153 | "source": [
154 | "\"\"\"\n",
155 | "Model training graph\n",
156 | "\n",
157 | "Components:\n",
158 | "1. U-Net audio transform (`util_audio_transform.build_unet`)\n",
159 | "2. Auditory model loss function (`util_auditory_model_loss.AuditoryModelLoss`)\n",
160 | "3. Tensorflow optimizer to train U-Net weights\n",
161 | "\"\"\"\n",
162 | "\n",
163 | "### U-Net audio transform\n",
164 | "tensor_waveform_noisy = input_tensor_dict['waveform_noisy']\n",
165 | "tensor_waveform_clean = input_tensor_dict['waveform_clean']\n",
166 | "tensor_waveform_denoised = util_audio_transform.build_unet(tensor_waveform_noisy)\n",
167 | "\n",
168 | "### Build auditory model loss function (specify recognition networks\n",
169 | "### to include in the deep feature loss)\n",
170 | "list_recognition_networks = [\n",
171 | " 'arch1_taskA',\n",
172 | "# 'arch2_taskA',\n",
173 | "# 'arch3_taskA',\n",
174 | "]\n",
175 | "auditory_model = util_auditory_model_loss.AuditoryModelLoss(\n",
176 | " list_recognition_networks=list_recognition_networks,\n",
177 | " tensor_wave0=tensor_waveform_clean,\n",
178 | " tensor_wave1=tensor_waveform_denoised)\n",
179 | "\n",
180 | "transform_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='separator')\n",
181 | "transform_saver = tf.train.Saver(var_list=transform_var_list, max_to_keep=0)\n",
182 | "\n",
183 | "### Specify loss function (waveform, cochlear model, or deep features)\n",
184 | "### and build optimizer object + training operation\n",
185 | "loss = auditory_model.loss_cochlear_model # <-- cochlear model loss is lightweight and works well\n",
186 | "# loss = auditory_model.loss_deep_features\n",
187 | "# loss = auditory_model.loss_waveform\n",
188 | "optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)\n",
189 | "train_op = optimizer.minimize(\n",
190 | " loss=loss,\n",
191 | " global_step=None,\n",
192 | " var_list=transform_var_list)\n"
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": 4,
198 | "metadata": {},
199 | "outputs": [
200 | {
201 | "name": "stdout",
202 | "output_type": "stream",
203 | "text": [
204 | "Loading `arch1_taskA` variables from models/recognition_networks/arch1_taskA.ckpt-550000\n",
205 | "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
206 | "Instructions for updating:\n",
207 | "Use standard file APIs to check for files with this prefix.\n",
208 | "INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskA.ckpt-550000\n",
209 | "INFO:tensorflow:Restoring parameters from models/recognition_networks/arch1_taskA.ckpt-550000\n",
210 | "Loss after training step 000000 = 759.71\n",
211 | "Loss after training step 000010 = 594.60\n",
212 | "Loss after training step 000020 = 547.26\n",
213 | "Loss after training step 000030 = 454.63\n",
214 | "Loss after training step 000040 = 478.92\n",
215 | "Loss after training step 000050 = 450.72\n",
216 | "Loss after training step 000060 = 395.44\n",
217 | "Loss after training step 000070 = 427.17\n",
218 | "Loss after training step 000080 = 456.67\n",
219 | "Loss after training step 000090 = 438.67\n",
220 | "Loss after training step 000100 = 431.82\n",
221 | "INFO:tensorflow:new_model.ckpt-100 is not in all_model_checkpoint_paths. Manually adding it.\n"
222 | ]
223 | }
224 | ],
225 | "source": [
226 | "\"\"\"\n",
227 | "Simple training routine\n",
228 | "\n",
229 | "Models in the paper were all trained for 600000 steps\n",
230 | "with batch size 8 and learning rate 10e-4.\n",
231 | "\"\"\"\n",
232 | "\n",
233 | "with tf.Session() as sess:\n",
234 | " sess.run(tf.global_variables_initializer())\n",
235 | " auditory_model.load_auditory_model_vars(sess)\n",
236 | " \n",
237 | " for step in range(101):\n",
238 | " _, step_loss = sess.run([train_op, loss])\n",
239 | " if step % 10 == 0:\n",
240 | " print(\"Loss after training step {:06d} = {:.02f}\".format(step, step_loss.mean()))\n",
241 | " \n",
242 | " transform_saver.save(\n",
243 | " sess,\n",
244 | " save_path='new_model.ckpt',\n",
245 | " global_step=step,\n",
246 | " write_meta_graph=False)\n"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": null,
252 | "metadata": {},
253 | "outputs": [],
254 | "source": []
255 | }
256 | ],
257 | "metadata": {
258 | "kernelspec": {
259 | "display_name": "Python 3",
260 | "language": "python",
261 | "name": "python3"
262 | },
263 | "language_info": {
264 | "codemirror_mode": {
265 | "name": "ipython",
266 | "version": 3
267 | },
268 | "file_extension": ".py",
269 | "mimetype": "text/x-python",
270 | "name": "python",
271 | "nbconvert_exporter": "python",
272 | "pygments_lexer": "ipython3",
273 | "version": "3.5.2"
274 | }
275 | },
276 | "nbformat": 4,
277 | "nbformat_minor": 2
278 | }
279 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Mark R. Saddler, Andrew Francl, Jenelle Feather
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Speech Denoising with Auditory Models ([arXiv](https://arxiv.org/abs/2011.10706), [audio examples](https://mcdermottlab.mit.edu/denoising/demo.html))
3 | This is a TensorFlow implementation of our [Speech Denoising with Auditory Models](https://arxiv.org/abs/2011.10706).
4 |
5 | Contact: [Mark Saddler](mailto:msaddler@mit.edu) or [Andrew Francl](mailto:francl@mit.edu)
6 |
7 |
8 | ### Citation
9 | If you use our code for research, please cite our paper:
10 | Mark R. Saddler\*, Andrew Francl\*, Jenelle Feather, Kaizhi Qian, Yang Zhang, Josh H. McDermott (2021). Speech Denoising with Auditory Models. *Proc. Interspeech* 2021, 2681-2685. [arXiv:2011.10706](https://arxiv.org/abs/2011.10706).
11 |
12 |
13 | ### License
14 | The source code is published under the MIT license. See [LICENSE](./LICENSE.md) for details. In general, you can use the code for any purpose with proper attribution. If you do something interesting with the code, we'll be happy to know. Feel free to contact us.
15 |
16 |
17 | ### Requirements
18 | In order to speed setup and aid reproducibility we provide a Singularity container. This container holds all the libraries and dependencies needed to run the code and allows you to work in the same environment as was originally used. Please see the [Singularity Documentation](https://sylabs.io/guides/3.8/user-guide/) for more details. Download Singularity image: [tensorflow-1.13.0-denoising.simg](https://drive.google.com/file/d/1KFGMJnuX4l1KRQRRVnzbjXE6bjHA7Tjm/view?usp=sharing).
19 |
20 |
21 | ### Trained Models
22 | We provide model checkpoints for all of our trained audio transforms and deep feature recognition networks. Users must download the audio transform checkpoints to evaluate our denoising algorithms on their own audio. Both sets of checkpoints must be downloaded to run our [DEMO Jupyter notebook](./DEMO.ipynb). Download the entire `auditory-model-denoising/models` directory [here](https://drive.google.com/drive/folders/1HmXSCVOKQCq7G62rs9KE_jsvVO0UqclC?usp=sharing):
23 | - Recognition network checkpoints: [auditory-model-denoising/models/recognition_networks](https://drive.google.com/file/d/1v9dKlRCnMP7X9v5IFcg4H0U1bXottgDo/view?usp=sharing)
24 | - Auditory transform checkpoints: [auditory-model-denoising/models/audio_transforms](https://drive.google.com/file/d/1L21NqxN-nVSzpY9CtjtH-1zlKPhEQkow/view?usp=sharing)
25 |
26 |
27 | ### Quick Start
28 | We provide a [Jupyter notebook](./DEMO.ipynb) that (1) demos how to run our trained denoising models and (2) provides examples of how to compute the auditory model losses used to train the models. A [second notebook](./DEMO_train_audio_transform.ipynb) demos how to train a new audio transform using the auditory model losses.
29 |
--------------------------------------------------------------------------------
/Wave-U-Net/InterpolationLayer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def learned_interpolation_layer(input, padding, level):
5 | '''
6 | Implements a trainable upsampling layer by interpolation by a factor of two, from N samples to N*2 - 1.
7 | Interpolation of intermediate feature vectors v_1 and v_2 (of dimensionality F) is performed by
8 | w \cdot v_1 + (1-w) \cdot v_2, where \cdot is point-wise multiplication, and w an F-dimensional weight vector constrained to [0,1]
9 | :param input: Input features of shape [batch_size, 1, width, F]
10 | :param padding:
11 | :param level:
12 | :return:
13 | '''
14 | assert(padding == "valid" or padding == "same")
15 | features = input.get_shape().as_list()[3]
16 |
17 | # Construct 2FxF weight matrix, where F is the number of feature channels in the feature map.
18 | # Matrix is constrained, made up out of two diagonal FxF matrices with diagonal weights w and 1-w. w is constrained to be in [0,1] # mioid
19 | weights = tf.get_variable("interp_" + str(level), shape=[features], dtype=tf.float32)
20 | weights_scaled = tf.nn.sigmoid(weights) # Constrain weights to [0,1]
21 | counter_weights = 1.0 - weights_scaled # Mirrored weights for the features from the other time step
22 | conv_weights = tf.expand_dims(tf.concat([tf.expand_dims(tf.diag(weights_scaled), axis=0), tf.expand_dims(tf.diag(counter_weights), axis=0)], axis=0), axis=0)
23 | intermediate_vals = tf.nn.conv2d(input, conv_weights, strides=[1,1,1,1], padding=padding.upper())
24 |
25 | intermediate_vals = tf.transpose(intermediate_vals, [2, 0, 1, 3])
26 | out = tf.transpose(input, [2, 0, 1, 3])
27 | num_entries = out.get_shape().as_list()[0]
28 | out = tf.concat([out, intermediate_vals], axis=0)
29 | indices = list()
30 |
31 | # Interleave interpolated features with original ones, starting with the first original one
32 | num_outputs = (2*num_entries - 1) if padding == "valid" else 2*num_entries
33 | for idx in range(num_outputs):
34 | if idx % 2 == 0:
35 | indices.append(idx // 2)
36 | else:
37 | indices.append(num_entries + idx//2)
38 | out = tf.gather(out, indices)
39 | current_layer = tf.transpose(out, [1, 2, 0, 3])
40 | return current_layer
--------------------------------------------------------------------------------
/Wave-U-Net/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Daniel Stoller
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Wave-U-Net/OutputLayer.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | import unet_utils
4 |
5 | def independent_outputs(featuremap, source_names, num_channels, filter_width, padding, activation):
6 | outputs = dict()
7 | for name in source_names:
8 | outputs[name] = tf.layers.conv1d(featuremap, num_channels, filter_width, activation=activation, padding=padding)
9 | return outputs
10 |
11 | def difference_output(input_mix, featuremap, source_names, num_channels, filter_width, padding, activation, training):
12 | outputs = dict()
13 | sum_source = 0
14 | for name in source_names[:-1]:
15 | out = tf.layers.conv1d(featuremap, num_channels, filter_width, activation=activation, padding=padding)
16 | outputs[name] = out
17 | sum_source = sum_source + out
18 |
19 | # Compute last source based on the others
20 | last_source = unet_utils.crop(input_mix, sum_source.get_shape().as_list()) - sum_source
21 | last_source = unet_utils.AudioClip(last_source, training)
22 | outputs[source_names[-1]] = last_source
23 | return outputs
24 |
--------------------------------------------------------------------------------
/Wave-U-Net/README.md:
--------------------------------------------------------------------------------
1 | # Wave-U-Net
2 | Implementation of the [Wave-U-Net](https://arxiv.org/abs/1806.03185) for audio source separation.
3 |
4 | Initial github repository copied from https://github.com/f90/Wave-U-Net (architecture was modified for compatibility).
5 |
--------------------------------------------------------------------------------
/Wave-U-Net/UnetAudioSeparator.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | import unet_utils
5 | import InterpolationLayer
6 | import OutputLayer
7 |
8 | class UnetAudioSeparator:
9 | '''
10 | U-Net separator network for singing voice separation.
11 | Uses valid convolutions, so it predicts for the centre part of the input - only certain input and output shapes are therefore possible (see getpadding function)
12 | '''
13 |
14 | def __init__(self, model_config):
15 | '''
16 | Initialize U-net
17 | :param num_layers: Number of down- and upscaling layers in the network
18 | '''
19 | self.num_layers = model_config.get("num_layers",12)
20 | self.num_initial_filters = model_config.get("num_initial_filters",24)
21 | self.filter_size = model_config.get("filter_size",15)
22 | self.merge_filter_size = model_config.get("merge_filter_size",5)
23 | self.input_filter_size = model_config.get("input_filter_size",15)
24 | self.output_filter_size = model_config.get("output_filter_size",1)
25 | self.upsampling = model_config.get("upsampling", 'learned')
26 | self.output_type = model_config.get("output_type", 'direct')
27 | self.context = model_config.get("context", False)
28 | self.padding = "valid" if model_config.get("context", False) else "same"
29 | self.source_names = model_config.get("source_names", ['enhancement'])
30 | self.num_channels = 1 if model_config.get("mono_downmix", True) else 2
31 | self.output_activation = model_config.get("output_activation", 'identity')
32 |
33 | def get_padding(self, shape):
34 | '''
35 | Calculates the required amounts of padding along each axis of the input and output, so that the Unet works and has the given shape as output shape
36 | :param shape: Desired output shape
37 | :return: Input_shape, output_shape, where each is a list [batch_size, time_steps, channels]
38 | '''
39 |
40 | if self.context:
41 | # Check if desired shape is possible as output shape - go from output shape towards lowest-res feature map
42 | rem = float(shape[1]) # Cut off batch size number and channel
43 |
44 | # Output filter size
45 | rem = rem - self.output_filter_size + 1
46 |
47 | # Upsampling blocks
48 | for i in range(self.num_layers):
49 | rem = rem + self.merge_filter_size - 1
50 | rem = (rem + 1.) / 2.# out = in + in - 1 <=> in = (out+1)/
51 |
52 | # Round resulting feature map dimensions up to nearest integer
53 | x = np.asarray(np.ceil(rem),dtype=np.int64)
54 | assert(x >= 2)
55 |
56 | # Compute input and output shapes based on lowest-res feature map
57 | output_shape = x
58 | input_shape = x
59 |
60 | # Extra conv
61 | input_shape = input_shape + self.filter_size - 1
62 |
63 | # Go from centre feature map through up- and downsampling blocks
64 | for i in range(self.num_layers):
65 | output_shape = 2*output_shape - 1 #Upsampling
66 | output_shape = output_shape - self.merge_filter_size + 1 # Conv
67 |
68 | input_shape = 2*input_shape - 1 # Decimation
69 | if i < self.num_layers - 1:
70 | input_shape = input_shape + self.filter_size - 1 # Conv
71 | else:
72 | input_shape = input_shape + self.input_filter_size - 1
73 |
74 | # Output filters
75 | output_shape = output_shape - self.output_filter_size + 1
76 |
77 | input_shape = np.concatenate([[shape[0]], [input_shape], [self.num_channels]])
78 | output_shape = np.concatenate([[shape[0]], [output_shape], [self.num_channels]])
79 |
80 | return input_shape, output_shape
81 | else:
82 | return [shape[0], shape[1], self.num_channels], [shape[0], shape[1], self.num_channels]
83 |
84 | def get_output(self, input, training, return_spectrogram=False, reuse=True):
85 | '''
86 | Creates symbolic computation graph of the U-Net for a given input batch
87 | :param input: Input batch of mixtures, 3D tensor [batch_size, num_samples, num_channels]
88 | :param reuse: Whether to create new parameter variables or reuse existing ones
89 | :return: U-Net output: List of source estimates. Each item is a 3D tensor [batch_size, num_out_samples, num_channels]
90 | '''
91 | with tf.variable_scope("separator", reuse=reuse):
92 | enc_outputs = list()
93 | current_layer = input
94 |
95 | # Down-convolution: Repeat strided conv
96 | for i in range(self.num_layers):
97 | current_layer = tf.layers.conv1d(current_layer, self.num_initial_filters + (self.num_initial_filters * i), self.filter_size, strides=1, activation=unet_utils.LeakyReLU, padding=self.padding) # out = in - filter + 1
98 | enc_outputs.append(current_layer)
99 | current_layer = current_layer[:,::2,:] # Decimate by factor of 2 # out = (in-1)/2 + 1
100 |
101 | current_layer = tf.layers.conv1d(current_layer, self.num_initial_filters + (self.num_initial_filters * self.num_layers),self.filter_size,activation=unet_utils.LeakyReLU,padding=self.padding) # One more conv here since we need to compute features after last decimation
102 |
103 | # Feature map here shall be X along one dimension
104 |
105 | # Upconvolution
106 | for i in range(self.num_layers):
107 | #UPSAMPLING
108 | current_layer = tf.expand_dims(current_layer, axis=1)
109 | if self.upsampling == 'learned':
110 | # Learned interpolation between two neighbouring time positions by using a convolution filter of width 2, and inserting the responses in the middle of the two respective inputs
111 | current_layer = InterpolationLayer.learned_interpolation_layer(current_layer, self.padding, i)
112 | else:
113 | if self.context:
114 | current_layer = tf.image.resize_bilinear(current_layer, [1, current_layer.get_shape().as_list()[2] * 2 - 1], align_corners=True)
115 | else:
116 | current_layer = tf.image.resize_bilinear(current_layer, [1, current_layer.get_shape().as_list()[2]*2]) # out = in + in - 1
117 | current_layer = tf.squeeze(current_layer, axis=1)
118 | # UPSAMPLING FINISHED
119 | assert(enc_outputs[-i-1].get_shape().as_list()[1] == current_layer.get_shape().as_list()[1] or self.context) #No cropping should be necessary unless we are using context
120 | current_layer = unet_utils.crop_and_concat(enc_outputs[-i-1], current_layer, match_feature_dim=False)
121 | current_layer = tf.layers.conv1d(current_layer, self.num_initial_filters + (self.num_initial_filters * (self.num_layers - i - 1)), self.merge_filter_size,
122 | activation=unet_utils.LeakyReLU,
123 | padding=self.padding) # out = in - filter + 1
124 |
125 | current_layer = unet_utils.crop_and_concat(input, current_layer, match_feature_dim=False)
126 |
127 | # Output layer
128 | # Determine output activation function
129 | if self.output_activation == "identity":
130 | out_activation = tf.identity
131 | elif self.output_activation == "tanh":
132 | out_activation = tf.tanh
133 | elif self.output_activation == "linear":
134 | out_activation = lambda x: unet_utils.AudioClip(x, training)
135 | else:
136 | raise NotImplementedError
137 |
138 | if self.output_type == "direct":
139 | return OutputLayer.independent_outputs(current_layer, self.source_names, self.num_channels, self.output_filter_size, self.padding, out_activation)
140 | elif self.output_type == "difference":
141 | cropped_input = unet_utils.crop(input,current_layer.get_shape().as_list(), match_feature_dim=False)
142 | return OutputLayer.difference_output(cropped_input, current_layer, self.source_names, self.num_channels, self.output_filter_size, self.padding, out_activation, training)
143 | else:
144 | raise NotImplementedError
145 |
--------------------------------------------------------------------------------
/Wave-U-Net/unet_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import librosa
4 |
5 | def getTrainableVariables(tag=""):
6 | return [v for v in tf.trainable_variables() if tag in v.name]
7 |
8 | def getNumParams(tensors):
9 | return np.sum([np.prod(t.get_shape().as_list()) for t in tensors])
10 |
11 | def crop_and_concat(x1,x2, match_feature_dim=True):
12 | '''
13 | Copy-and-crop operation for two feature maps of different size.
14 | Crops the first input x1 equally along its borders so that its shape is equal to
15 | the shape of the second input x2, then concatenates them along the feature channel axis.
16 | :param x1: First input that is cropped and combined with the second input
17 | :param x2: Second input
18 | :return: Combined feature map
19 | '''
20 | if x2 is None:
21 | return x1
22 |
23 | x1 = crop(x1,x2.get_shape().as_list(), match_feature_dim)
24 | return tf.concat([x1, x2], axis=2)
25 |
26 | def random_amplify(sample):
27 | '''
28 | Randomly amplifies or attenuates the input signal
29 | :return: Amplified signal
30 | '''
31 | for key, val in sample.items():
32 | if key != "mix":
33 | sample[key] = tf.random_uniform([], 0.7, 1.0) * val
34 |
35 | sample["mix"] = tf.add_n([val for key, val in sample.items() if key != "mix"])
36 | return sample
37 |
38 | def crop_sample(sample, crop_frames):
39 | for key, val in sample.items():
40 | if key != "mix" and crop_frames > 0:
41 | sample[key] = val[crop_frames:-crop_frames,:]
42 | return sample
43 |
44 | def pad_freqs(tensor, target_shape):
45 | '''
46 | Pads the frequency axis of a 4D tensor of shape [batch_size, freqs, timeframes, channels] or 2D tensor [freqs, timeframes] with zeros
47 | so that it reaches the target shape. If the number of frequencies to pad is uneven, the rows are appended at the end.
48 | :param tensor: Input tensor to pad with zeros along the frequency axis
49 | :param target_shape: Shape of tensor after zero-padding
50 | :return: Padded tensor
51 | '''
52 | target_freqs = (target_shape[1] if len(target_shape) == 4 else target_shape[0]) #TODO
53 | if isinstance(tensor, tf.Tensor):
54 | input_shape = tensor.get_shape().as_list()
55 | else:
56 | input_shape = tensor.shape
57 |
58 | if len(input_shape) == 2:
59 | input_freqs = input_shape[0]
60 | else:
61 | input_freqs = input_shape[1]
62 |
63 | diff = target_freqs - input_freqs
64 | if diff % 2 == 0:
65 | pad = [(diff/2, diff/2)]
66 | else:
67 | pad = [(diff//2, diff//2 + 1)] # Add extra frequency bin at the end
68 |
69 | if len(target_shape) == 2:
70 | pad = pad + [(0,0)]
71 | else:
72 | pad = [(0,0)] + pad + [(0,0), (0,0)]
73 |
74 | if isinstance(tensor, tf.Tensor):
75 | return tf.pad(tensor, pad, mode='constant', constant_values=0.0)
76 | else:
77 | return np.pad(tensor, pad, mode='constant', constant_values=0.0)
78 |
79 | def LeakyReLU(x, alpha=0.2):
80 | return tf.maximum(alpha*x, x)
81 |
82 | def AudioClip(x, training):
83 | '''
84 | Simply returns the input if training is set to True, otherwise clips the input to [-1,1]
85 | :param x: Input tensor (coming from last layer of neural network)
86 | :param training: Whether model is in training (True) or testing mode (False)
87 | :return: Output tensor (potentially clipped)
88 | '''
89 | if training:
90 | return x
91 | else:
92 | return tf.maximum(tf.minimum(x, 1.0), -1.0)
93 |
94 | def resample(audio, orig_sr, new_sr):
95 | return librosa.resample(audio.T, orig_sr, new_sr).T
96 |
97 | def load(path, sr=22050, mono=True, offset=0.0, duration=None, dtype=np.float32):
98 | # ALWAYS output (n_frames, n_channels) audio
99 | y, orig_sr = librosa.load(path, sr, mono, offset, duration, dtype)
100 | if len(y.shape) == 1:
101 | y = np.expand_dims(y, axis=0)
102 | return y.T, orig_sr
103 |
104 | def crop(tensor, target_shape, match_feature_dim=True):
105 | '''
106 | Crops a 3D tensor [batch_size, width, channels] along the width axes to a target shape.
107 | Performs a centre crop. If the dimension difference is uneven, crop last dimensions first.
108 | :param tensor: 4D tensor [batch_size, width, height, channels] that should be cropped.
109 | :param target_shape: Target shape (4D tensor) that the tensor should be cropped to
110 | :return: Cropped tensor
111 | '''
112 | shape = np.array(tensor.get_shape().as_list())
113 | # Handles case when batch size is not defined during graph construction
114 | if shape[0] is None and target_shape[0] is None:
115 | shape[0] = -1
116 | target_shape[0] = -1
117 | diff = shape - np.array(target_shape)
118 | assert(diff[0] == 0 and (diff[2] == 0 or not match_feature_dim))# Only width axis can differ
119 | if (diff[1] % 2 != 0):
120 | print("WARNING: Cropping with uneven number of extra entries on one side")
121 | assert diff[1] >= 0 # Only positive difference allowed
122 | if diff[1] == 0:
123 | return tensor
124 | crop_start = diff // 2
125 | crop_end = diff - crop_start
126 |
127 | return tensor[:,crop_start[1]:-crop_end[1],:]
128 |
129 | def spectrogramToAudioFile(magnitude, fftWindowSize, hopSize, phaseIterations=10, phase=None, length=None):
130 | '''
131 | Computes an audio signal from the given magnitude spectrogram, and optionally an initial phase.
132 | Griffin-Lim is executed to recover/refine the given the phase from the magnitude spectrogram.
133 | :param magnitude: Magnitudes to be converted to audio
134 | :param fftWindowSize: Size of FFT window used to create magnitudes
135 | :param hopSize: Hop size in frames used to create magnitudes
136 | :param phaseIterations: Number of Griffin-Lim iterations to recover phase
137 | :param phase: If given, starts ISTFT with this particular phase matrix
138 | :param length: If given, audio signal is clipped/padded to this number of frames
139 | :return:
140 | '''
141 | if phase is not None:
142 | if phaseIterations > 0:
143 | # Refine audio given initial phase with a number of iterations
144 | return reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations, phase, length)
145 | # reconstructing the new complex matrix
146 | stftMatrix = magnitude * np.exp(phase * 1j) # magnitude * e^(j*phase)
147 | audio = librosa.istft(stftMatrix, hop_length=hopSize, length=length)
148 | else:
149 | audio = reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations)
150 | return audio
151 |
152 | def reconPhase(magnitude, fftWindowSize, hopSize, phaseIterations=10, initPhase=None, length=None):
153 | '''
154 | Griffin-Lim algorithm for reconstructing the phase for a given magnitude spectrogram, optionally with a given
155 | intial phase.
156 | :param magnitude: Magnitudes to be converted to audio
157 | :param fftWindowSize: Size of FFT window used to create magnitudes
158 | :param hopSize: Hop size in frames used to create magnitudes
159 | :param phaseIterations: Number of Griffin-Lim iterations to recover phase
160 | :param initPhase: If given, starts reconstruction with this particular phase matrix
161 | :param length: If given, audio signal is clipped/padded to this number of frames
162 | :return:
163 | '''
164 | for i in range(phaseIterations):
165 | if i == 0:
166 | if initPhase is None:
167 | reconstruction = np.random.random_sample(magnitude.shape) + 1j * (2 * np.pi * np.random.random_sample(magnitude.shape) - np.pi)
168 | else:
169 | reconstruction = np.exp(initPhase * 1j) # e^(j*phase), so that angle => phase
170 | else:
171 | reconstruction = librosa.stft(audio, fftWindowSize, hopSize)
172 | spectrum = magnitude * np.exp(1j * np.angle(reconstruction))
173 | if i == phaseIterations - 1:
174 | audio = librosa.istft(spectrum, hopSize, length=length)
175 | else:
176 | audio = librosa.istft(spectrum, hopSize)
177 | return audio
178 |
--------------------------------------------------------------------------------
/audio/ex01_unprocessed_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex01_unprocessed_input.wav
--------------------------------------------------------------------------------
/audio/ex02_unprocessed_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex02_unprocessed_input.wav
--------------------------------------------------------------------------------
/audio/ex03_unprocessed_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex03_unprocessed_input.wav
--------------------------------------------------------------------------------
/audio/ex04_unprocessed_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex04_unprocessed_input.wav
--------------------------------------------------------------------------------
/audio/ex05_unprocessed_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex05_unprocessed_input.wav
--------------------------------------------------------------------------------
/audio/ex06_unprocessed_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex06_unprocessed_input.wav
--------------------------------------------------------------------------------
/audio/ex07_unprocessed_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex07_unprocessed_input.wav
--------------------------------------------------------------------------------
/audio/ex08_unprocessed_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex08_unprocessed_input.wav
--------------------------------------------------------------------------------
/audio/ex09_unprocessed_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex09_unprocessed_input.wav
--------------------------------------------------------------------------------
/audio/ex10_unprocessed_input.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/audio/ex10_unprocessed_input.wav
--------------------------------------------------------------------------------
/data/toy_dataset_0000.tfrecords:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/data/toy_dataset_0000.tfrecords
--------------------------------------------------------------------------------
/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/models/.gitkeep
--------------------------------------------------------------------------------
/models/audio_transforms/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/models/audio_transforms/.gitkeep
--------------------------------------------------------------------------------
/models/recognition_networks/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/msaddler/auditory-model-denoising/ea35806b60848b4848cb9761afa5906f32f90cb5/models/recognition_networks/.gitkeep
--------------------------------------------------------------------------------
/models/recognition_networks/arch1.json:
--------------------------------------------------------------------------------
1 | [{"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_data_layer"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_0", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 32, "kernel_size": [2, 42], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_0"}}, {"layer_type": "hpool", "args": {"name": "pool_0", "padding": "VALID_TIME", "strides": [2, 4], "pool_size": [8, 16], "sqrt_window":true}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_0"}},{"layer_type": "tf.layers.conv2d", "args": {"name": "conv_1", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 64, "kernel_size": [2, 18], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_1"}}, {"layer_type": "hpool", "args": {"name": "pool_1", "padding": "VALID_TIME", "strides": [2, 4], "pool_size": [8, 16], "sqrt_window":true}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_1"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_2", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 128, "kernel_size": [6, 6], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_2"}}, {"layer_type": "hpool", "args": {"name": "pool_2", "padding": "VALID_TIME", "strides": [1, 4], "pool_size": [1, 16], "sqrt_window":true}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_2"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_3", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 256, "kernel_size": [6, 6], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_3"}}, {"layer_type": "hpool", "args": {"name": "pool_3", "padding": "VALID_TIME", "strides": [1, 4], "pool_size": [1, 16], "sqrt_window":true}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_3"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_4", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 512, "kernel_size": [8, 8], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_4"}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_4"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_5", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 512, "kernel_size": [6, 6], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_5"}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_5"}}, {"layer_type": "tf.layers.conv2d", "args": {"name": "conv_6", "dilation_rate": [1, 1], "strides": [1, 1], "filters": 512, "kernel_size": [8, 8], "activation": null, "padding": "VALID_TIME"}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_6"}}, {"layer_type": "hpool", "args": {"name": "pool_4", "padding": "VALID_TIME", "strides": [2, 4], "pool_size": [8, 16], "sqrt_window":true}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_6"}}, {"layer_type": "tf.layers.flatten", "args": {"name": "flatten_end_conv"}}, {"layer_type": "tf.layers.dense", "args": {"name": "fc_intermediate", "activation": null, "units": 512}}, {"layer_type": "tf.nn.relu", "args": {"name": "relu_fc_intermediate"}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_fc_intermediate"}}, {"layer_type": "tf.layers.dropout", "args": {"name": "dropout", "rate": 0.5}}, {"layer_type": "fc_top_classification", "args": {"activation": null, "name": "fc_top"}}]
2 |
--------------------------------------------------------------------------------
/models/recognition_networks/arch2.json:
--------------------------------------------------------------------------------
1 | [{"args": {"strides": [1, 1], "kernel_size": [2, 32], "filters": 32, "dilation_rate": [1, 1], "name": "conv_0", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_0"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [12, 64], "sqrt_window": false, "strides": [3, 16], "name": "pool_0", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_0"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [2, 16], "filters": 128, "dilation_rate": [1, 1], "name": "conv_1", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_1"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 16], "sqrt_window": false, "strides": [1, 4], "name": "pool_1", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_1"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [1, 16], "filters": 64, "dilation_rate": [1, 1], "name": "conv_2", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_2"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [8, 16], "sqrt_window": false, "strides": [2, 4], "name": "pool_2", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_2"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [2, 8], "filters": 512, "dilation_rate": [1, 1], "name": "conv_3", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_3"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 8], "sqrt_window": false, "strides": [1, 2], "name": "pool_3", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_3"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [3, 8], "filters": 512, "dilation_rate": [1, 1], "name": "conv_4", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_4"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 1], "sqrt_window": false, "strides": [1, 1], "name": "pool_4", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_4"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [2, 3], "filters": 512, "dilation_rate": [1, 1], "name": "conv_5", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_5"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 1], "sqrt_window": false, "strides": [1, 1], "name": "pool_5", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_5"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [2, 4], "filters": 512, "dilation_rate": [1, 1], "name": "conv_6", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_6"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 8], "sqrt_window": false, "strides": [1, 2], "name": "pool_6", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_6"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"strides": [1, 1], "kernel_size": [1, 3], "filters": 512, "dilation_rate": [1, 1], "name": "conv_7", "padding": "VALID_TIME", "activation": null}, "layer_type": "tf.layers.conv2d"}, {"args": {"name": "lrelu_7"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"normalize": true, "pool_size": [1, 1], "sqrt_window": false, "strides": [1, 1], "name": "pool_7", "padding": "VALID_TIME"}, "layer_type": "hpool"}, {"args": {"name": "batch_norm_7"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"name": "flatten_end_conv"}, "layer_type": "tf.layers.flatten"}, {"args": {"name": "fc_intermediate", "activation": null, "units": 1024}, "layer_type": "tf.layers.dense"}, {"args": {"name": "lrelu_fc_intermediate"}, "layer_type": "tf.nn.leaky_relu"}, {"args": {"name": "batch_norm_fc_intermediate"}, "layer_type": "tf.layers.batch_normalization"}, {"args": {"name": "dropout", "rate": 0.5}, "layer_type": "tf.layers.dropout"}, {"args": {"name": "fc_top", "activation": null}, "layer_type": "fc_top_classification"}]
--------------------------------------------------------------------------------
/models/recognition_networks/arch3.json:
--------------------------------------------------------------------------------
1 | [{"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [3, 64], "padding": "VALID_TIME", "filters": 32, "activation": null, "name": "conv_0", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_0"}}, {"layer_type": "hpool", "args": {"name": "pool_0", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 32], "strides": [1, 8]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_0"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [2, 16], "padding": "VALID_TIME", "filters": 32, "activation": null, "name": "conv_1", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_1"}}, {"layer_type": "hpool", "args": {"name": "pool_1", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 32], "strides": [1, 8]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_1"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [1, 8], "padding": "VALID_TIME", "filters": 256, "activation": null, "name": "conv_2", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_2"}}, {"layer_type": "hpool", "args": {"name": "pool_2", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [8, 16], "strides": [2, 4]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_2"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [2, 8], "padding": "VALID_TIME", "filters": 512, "activation": null, "name": "conv_3", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_3"}}, {"layer_type": "hpool", "args": {"name": "pool_3", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 1], "strides": [1, 1]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_3"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [3, 2], "padding": "VALID_TIME", "filters": 256, "activation": null, "name": "conv_4", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_4"}}, {"layer_type": "hpool", "args": {"name": "pool_4", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 1], "strides": [1, 1]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_4"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [1, 2], "padding": "VALID_TIME", "filters": 256, "activation": null, "name": "conv_5", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_5"}}, {"layer_type": "hpool", "args": {"name": "pool_5", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 8], "strides": [1, 2]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_5"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [1, 4], "padding": "VALID_TIME", "filters": 256, "activation": null, "name": "conv_6", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_6"}}, {"layer_type": "hpool", "args": {"name": "pool_6", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 1], "strides": [1, 1]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_6"}}, {"layer_type": "tf.layers.conv2d", "args": {"dilation_rate": [1, 1], "kernel_size": [1, 3], "padding": "VALID_TIME", "filters": 128, "activation": null, "name": "conv_7", "strides": [1, 1]}}, {"layer_type": "tf.nn.leaky_relu", "args": {"name": "lrelu_7"}}, {"layer_type": "hpool", "args": {"name": "pool_7", "normalize": true, "padding": "VALID_TIME", "sqrt_window": false, "pool_size": [1, 1], "strides": [1, 1]}}, {"layer_type": "tf.layers.batch_normalization", "args": {"name": "batch_norm_7"}}, {"layer_type": "tf.layers.flatten", "args": {"name": "flatten_end_conv"}}, {"layer_type": "tf.layers.dropout", "args": {"rate": 0.5, "name": "dropout"}}, {"layer_type": "fc_top_classification", "args": {"name": "fc_top", "activation": null}}]
--------------------------------------------------------------------------------
/models/recognition_networks/deep_feature_loss_weights.json:
--------------------------------------------------------------------------------
1 | {
2 | "arch1_taskA": {
3 | "batch_norm_0": 7.721951988869115e-07,
4 | "batch_norm_1": 2.180155442434395e-06,
5 | "batch_norm_2": 4.35400902509891e-06,
6 | "batch_norm_3": 9.663577649529144e-06,
7 | "batch_norm_4": 7.037957926197431e-06,
8 | "batch_norm_5": 8.04823966794683e-06,
9 | "batch_norm_6": 9.601534802363368e-05
10 | },
11 | "arch1_taskR": {
12 | "batch_norm_0": 4.5313309082274834e-06,
13 | "batch_norm_1": 2.2921854969402775e-06,
14 | "batch_norm_2": 9.444441099113525e-07,
15 | "batch_norm_3": 4.3454074465209316e-07,
16 | "batch_norm_4": 4.744116803227874e-07,
17 | "batch_norm_5": 7.252991823836001e-07,
18 | "batch_norm_6": 7.048356282228414e-07
19 | },
20 | "arch1_taskW": {
21 | "batch_norm_0": 8.148634005683363e-07,
22 | "batch_norm_1": 2.655124304882748e-06,
23 | "batch_norm_2": 5.720402614649923e-06,
24 | "batch_norm_3": 1.7785463446355866e-05,
25 | "batch_norm_4": 1.1564212308251202e-05,
26 | "batch_norm_5": 1.3680587713195843e-05,
27 | "batch_norm_6": 0.00012747735009257673
28 | },
29 | "arch2_taskA": {
30 | "batch_norm_0": 1.2787083010970974e-05,
31 | "batch_norm_1": 3.872243390175745e-06,
32 | "batch_norm_2": 4.0916914026823114e-05,
33 | "batch_norm_3": 1.0983149329860448e-05,
34 | "batch_norm_4": 1.6411619372836008e-05,
35 | "batch_norm_5": 1.9605537700159232e-05,
36 | "batch_norm_6": 8.89673192136142e-05,
37 | "batch_norm_7": 0.00015303936142223106
38 | },
39 | "arch2_taskR": {
40 | "batch_norm_0": 0.0009033044077854935,
41 | "batch_norm_1": 0.001985916145766335,
42 | "batch_norm_2": 0.04740769988747281,
43 | "batch_norm_3": 0.049138172254876385,
44 | "batch_norm_4": 0.09488066449971078,
45 | "batch_norm_5": 0.14831164137971978,
46 | "batch_norm_6": 0.8964458129175272,
47 | "batch_norm_7": 2.0977237380635794
48 | },
49 | "arch2_taskW": {
50 | "batch_norm_0": 1.7039983238137272e-05,
51 | "batch_norm_1": 4.049126793732267e-06,
52 | "batch_norm_2": 5.5138894571540054e-05,
53 | "batch_norm_3": 1.5994307965068624e-05,
54 | "batch_norm_4": 2.3465571033206468e-05,
55 | "batch_norm_5": 2.6932619594160208e-05,
56 | "batch_norm_6": 9.091449999742952e-05,
57 | "batch_norm_7": 0.0001247474154084437
58 | },
59 | "arch3_taskA": {
60 | "batch_norm_0": 9.367178190155468e-07,
61 | "batch_norm_1": 4.047365064236076e-06,
62 | "batch_norm_2": 3.221161721690842e-06,
63 | "batch_norm_3": 1.6302867318651998e-06,
64 | "batch_norm_4": 3.209753982180463e-06,
65 | "batch_norm_5": 7.648978449406465e-06,
66 | "batch_norm_6": 9.518763618955694e-06,
67 | "batch_norm_7": 0.0001503918748524232
68 | },
69 | "arch3_taskR": {
70 | "batch_norm_0": 8.944401019583153e-05,
71 | "batch_norm_1": 0.0015361094702168968,
72 | "batch_norm_2": 0.005594106590352144,
73 | "batch_norm_3": 0.004969826705029876,
74 | "batch_norm_4": 0.012617730321076999,
75 | "batch_norm_5": 0.04534875913067663,
76 | "batch_norm_6": 0.0649332731409872,
77 | "batch_norm_7": 0.19110162724547713
78 | },
79 | "arch3_taskW": {
80 | "batch_norm_0": 8.376852943496161e-07,
81 | "batch_norm_1": 4.130054862512673e-06,
82 | "batch_norm_2": 4.30953902070214e-06,
83 | "batch_norm_3": 2.510501905339445e-06,
84 | "batch_norm_4": 5.482680435228746e-06,
85 | "batch_norm_5": 1.2116750065352435e-05,
86 | "batch_norm_6": 1.3421245766398092e-05,
87 | "batch_norm_7": 0.0001430609287507201
88 | }
89 | }
--------------------------------------------------------------------------------
/util_audio_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tensorflow as tf
4 |
5 |
6 | def tf_demean(x, axis=1):
7 | '''
8 | Helper function to mean-subtract tensor.
9 |
10 | Args
11 | ----
12 | x (tensor): tensor to be mean-subtracted
13 | axis (int): kwarg for tf.reduce_mean (axis along which to compute mean)
14 |
15 | Returns
16 | -------
17 | x_demean (tensor): mean-subtracted tensor
18 | '''
19 | x_demean = tf.math.subtract(x, tf.reduce_mean(x, axis=1, keepdims=True))
20 | return x_demean
21 |
22 |
23 | def tf_rms(x, axis=1, keepdims=True):
24 | '''
25 | Helper function to compute RMS amplitude of a tensor.
26 |
27 | Args
28 | ----
29 | x (tensor): tensor for which RMS amplitude should be computed
30 | axis (int): kwarg for tf.reduce_mean (axis along which to compute mean)
31 | keepdims (bool): kwarg for tf.reduce_mean (specify if mean should keep collapsed dimension)
32 |
33 | Returns
34 | -------
35 | rms_x (tensor): root-mean-square amplitude of x
36 | '''
37 | rms_x = tf.sqrt(tf.reduce_mean(tf.math.square(x), axis=axis, keepdims=keepdims))
38 | return rms_x
39 |
40 |
41 | def tf_set_snr(signal, noise, snr):
42 | '''
43 | Helper function to combine signal and noise tensors with specified SNR.
44 |
45 | Args
46 | ----
47 | signal (tensor): signal tensor
48 | noise (tensor): noise tensor
49 | snr (tensor): desired signal-to-noise ratio in dB
50 |
51 | Returns
52 | -------
53 | signal_in_noise (tensor): equal to signal + noise_scaled
54 | signal (tensor): mean-subtracted version of input signal tensor
55 | noise_scaled (tensor): mean-subtracted and scaled version of input noise tensor
56 |
57 | Raises
58 | ------
59 | InvalidArgumentError: Raised when rms(signal) == 0 or rms(noise) == 0.
60 | Occurs if noise or signal input are all zeros, which is incompatible with set_snr implementation.
61 | '''
62 | # Mean-subtract the provided signal and noise
63 | signal = tf_demean(signal, axis=1)
64 | noise = tf_demean(noise, axis=1)
65 | # Compute RMS amplitudes of provided signal and noise
66 | rms_signal = tf_rms(signal, axis=1, keepdims=True)
67 | rms_noise = tf_rms(noise, axis=1, keepdims=True)
68 | # Ensure neither signal nor noise has an RMS amplitude of zero
69 | msg = 'The rms({:s}) == 0. Results from {:s} input values all equal to zero'
70 | tf.debugging.assert_none_equal(rms_signal, tf.zeros_like(rms_signal),
71 | message=msg.format('signal','signal')).mark_used()
72 | tf.debugging.assert_none_equal(rms_noise, tf.zeros_like(rms_noise),
73 | message=msg.format('noise','noise')).mark_used()
74 | # Convert snr from dB to desired ratio of RMS(signal) to RMS(noise)
75 | rms_ratio = tf.math.pow(10.0, snr / 20.0)
76 | # Re-scale RMS of the noise such that signal + noise will have desired SNR
77 | noise_scale_factor = tf.math.divide(rms_signal, tf.math.multiply(rms_noise, rms_ratio))
78 | noise_scaled = tf.math.multiply(noise_scale_factor, noise)
79 | signal_in_noise = tf.math.add(signal, noise_scaled)
80 | return signal_in_noise, signal, noise_scaled
81 |
82 |
83 | def tf_set_dbspl(x, dbspl):
84 | '''
85 | Helper function to scale tensor to a specified sound pressure level
86 | in dB re 20e-6 Pa (dB SPL).
87 |
88 | Args
89 | ----
90 | x (tensor): tensor to be scaled
91 | dbspl (tensor): desired sound pressure level in dB re 20e-6 Pa
92 |
93 | Returns
94 | -------
95 | x (tensor): mean-subtracted and scaled tensor
96 | scale_factor (tensor): constant x is multiplied by to set dB SPL
97 | '''
98 | x = tf_demean(x, axis=1)
99 | rms_new = 20e-6 * tf.math.pow(10.0, dbspl / 20.0)
100 | rms_old = tf_rms(x, axis=1, keepdims=True)
101 | scale_factor = rms_new / rms_old
102 | x = tf.math.multiply(scale_factor, x)
103 | return x, scale_factor
104 |
--------------------------------------------------------------------------------
/util_audio_transform.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tensorflow as tf
4 |
5 | sys.path.append('Wave-U-Net')
6 | import UnetAudioSeparator
7 |
8 |
9 | def build_unet(tensor_waveform, signal_rate=20000, UNET_PARAMS={}):
10 | '''
11 | This function builds the tensorflow graph for the Wave-U-Net.
12 |
13 | Args
14 | ----
15 | tensor_waveform (tensor): input audio waveform (shape: [batch, time])
16 | signal_rate (int): sampling rate of input waveform (Hz)
17 | UNET_PARAMS (dict): U-net configuration parameters
18 |
19 | Returns
20 | -------
21 | tensor_waveform_unet (tensor): U-net transformed waveform (shape: [batch, time])
22 | '''
23 | padding = UNET_PARAMS.get('padding', [[0,0], [480,480]])
24 | tensor_waveform_zero_padded = tf.pad(tensor_waveform, padding)
25 | tensor_waveform_expanded = tf.expand_dims(tensor_waveform_zero_padded,axis=2)
26 | unet_audio_separator = UnetAudioSeparator.UnetAudioSeparator(UNET_PARAMS)
27 | unet_audio_separator_output = unet_audio_separator.get_output(
28 | tensor_waveform_expanded,
29 | training=True,
30 | return_spectrogram=False,
31 | reuse=tf.AUTO_REUSE)
32 | tensor_waveform_unet = unet_audio_separator_output["enhancement"]
33 | tensor_waveform_unet = tensor_waveform_unet[:, padding[1][0]:-padding[1][1],:]
34 | tensor_waveform_unet = tf.squeeze(tensor_waveform_unet, axis=2)
35 | return tensor_waveform_unet
36 |
--------------------------------------------------------------------------------
/util_auditory_model_loss.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import json
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | import util_recognition_network
9 | import util_cochlear_model
10 |
11 |
12 | class AuditoryModelLoss():
13 | def __init__(self,
14 | dir_recognition_networks='models/recognition_networks',
15 | list_recognition_networks=None,
16 | fn_weights='deep_feature_loss_weights.json',
17 | config_cochlear_model={},
18 | tensor_wave0=None,
19 | tensor_wave1=None):
20 | """
21 | The AuditoryModelLoss class creates an object to measure distances between
22 | auditory model representations of sounds.
23 |
24 | Args
25 | ----
26 | dir_recognition_networks (str): directory containing recognition network architectures
27 | and model checkpoints
28 | list_recognition_networks (list or None): list of recognition network keys specifying
29 | the deep feature loss (if None, use all checkpoints in `dir_recognition_networks`)
30 | fn_weights (str): filename for deep feature loss weights, which weight the contribution
31 | of each recognition network layer to the total deep feature loss weight
32 | config_cochlear_model (dict): parameters for util_cochlear_model.build_cochlear_model
33 | tensor_wave0 (tensor with shape [batch, 40000]): reference waveform tensor (for training)
34 | tensor_wave1 (tensor with shape [batch, 40000]): denoised waveform tensor (for training)
35 | """
36 | if not os.path.isabs(fn_weights):
37 | fn_weights = os.path.join(dir_recognition_networks, fn_weights)
38 | with open(fn_weights, 'r') as f_weights:
39 | deep_feature_loss_weights = json.load(f_weights)
40 | if list_recognition_networks is None:
41 | print(("`list_recognition_networks` not specified --> "
42 | "searching for all checkpoints in {}".format(dir_recognition_networks)))
43 | list_fn_ckpt = glob.glob(os.path.join(dir_recognition_networks, '*index'))
44 | list_fn_ckpt = [fn_ckpt.replace('.index', '') for fn_ckpt in list_fn_ckpt]
45 | else:
46 | list_fn_ckpt = []
47 | for network_key in list_recognition_networks:
48 | tmp = glob.glob(os.path.join(dir_recognition_networks, '{}*index'.format(network_key)))
49 | msg = "Failed to find exactly 1 checkpoint for recognition network {}".format(network_key)
50 | assert len(tmp) == 1, msg
51 | list_fn_ckpt.append(tmp[0].replace('.index', ''))
52 | print("{} recognition networks included for deep feature loss:".format(len(list_fn_ckpt)))
53 | config_recognition_networks = {}
54 | for fn_ckpt in list_fn_ckpt:
55 | network_key = os.path.basename(fn_ckpt).split('.')[0]
56 | if 'taskA' in network_key:
57 | n_classes_dict = {"task_audioset": 517}
58 | else:
59 | n_classes_dict = {"task_word": 794}
60 | config_recognition_networks[network_key] = {
61 | 'fn_ckpt': fn_ckpt,
62 | 'fn_arch': fn_ckpt[:fn_ckpt.rfind('_task')] + '.json',
63 | 'n_classes_dict': n_classes_dict,
64 | 'weights': deep_feature_loss_weights[network_key],
65 | }
66 | print('|__ {}: {}'.format(network_key, fn_ckpt))
67 | self.config_recognition_networks = config_recognition_networks
68 | self.config_cochlear_model = config_cochlear_model
69 | self.tensor_wave0 = tensor_wave0
70 | self.tensor_wave1 = tensor_wave1
71 | self.build_auditory_model()
72 | self.sess = None
73 | self.vars_loaded = False
74 | return
75 |
76 |
77 | def l1_distance(self, feature0, feature1):
78 | """
79 | Computes L1 distance between two features (preserving axis 0 for batch)
80 | """
81 | axis = np.arange(1, len(feature0.get_shape().as_list()))
82 | return tf.reduce_sum(tf.math.abs(feature0 - feature1), axis=axis)
83 |
84 |
85 | def build_auditory_model(self, dtype=tf.float32):
86 | """
87 | Constructs the full auditory model and losses. This function builds two
88 | waveform placeholders and constructs two identical auditory models operating
89 | on the two placeholders. Waveform, cochlear model, and deep feature losses are
90 | computed by measuring distances between identical stages of the two auditory
91 | models.
92 | """
93 | # Build placeholders for two waveforms (if needed) and compute waveform loss
94 | if (self.tensor_wave0 is None) or (self.tensor_wave1 is None):
95 | self.tensor_wave0 = tf.placeholder(dtype, [None, 40000])
96 | self.tensor_wave1 = tf.placeholder(dtype, [None, 40000])
97 | print('Building waveform loss')
98 | self.loss_waveform = self.l1_distance(self.tensor_wave0, self.tensor_wave1)
99 | # Build cochlear model for each waveform and compute cochlear model loss
100 | print('Building cochlear model loss')
101 | tensor_coch0, _ = util_cochlear_model.build_cochlear_model(
102 | self.tensor_wave0,
103 | **self.config_cochlear_model)
104 | tensor_coch1, _ = util_cochlear_model.build_cochlear_model(
105 | self.tensor_wave1,
106 | **self.config_cochlear_model)
107 | self.loss_cochlear_model = self.l1_distance(tensor_coch0, tensor_coch1)
108 | # Build network(s) for each waveform and compute deep feature losses
109 | self.loss_deep_features_dict = {}
110 | self.loss_deep_features = tf.zeros([], dtype=dtype)
111 | for network_key in sorted(self.config_recognition_networks.keys()):
112 | print('Building deep feature loss (recognition network: {})'.format(network_key))
113 | with open(self.config_recognition_networks[network_key]['fn_arch'], 'r') as f:
114 | list_layer_dict = json.load(f)
115 | # Build network for stimulus 0
116 | with tf.variable_scope(network_key + '0') as scope:
117 | _, tensors_network0 = util_recognition_network.build_network(
118 | tensor_coch0,
119 | list_layer_dict,
120 | n_classes_dict=self.config_recognition_networks[network_key]['n_classes_dict'])
121 | var_list = {
122 | v.name.replace(scope.name + '/', '').replace(':0', ''): v
123 | for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)
124 | }
125 | self.config_recognition_networks[network_key]['saver0'] = tf.train.Saver(
126 | var_list=var_list,
127 | max_to_keep=0)
128 | # Build network for stimulus 1
129 | with tf.variable_scope(network_key + '1') as scope:
130 | _, tensors_network1 = util_recognition_network.build_network(
131 | tensor_coch1,
132 | list_layer_dict,
133 | n_classes_dict=self.config_recognition_networks[network_key]['n_classes_dict'])
134 | var_list = {
135 | v.name.replace(scope.name + '/', '').replace(':0', ''): v
136 | for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope.name)
137 | }
138 | self.config_recognition_networks[network_key]['saver1'] = tf.train.Saver(
139 | var_list=var_list,
140 | max_to_keep=0)
141 | # Compute deep feature losses (weighted sum across layers)
142 | self.loss_deep_features_dict[network_key] = tf.zeros([], dtype=dtype)
143 | layer_weights = self.config_recognition_networks[network_key]['weights']
144 | for layer_key in sorted(layer_weights.keys()):
145 | tmp = self.l1_distance(tensors_network0[layer_key], tensors_network1[layer_key])
146 | self.loss_deep_features_dict[network_key] += layer_weights[layer_key] * tmp
147 | self.loss_deep_features += self.loss_deep_features_dict[network_key]
148 |
149 |
150 | def load_auditory_model_vars(self, sess):
151 | """
152 | Loads the same variables into the two copies of each recognition network
153 | from recognition network checkpoints.
154 | """
155 | self.sess = sess
156 | for network_key in sorted(self.config_recognition_networks.keys()):
157 | fn_ckpt = self.config_recognition_networks[network_key]['fn_ckpt']
158 | saver0 = self.config_recognition_networks[network_key]['saver0']
159 | saver1 = self.config_recognition_networks[network_key]['saver1']
160 | print('Loading `{}` variables from {}'.format(network_key, fn_ckpt))
161 | saver0.restore(self.sess, fn_ckpt)
162 | saver1.restore(self.sess, fn_ckpt)
163 | self.vars_loaded = True
164 |
165 |
166 | def waveform_loss(self, y0, y1):
167 | """
168 | Method to compute waveform loss between two waveforms under
169 | the constructed auditory model.
170 |
171 | Args
172 | ----
173 | y0 (np.ndarray): waveform with shape [batch, 40000]
174 | y1 (np.ndarray): waveform with shape [batch, 40000]
175 |
176 | Returns
177 | -------
178 | (np.ndarray): L1 waveform loss with shape [batch]
179 | """
180 | assert (self.sess is not None) and (not self.sess._closed)
181 | feed_dict={self.tensor_wave0: y0, self.tensor_wave1: y1}
182 | return self.sess.run(self.loss_waveform, feed_dict=feed_dict)
183 |
184 |
185 | def cochlear_model_loss(self, y0, y1):
186 | """
187 | Method to compute cochlear model loss between two waveforms
188 | under the constructed auditory model.
189 |
190 | Args
191 | ----
192 | y0 (np.ndarray): waveform with shape [batch, 40000]
193 | y1 (np.ndarray): waveform with shape [batch, 40000]
194 |
195 | Returns
196 | -------
197 | (np.ndarray): L1 cochlear model loss with shape [batch]
198 | """
199 | assert (self.sess is not None) and (not self.sess._closed)
200 | feed_dict={self.tensor_wave0: y0, self.tensor_wave1: y1}
201 | return self.sess.run(self.loss_cochlear_model, feed_dict=feed_dict)
202 |
203 |
204 | def deep_feature_loss(self, y0, y1):
205 | """
206 | Method to compute deep feature loss between two waveforms
207 | under the constructed auditory model.
208 |
209 | Args
210 | ----
211 | y0 (np.ndarray): waveform with shape [batch, 40000]
212 | y1 (np.ndarray): waveform with shape [batch, 40000]
213 |
214 | Returns
215 | -------
216 | (np.ndarray): L1 deep feature loss with shape [batch]
217 | """
218 | assert (self.sess is not None) and (not self.sess._closed)
219 | if not self.vars_loaded:
220 | print(("WARNING: `deep_feature_loss` called before loading vars"))
221 | feed_dict={self.tensor_wave0: y0, self.tensor_wave1: y1}
222 | return self.sess.run(self.loss_deep_features, feed_dict=feed_dict)
223 |
--------------------------------------------------------------------------------
/util_cochlear_model.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is closely based on pycochleagram and tfcochleagram,
3 | which have been previously released:
4 |
5 | https://github.com/mcdermottLab/pycochleagram
6 | https://github.com/jenellefeather/tfcochleagram
7 |
8 | Minor modifications have been made here to provide a single script
9 | containing all functions needed to build the cochlear model used
10 | in this project.
11 | """
12 |
13 | from __future__ import absolute_import
14 | from __future__ import division
15 | from __future__ import print_function
16 |
17 | import os
18 | import sys
19 | import warnings
20 | import functools
21 | import numpy as np
22 | import tensorflow as tf
23 | import scipy.signal as signal
24 | import matplotlib.pyplot as plt
25 |
26 |
27 | def freq2erb(freq_hz):
28 | """Converts Hz to human-defined ERBs, using the formula of Glasberg and Moore.
29 |
30 | Args:
31 | freq_hz (array_like): frequency to use for ERB.
32 |
33 | Returns:
34 | ndarray: **n_erb** -- Human-defined ERB representation of input.
35 | """
36 | return 9.265 * np.log(1 + freq_hz / (24.7 * 9.265))
37 |
38 |
39 | def erb2freq(n_erb):
40 | """Converts human ERBs to Hz, using the formula of Glasberg and Moore.
41 | Args:
42 | n_erb (array_like): Human-defined ERB to convert to frequency.
43 | Returns:
44 | ndarray: **freq_hz** -- Frequency representation of input.
45 | """
46 | return 24.7 * 9.265 * (np.exp(n_erb / 9.265) - 1)
47 |
48 |
49 | def get_freq_rand_conversions(xp, seed=0, minval=0.0, maxval=1.0):
50 | """Generates freq2rand and rand2freq conversion functions.
51 |
52 | Args:
53 | xp (array_like): xvals for freq2rand linear interpolation.
54 | seed (int): numpy seed to generate yvals for linear interpolation.
55 | minval (float): yvals for linear interpolation are scaled to [minval, maxval].
56 | maxval (float): yvals for linear interpolation are scaled to [minval, maxval].
57 |
58 | Returns:
59 | freq2rand (function): converts Hz to random frequency scale
60 | rand2freq (function): converts random frequency scale to Hz
61 | """
62 | np.random.seed(seed)
63 | yp = np.cumsum(np.random.poisson(size=xp.shape))
64 | yp = ((maxval - minval) * (yp - yp.min())) / (yp.max() - yp.min()) + minval
65 | freq2rand = lambda x : np.interp(x, xp, yp)
66 | rand2freq = lambda y : np.interp(y, yp, xp)
67 | return freq2rand, rand2freq
68 |
69 |
70 | def make_cosine_filter(freqs, l, h, convert_to_erb=True):
71 | """Generate a half-cosine filter. Represents one subband of the cochleagram.
72 | A half-cosine filter is created using the values of freqs that are within the
73 | interval [l, h]. The half-cosine filter is centered at the center of this
74 | interval, i.e., (h - l) / 2. Values outside the valid interval [l, h] are
75 | discarded. So, if freqs = [1, 2, 3, ... 10], l = 4.5, h = 8, the cosine filter
76 | will only be defined on the domain [5, 6, 7] and the returned output will only
77 | contain 3 elements.
78 | Args:
79 | freqs (array_like): Array containing the domain of the filter, in ERB space;
80 | see convert_to_erb parameter below.. A single half-cosine
81 | filter will be defined only on the valid section of these values;
82 | specifically, the values between cutoffs ``l`` and ``h``. A half-cosine filter
83 | centered at (h - l ) / 2 is created on the interval [l, h].
84 | l (float): The lower cutoff of the half-cosine filter in ERB space; see
85 | convert_to_erb parameter below.
86 | h (float): The upper cutoff of the half-cosine filter in ERB space; see
87 | convert_to_erb parameter below.
88 | convert_to_erb (bool, default=True): If this is True, the values in
89 | input arguments ``freqs``, ``l``, and ``h`` will be transformed from Hz to ERB
90 | space before creating the half-cosine filter. If this is False, the
91 | input arguments are assumed to be in ERB space.
92 | Returns:
93 | ndarray: **half_cos_filter** -- A half-cosine filter defined using elements of
94 | freqs within [l, h].
95 | """
96 | if convert_to_erb:
97 | freqs_erb = freq2erb(freqs)
98 | l_erb = freq2erb(l)
99 | h_erb = freq2erb(h)
100 | else:
101 | freqs_erb = freqs
102 | l_erb = l
103 | h_erb = h
104 |
105 | avg_in_erb = (l_erb + h_erb) / 2 # center of filter
106 | rnge_in_erb = h_erb - l_erb # width of filter
107 | # return np.cos((freq2erb(freqs[a_l_ind:a_h_ind+1]) - avg)/rnge * np.pi) # h_ind+1 to include endpoint
108 | # return np.cos((freqs_erb[(freqs_erb >= l_erb) & (freqs_erb <= h_erb)]- avg_in_erb) / rnge_in_erb * np.pi) # map cutoffs to -pi/2, pi/2 interval
109 | return np.cos((freqs_erb[(freqs_erb > l_erb) & (freqs_erb < h_erb)]- avg_in_erb) / rnge_in_erb * np.pi) # map cutoffs to -pi/2, pi/2 interval
110 |
111 |
112 | def make_full_filter_set(filts, signal_length=None):
113 | """Create the full set of filters by extending the filterbank to negative FFT
114 | frequencies.
115 | Args:
116 | filts (array_like): Array containing the cochlear filterbank in frequency space,
117 | i.e., the output of make_cos_filters_nx. Each row of ``filts`` is a
118 | single filter, with columns indexing frequency.
119 | signal_length (int, optional): Length of the signal to be filtered with this filterbank.
120 | This should be equal to filter length * 2 - 1, i.e., 2*filts.shape[1] - 1, and if
121 | signal_length is None, this value will be computed with the above formula.
122 | This parameter might be deprecated later.
123 |
124 | Returns:
125 | ndarray: **full_filter_set** -- Array containing the complete filterbank in
126 | frequency space. This output can be directly applied to the frequency
127 | representation of a signal.
128 | """
129 | if signal_length is None:
130 | signal_length = 2 * filts.shape[1] - 1
131 |
132 | # note that filters are currently such that each ROW is a filter and COLUMN idxs freq
133 | if np.remainder(signal_length, 2) == 0: # even -- don't take the DC & don't double sample nyquist
134 | neg_filts = np.flipud(filts[1:filts.shape[0] - 1, :])
135 | else: # odd -- don't take the DC
136 | neg_filts = np.flipud(filts[1:filts.shape[0], :])
137 | fft_filts = np.vstack((filts, neg_filts))
138 | # we need to switch representation to apply filters to fft of the signal, not sure why, but do it here
139 | return fft_filts.T
140 |
141 |
142 | def make_cos_filters_nx(signal_length, sr, n, low_lim, hi_lim, sample_factor,
143 | padding_size=None, full_filter=True, strict=True,
144 | bandwidth_scale_factor=1.0, include_lowpass=True,
145 | include_highpass=True, filter_spacing='erb'):
146 | """Create cosine filters, oversampled by a factor provided by "sample_factor"
147 | Args:
148 | signal_length (int): Length of signal to be filtered with the generated
149 | filterbank. The signal length determines the length of the filters.
150 | sr (int): Sampling rate associated with the signal waveform.
151 | n (int): Number of filters (subbands) to be generated with standard
152 | sampling (i.e., using a sampling factor of 1). Note, the actual number of
153 | filters in the generated filterbank depends on the sampling factor, and
154 | may optionally include lowpass and highpass filters that allow for
155 | perfect reconstruction of the input signal (the exact number of lowpass
156 | and highpass filters is determined by the sampling factor). The
157 | number of filters in the generated filterbank is given below:
158 | +---------------+---------------+-+------------+---+---------------------+
159 | | sample factor | n_out |=| bandpass |\ +| highpass + lowpass |
160 | +===============+===============+=+============+===+=====================+
161 | | 1 | n+2 |=| n |\ +| 1 + 1 |
162 | +---------------+---------------+-+------------+---+---------------------+
163 | | 2 | 2*n+1+4 |=| 2*n+1 |\ +| 2 + 2 |
164 | +---------------+---------------+-+------------+---+---------------------+
165 | | 4 | 4*n+3+8 |=| 4*n+3 |\ +| 4 + 4 |
166 | +---------------+---------------+-+------------+---+---------------------+
167 | | s | s*(n+1)-1+2*s |=| s*(n+1)-1 |\ +| s + s |
168 | +---------------+---------------+-+------------+---+---------------------+
169 | low_lim (int): Lower limit of frequency range. Filters will not be defined
170 | below this limit.
171 | hi_lim (int): Upper limit of frequency range. Filters will not be defined
172 | above this limit.
173 | sample_factor (int): Positive integer that determines how densely ERB function
174 | will be sampled to create bandpass filters. 1 represents standard sampling;
175 | adjacent bandpass filters will overlap by 50%. 2 represents 2x overcomplete sampling;
176 | adjacent bandpass filters will overlap by 75%. 4 represents 4x overcomplete sampling;
177 | adjacent bandpass filters will overlap by 87.5%.
178 | padding_size (int, optional): If None (default), the signal will not be padded
179 | before filtering. Otherwise, the filters will be created assuming the
180 | waveform signal will be padded to length padding_size*signal_length.
181 | full_filter (bool, default=True): If True (default), the complete filter that
182 | is ready to apply to the signal is returned. If False, only the first
183 | half of the filter is returned (likely positive terms of FFT).
184 | strict (bool, default=True): If True (default), will throw an error if
185 | sample_factor is not a power of two. This facilitates comparison across
186 | sample_factors. Also, if True, will throw an error if provided hi_lim
187 | is greater than the Nyquist rate.
188 | bandwidth_scale_factor (float, default=1.0): scales the bandpass filter bandwidths.
189 | bandwidth_scale_factor=2.0 means half-cosine filters will be twice as wide.
190 | Note that values < 1 will cause frequency gaps between the filters.
191 | bandwidth_scale_factor requires sample_factor=1, include_lowpass=False, include_highpass=False.
192 | include_lowpass (bool, default=True): if set to False, lowpass filter will be discarded.
193 | include_highpass (bool, default=True): if set to False, highpass filter will be discarded.
194 | filter_spacing (str, default='erb'): Specifies the type of reference spacing for the
195 | half-cosine filters. Options include 'erb' and 'linear'.
196 | Returns:
197 | tuple:
198 | A tuple containing the output:
199 | * **filts** (*array*)-- The filterbank consisting of filters have
200 | cosine-shaped frequency responses, with center frequencies equally
201 | spaced from low_lim to hi_lim on a scale specified by filter_spacing
202 | * **center_freqs** (*array*) -- center frequencies of filterbank in filts
203 | * **freqs** (*array*) -- freq vector in Hz, same frequency dimension as filts
204 | Raises:
205 | ValueError: Various value errors for bad choices of sample_factor or frequency
206 | limits; see description for strict parameter.
207 | UserWarning: Raises warning if cochlear filters exceed the Nyquist
208 | limit or go below 0.
209 | NotImplementedError: Raises error if specified filter_spacing is not implemented
210 | """
211 |
212 | # Specifiy the type of filter spacing, if using linear filters instead
213 | if filter_spacing == 'erb':
214 | _freq2ref = freq2erb
215 | _ref2freq = erb2freq
216 | elif filter_spacing == 'erb_r':
217 | _freq2ref = lambda x: freq2erb(hi_lim) - freq2erb(hi_lim - x)
218 | _ref2freq = lambda x: hi_lim - erb2freq(freq2erb(hi_lim) - x)
219 | elif (filter_spacing == 'lin') or (filter_spacing == 'linear'):
220 | _freq2ref = lambda x: x
221 | _ref2freq = lambda x: x
222 | elif 'random' in filter_spacing:
223 | _freq2ref, _ref2freq = get_freq_rand_conversions(
224 | np.linspace(low_lim, hi_lim, n),
225 | seed=int(filter_spacing.split('-')[1].replace('seed', '')),
226 | minval=freq2erb(low_lim),
227 | maxval=freq2erb(hi_lim))
228 | else:
229 | raise NotImplementedError('unrecognized spacing mode: %s' % filter_spacing)
230 | print('[make_cos_filters_nx] using filter_spacing=`{}`'.format(filter_spacing))
231 |
232 | if not bandwidth_scale_factor == 1.0:
233 | assert sample_factor == 1, "bandwidth_scale_factor only supports sample_factor=1"
234 | assert include_lowpass == False, "bandwidth_scale_factor only supports include_lowpass=False"
235 | assert include_highpass == False, "bandwidth_scale_factor only supports include_highpass=False"
236 |
237 | if not isinstance(sample_factor, int):
238 | raise ValueError('sample_factor must be an integer, not %s' % type(sample_factor))
239 | if sample_factor <= 0:
240 | raise ValueError('sample_factor must be positive')
241 |
242 | if sample_factor != 1 and np.remainder(sample_factor, 2) != 0:
243 | msg = 'sample_factor odd, and will change filter widths. Use even sample factors for comparison.'
244 | if strict:
245 | raise ValueError(msg)
246 | else:
247 | warnings.warn(msg, RuntimeWarning, stacklevel=2)
248 |
249 | if padding_size is not None and padding_size >= 1:
250 | signal_length += padding_size
251 |
252 | if np.remainder(signal_length, 2) == 0: # even length
253 | n_freqs = signal_length // 2 # .0 does not include DC, likely the sampling grid
254 | max_freq = sr / 2 # go all the way to nyquist
255 | else: # odd length
256 | n_freqs = (signal_length - 1) // 2 # .0
257 | max_freq = sr * (signal_length - 1) / 2 / signal_length # just under nyquist
258 |
259 | # verify the high limit is allowed by the sampling rate
260 | if hi_lim > sr / 2:
261 | hi_lim = max_freq
262 | msg = 'input arg "hi_lim" exceeds nyquist limit for max frequency; ignore with "strict=False"'
263 | if strict:
264 | raise ValueError(msg)
265 | else:
266 | warnings.warn(msg, RuntimeWarning, stacklevel=2)
267 |
268 | # changing the sampling density without changing the filter locations
269 | # (and, thereby changing their widths) requires that a certain number of filters
270 | # be used.
271 | n_filters = sample_factor * (n + 1) - 1
272 | n_lp_hp = 2 * sample_factor
273 | freqs = np.linspace(0, max_freq, n_freqs + 1)
274 | filts = np.zeros((n_freqs + 1, n_filters + n_lp_hp))
275 |
276 | # cutoffs are evenly spaced on the scale specified by filter_spacing; for ERB scale,
277 | # interpolate linearly in erb space then convert back.
278 | # Also return the actual spacing used to generate the sequence (in case numpy does
279 | # something weird)
280 | center_freqs, step_spacing = np.linspace(_freq2ref(low_lim), _freq2ref(hi_lim), n_filters + 2, retstep=True) # +2 for bin endpoints
281 | # we need to exclude the endpoints
282 | center_freqs = center_freqs[1:-1]
283 |
284 | freqs_ref = _freq2ref(freqs)
285 | for i in range(n_filters):
286 | i_offset = i + sample_factor
287 | l = center_freqs[i] - sample_factor * bandwidth_scale_factor * step_spacing
288 | h = center_freqs[i] + sample_factor * bandwidth_scale_factor * step_spacing
289 | if _ref2freq(h) > sr/2:
290 | cf = _ref2freq(center_freqs[i])
291 | msg = "High ERB cutoff of filter with cf={:.2f}Hz exceeds {:.2f}Hz (Nyquist frequency)"
292 | warnings.warn(msg.format(cf, sr/2))
293 | if _ref2freq(l) < 0:
294 | cf = _ref2freq(center_freqs[i])
295 | msg = 'Low ERB cutoff of filter with cf={:.2f}Hz is not strictly positive'
296 | warnings.warn(msg.format(cf))
297 | # the first sample_factor # of rows in filts will be lowpass filters
298 | filts[(freqs_ref > l) & (freqs_ref < h), i_offset] = make_cosine_filter(freqs_ref, l, h, convert_to_erb=False)
299 |
300 | # add lowpass and highpass filters (there will be sample_factor number of each)
301 | for i in range(sample_factor):
302 | # account for the fact that the first sample_factor # of filts are lowpass
303 | i_offset = i + sample_factor
304 | lp_h_ind = max(np.where(freqs < _ref2freq(center_freqs[i]))[0]) # lowpass filter goes up to peak of first cos filter
305 | lp_filt = np.sqrt(1 - np.power(filts[:lp_h_ind+1, i_offset], 2))
306 |
307 | hp_l_ind = min(np.where(freqs > _ref2freq(center_freqs[-1-i]))[0]) # highpass filter goes down to peak of last cos filter
308 | hp_filt = np.sqrt(1 - np.power(filts[hp_l_ind:, -1-i_offset], 2))
309 |
310 | filts[:lp_h_ind+1, i] = lp_filt
311 | filts[hp_l_ind:, -1-i] = hp_filt
312 |
313 | # get center freqs for lowpass and highpass filters
314 | cfs_low = np.copy(center_freqs[:sample_factor]) - sample_factor * step_spacing
315 | cfs_hi = np.copy(center_freqs[-sample_factor:]) + sample_factor * step_spacing
316 | center_freqs = np.concatenate((cfs_low, center_freqs, cfs_hi))
317 |
318 | # ensure that squared freq response adds to one
319 | filts = filts / np.sqrt(sample_factor)
320 |
321 | # convert center freqs from ERB numbers to Hz
322 | center_freqs = _ref2freq(center_freqs)
323 |
324 | # rectify
325 | center_freqs[center_freqs < 0] = 1
326 |
327 | # discard highpass and lowpass filters, if requested
328 | if include_lowpass == False:
329 | filts = filts[:, sample_factor:]
330 | center_freqs = center_freqs[sample_factor:]
331 | if include_highpass == False:
332 | filts = filts[:, :-sample_factor]
333 | center_freqs = center_freqs[:-sample_factor]
334 |
335 | # make the full filter by adding negative components
336 | if full_filter:
337 | filts = make_full_filter_set(filts, signal_length)
338 |
339 | return filts, center_freqs, freqs
340 |
341 |
342 | def tflog10(x):
343 | """Implements log base 10 in tensorflow """
344 | numerator = tf.log(x)
345 | denominator = tf.log(tf.constant(10, dtype=numerator.dtype))
346 | return numerator / denominator
347 |
348 |
349 | @tf.custom_gradient
350 | def stable_power_compression_norm_grad(x):
351 | """With this power compression function, the gradients from the power compression are not applied via backprop, we just pass the previous gradient onwards"""
352 | e = tf.nn.relu(x) # add relu to x to avoid NaN in loss
353 | p = tf.pow(e,0.3)
354 | def grad(dy): #try to check for nans before we clip the gradients. (use tf.where)
355 | return dy
356 | return p, grad
357 |
358 |
359 | @tf.custom_gradient
360 | def stable_power_compression(x):
361 | """Clip the gradients for the power compression and remove nans. Clipped values are (-1,1), so any cochleagram value below ~0.2 will be clipped."""
362 | e = tf.nn.relu(x) # add relu to x to avoid NaN in loss
363 | p = tf.pow(e,0.3)
364 | def grad(dy): #try to check for nans before we clip the gradients. (use tf.where)
365 | g = 0.3 * pow(e,-0.7)
366 | is_nan_values = tf.is_nan(g)
367 | replace_nan_values = tf.ones(tf.shape(g), dtype=tf.float32)*1
368 | return dy * tf.where(is_nan_values,replace_nan_values,tf.clip_by_value(g, -1, 1))
369 | return p, grad
370 |
371 |
372 | def cochleagram_graph(nets, SIGNAL_SIZE, SR, ENV_SR=200, LOW_LIM=20, HIGH_LIM=8000, N=40, SAMPLE_FACTOR=4, compression='none', WINDOW_SIZE=1001, debug=False, subbands_ifft=False, pycoch_downsamp=False, linear_max=796.87416837456942, input_node='input_signal', mean_subtract=False, rms_normalize=False, SMOOTH_ABS = False, return_subbands_only=False, include_all_keys=False, rectify_and_lowpass_subbands=False, pad_factor=None, return_coch_params=False, rFFT=False, linear_params=None, custom_filts=None, custom_compression_op=None, erb_filter_kwargs={}, reshape_kell2018=False, include_subbands_noise=False, subbands_noise_mean=0., subbands_noise_stddev=0., rate_level_kwargs={}, preprocess_kwargs={}):
373 | """
374 | Creates a tensorflow cochleagram graph using the pycochleagram erb filters to create the cochleagram with the tensorflow functions.
375 | Parameters
376 | ----------
377 | nets : dictionary
378 | dictionary containing parts of the cochleagram graph. At a minumum, nets['input_signal'] (or equivilant) should be defined containing a placeholder (if just constructing cochleagrams) or a variable (if optimizing over the cochleagrams), and can have a batch size>1.
379 | SIGNAL_SIZE : int
380 | the length of the audio signal used for the cochleagram graph
381 | SR : int
382 | raw sampling rate in Hz for the audio.
383 | ENV_SR : int
384 | the sampling rate for the cochleagram after downsampling
385 | LOW_LIM : int
386 | Lower frequency limits for the filters.
387 | HIGH_LIM : int
388 | Higher frequency limits for the filters.
389 | N : int
390 | Number of filters to uniquely span the frequency space
391 | SAMPLE_FACTOR : int
392 | number of times to overcomplete the filters.
393 | compression : string. see include_compression for compression options
394 | determine compression type to use in the cochleagram graph. If return_subbands is true, compress the rectified subbands
395 | WINDOW_SIZE : int
396 | the size of a window to use for the downsampling filter
397 | debug : boolean
398 | Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal when set to True (default False).
399 | subbands_ifft : boolean
400 | If true, adds the ifft of the subbands to nets
401 | input_node : string
402 | Name of the top level of nets, this is the input into the cochleagram graph.
403 | mean_subtract : boolean
404 | If true, subtracts the mean of the waveform (explicitly removes the DC offset)
405 | rms_normalize : Boolean # ONLY USE WHEN GENERATING COCHLEAGRAMS
406 | If true, divides the input signal by its RMS value, such that the RMS value of the sound going into the cochleagram generation is equal to 1. This option should be false if inverting cochleagrams, as it can cause problems with the gradients
407 | linear_max : float
408 | If default value, use 796.87416837456942, which is the 5th percentile from the speech dataset when it is rms normalized to a value of 1. This value is only used if the compression is 'linearbelow1', 'linearbelow1sqrt', 'stable_point3'
409 | SMOOTH_ABS : Boolean
410 | If True, uses a smoother version of the absolute value for the hilbert transform sqrt(10^-3 + real(env) + imag(env))
411 | return_subbands_only : Boolean
412 | If True, returns the non-envelope extracted subbands before taking the hilbert envelope as the output node of the graph
413 | include_all_keys : Boolean
414 | If True, returns all of the cochleagram and subbands processing keys in the dictionary
415 | rectify_and_lowpass_subbands : Boolean
416 | If True, rectifies and lowpasses the subbands before returning them (only works with return_subbands_only)
417 | pad_factor : int
418 | how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length)
419 | return_coch_params : Boolean
420 | If True, returns the cochleagram generation parameters in addition to nets
421 | rFFT : Boolean
422 | If True, builds the graph using rFFT and irFFT operations whenever possible
423 | linear_params : list of floats
424 | used for the linear compression operation, [m, b] where the output of the compression is y=mx+b. m and b can be vectors of shape [1,num_filts,1] to apply different values to each frequency channel.
425 | custom_filts : None, or numpy array
426 | if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS]
427 | custom_compression_op : None or tensorflow partial function
428 | if specified as a function, applies the tensorflow function as a custom compression operation. Should take the input node and 'name' as the arguments
429 | erb_filter_kwargs : dictionary
430 | contains additional arguments with filter parameters to use with erb.make_erb_cos_filters
431 | reshape_kell2018 : boolean (False)
432 | if true, reshapes the output cochleagram to be 256x256 as used by kell2018
433 | include_subbands_noise : boolean (False)
434 | if include_subbands_noise and return_subbands_only are both true, white noise is added to subbands after compression (this feature is currently only accessible when return_subbands_only == True)
435 | subbands_noise_mean : float
436 | sets mean of subbands white noise if include_subbands_noise == True
437 | subbands_noise_stddev : float
438 | sets standard deviation of subbands white noise if include_subbands_noise == True
439 | rate_level_kwargs : dictionary
440 | contains keyword arguments for AN_rate_level_function (used if compression == 'rate_level')
441 | preprocess_kwargs : dictionary
442 | contains keyword arguments for preprocess_input function (used to randomize input dB SPL)
443 |
444 | Returns
445 | -------
446 | nets : dictionary
447 | a dictionary containing the parts of the cochleagram graph. Top node in this graph is nets['output_tfcoch_graph']
448 | COCH_PARAMS : dictionary (Optional)
449 | a dictionary containing all of the input parameters into the function
450 | """
451 | if return_coch_params:
452 | COCH_PARAMS = locals()
453 | COCH_PARAMS.pop('nets')
454 |
455 | # run preprocessing operations on the input (ie rms normalization, convert to complex)
456 | nets = preprocess_input(nets, SIGNAL_SIZE, input_node, mean_subtract, rms_normalize, rFFT, **preprocess_kwargs)
457 |
458 | # fft of the input
459 | nets = fft_of_input(nets, pad_factor,debug, rFFT)
460 |
461 | # Make a wrapper for the compression function so it can be applied to the cochleagram and the subbands
462 | compression_function = functools.partial(include_compression, compression=compression, linear_max=linear_max, linear_params=linear_params, rate_level_kwargs=rate_level_kwargs, custom_compression_op=custom_compression_op)
463 |
464 | # make cochlear filters and compute the cochlear subbands
465 | nets = extract_cochlear_subbands(nets, SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, pad_factor, debug, subbands_ifft, return_subbands_only, rectify_and_lowpass_subbands, rFFT, custom_filts, erb_filter_kwargs, include_all_keys, compression_function, include_subbands_noise, subbands_noise_mean, subbands_noise_stddev)
466 |
467 | # Build the rest of the graph for the downsampled cochleagram, if we are returning the cochleagram or if we want to build the whole graph anyway.
468 | if (not return_subbands_only) or include_all_keys:
469 | # hilbert transform on subband fft
470 | nets = hilbert_transform_from_fft(nets, SR, SIGNAL_SIZE, pad_factor, debug, rFFT)
471 |
472 | # absolute value of the envelopes (and expand to one channel)
473 | nets = abs_envelopes(nets, SMOOTH_ABS)
474 |
475 | # downsample and rectified nonlinearity
476 | nets = downsample_and_rectify(nets, SR, ENV_SR, WINDOW_SIZE, pycoch_downsamp)
477 |
478 | # compress cochleagram
479 | nets = compression_function(nets, input_node_name='cochleagram_no_compression', output_node_name='cochleagram')
480 |
481 | if reshape_kell2018:
482 | nets, output_node_name_coch = reshape_coch_kell_2018(nets)
483 | else:
484 | output_node_name_coch = 'cochleagram'
485 |
486 | if return_subbands_only:
487 | nets['output_tfcoch_graph'] = nets['subbands_time_processed']
488 | else:
489 | nets['output_tfcoch_graph'] = nets[output_node_name_coch]
490 |
491 | # return
492 | if return_coch_params:
493 | return nets, COCH_PARAMS
494 | else:
495 | return nets
496 |
497 |
498 | def preprocess_input(nets, SIGNAL_SIZE, input_node, mean_subtract, rms_normalize, rFFT,
499 | set_dBSPL=False, dBSPL_range=[60., 60.]):
500 | """
501 | Does preprocessing on the input (rms and converting to complex number)
502 | Parameters
503 | ----------
504 | nets : dictionary
505 | dictionary containing parts of the cochleagram graph. should already contain input_node
506 | input_node : string
507 | Name of the top level of nets, this is the input into the cochleagram graph.
508 | mean_subtract : boolean
509 | If true, subtracts the mean of the waveform (explicitly removes the DC offset)
510 | rms_normalize : Boolean # TODO: incorporate stable gradient code for RMS
511 | If true, divides the input signal by its RMS value, such that the RMS value of the sound going
512 | rFFT : Boolean
513 | If true, preprocess input for using the rFFT operations
514 | set_dBSPL : Boolean
515 | If true, re-scale input waveform to dB SPL sampled uniformly from dBSPL_range
516 | dBSPL_range : list
517 | Range of sound presentation levels in units of dB re 20e-6 Pa ([minval, maxval])
518 | Returns
519 | -------
520 | nets : dictionary
521 | updated dictionary containing parts of the cochleagram graph.
522 | """
523 |
524 | if rFFT:
525 | if SIGNAL_SIZE%2!=0:
526 | print('rFFT is only tested with even length signals. Change your input length.')
527 | return
528 |
529 | processed_input_node = input_node
530 |
531 | if mean_subtract:
532 | processed_input_node = processed_input_node + '_mean_subtract'
533 | nets[processed_input_node] = nets[input_node] - tf.reshape(tf.reduce_mean(nets[input_node],1),(-1,1))
534 | input_node = processed_input_node
535 |
536 | if rms_normalize: # TODO: incoporate stable RMS normalization
537 | processed_input_node = processed_input_node + '_rms_normalized'
538 | nets['rms_input'] = tf.sqrt(tf.reduce_mean(tf.square(nets[input_node]), 1))
539 | nets[processed_input_node] = tf.identity(nets[input_node]/tf.reshape(nets['rms_input'],(-1,1)),'rms_normalized_input')
540 | input_node = processed_input_node
541 |
542 | if set_dBSPL: # NOTE: unstable if RMS of input is zero
543 | processed_input_node = processed_input_node + '_set_dBSPL'
544 | assert rms_normalize == False, "rms_normalize must be False if set_dBSPL=True"
545 | assert len(dBSPL_range) == 2, "dBSPL_range must be specified as [minval, maxval]"
546 | nets['dBSPL_set'] = tf.random.uniform([tf.shape(nets[input_node])[0], 1],
547 | minval=dBSPL_range[0], maxval=dBSPL_range[1],
548 | dtype=nets[input_node].dtype, name='sample_dBSPL_set')
549 | nets['rms_set'] = 20e-6 * tf.math.pow(10., nets['dBSPL_set'] / 20.)
550 | nets['rms_input'] = tf.sqrt(tf.reduce_mean(tf.square(nets[input_node]), axis=1, keepdims=True))
551 | nets[processed_input_node] = tf.math.multiply(nets['rms_set'] / nets['rms_input'], nets[input_node],
552 | name='scale_input_to_dBSPL_set')
553 | input_node = processed_input_node
554 |
555 | if not rFFT:
556 | nets['input_signal_i'] = nets[input_node]*0.0
557 | nets['input_signal_complex'] = tf.complex(nets[input_node], nets['input_signal_i'], name='input_complex')
558 | else:
559 | nets['input_real'] = nets[input_node]
560 | return nets
561 |
562 |
563 | def fft_of_input(nets, pad_factor, debug, rFFT):
564 | """
565 | Computs the fft of the signal and adds appropriate padding
566 |
567 | Parameters
568 | ----------
569 | nets : dictionary
570 | dictionary containing parts of the cochleagram graph. 'subbands' are used for the hilbert transform
571 | pad_factor : int
572 | how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length)
573 | debug : boolean
574 | Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal when set to True.
575 | rFFT : Boolean
576 | If true, cochleagram graph is constructed using rFFT wherever possible
577 | Returns
578 | -------
579 | nets : dictionary
580 | updated dictionary containing parts of the cochleagram graph with the rFFT of the input
581 | """
582 | # fft of the input
583 | if not rFFT:
584 | if pad_factor is not None:
585 | nets['input_signal_complex'] = tf.concat([nets['input_signal_complex'], tf.zeros([nets['input_signal_complex'].get_shape()[0], nets['input_signal_complex'].get_shape()[1]*(pad_factor-1)], dtype=tf.complex64)], axis=1)
586 | nets['fft_input'] = tf.fft(nets['input_signal_complex'],name='fft_of_input')
587 | else:
588 | nets['fft_input'] = tf.spectral.rfft(nets['input_real'],name='fft_of_input') # Since the DFT of a real signal is Hermitian-symmetric, RFFT only returns the fft_length / 2 + 1 unique components of the FFT: the zero-frequency term, followed by the fft_length / 2 positive-frequency terms.
589 |
590 | nets['fft_input'] = tf.expand_dims(nets['fft_input'], 1, name='exd_fft_of_input')
591 |
592 | if debug: # return the real and imaginary parts of the fft separately
593 | nets['fft_input_r'] = tf.real(nets['fft_input'])
594 | nets['fft_input_i'] = tf.imag(nets['fft_input'])
595 | return nets
596 |
597 |
598 | def extract_cochlear_subbands(nets, SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, pad_factor, debug, subbands_ifft, return_subbands_only, rectify_and_lowpass_subbands, rFFT, custom_filts, erb_filter_kwargs, include_all_keys, compression_function, include_subbands_noise, subbands_noise_mean, subbands_noise_stddev):
599 | """
600 | Computes the cochlear subbands from the fft of the input signal
601 | Parameters
602 | ----------
603 | nets : dictionary
604 | dictionary containing parts of the cochleagram graph. 'fft_input' is multiplied by the cochlear filters
605 | SIGNAL_SIZE : int
606 | the length of the audio signal used for the cochleagram graph
607 | SR : int
608 | raw sampling rate in Hz for the audio.
609 | LOW_LIM : int
610 | Lower frequency limits for the filters.
611 | HIGH_LIM : int
612 | Higher frequency limits for the filters.
613 | N : int
614 | Number of filters to uniquely span the frequency space
615 | SAMPLE_FACTOR : int
616 | number of times to overcomplete the filters.
617 | N : int
618 | Number of filters to uniquely span the frequency space
619 | SAMPLE_FACTOR : int
620 | number of times to overcomplete the filters.
621 | pad_factor : int
622 | how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length)
623 | debug : boolean
624 | Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal
625 | subbands_ifft : boolean
626 | If true, adds the ifft of the subbands to nets
627 | return_subbands_only : Boolean
628 | If True, returns the non-envelope extracted subbands before taking the hilbert envelope as the output node of the graph
629 | rectify_and_lowpass_subbands : Boolean
630 | If True, rectifies and lowpasses the subbands before returning them (only works with return_subbands_only)
631 | rFFT : Boolean
632 | If true, cochleagram graph is constructed using rFFT wherever possible
633 | custom_filts : None, or numpy array
634 | if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS]
635 | erb_filter_kwargs : dictionary
636 | contains additional arguments with filter parameters to use with erb.make_erb_cos_filters
637 | include_all_keys : Boolean
638 | If True, includes the time subbands and the cochleagram in the dictionary keys
639 | compression_function : function
640 | A partial function that takes in nets and the input and output names to apply compression
641 | include_subbands_noise : boolean (False)
642 | if include_subbands_noise and return_subbands_only are both true, white noise is added to subbands after compression (this feature is currently only accessible when return_subbands_only == True)
643 | subbands_noise_mean : float
644 | sets mean of subbands white noise if include_subbands_noise == True
645 | subbands_noise_stddev : float
646 | sets standard deviation of subbands white noise if include_subbands_noise == True
647 | Returns
648 | -------
649 | nets : dictionary
650 | updated dictionary containing parts of the cochleagram graph.
651 | """
652 |
653 | # make the erb filters tensor
654 | nets['filts_tensor'] = make_filts_tensor(SIGNAL_SIZE, SR, LOW_LIM, HIGH_LIM, N, SAMPLE_FACTOR, use_rFFT=rFFT, pad_factor=pad_factor, custom_filts=custom_filts, erb_filter_kwargs=erb_filter_kwargs)
655 |
656 | # make subbands by multiplying filts with fft of input
657 | nets['subbands'] = tf.multiply(nets['filts_tensor'],nets['fft_input'],name='mul_subbands')
658 | if debug: # return the real and imaginary parts of the subbands separately -- use if matching to their output
659 | nets['subbands_r'] = tf.real(nets['subbands'])
660 | nets['subbands_i'] = tf.imag(nets['subbands'])
661 |
662 | # TODO: with using subbands_ifft is redundant.
663 | # make the time subband operations if we are returning the subbands or if we want to include all of the keys in the graph
664 | if subbands_ifft or return_subbands_only or include_all_keys:
665 | if not rFFT:
666 | nets['subbands_ifft'] = tf.real(tf.ifft(nets['subbands'],name='ifft_subbands'),name='ifft_subbands_r')
667 | else:
668 | nets['subbands_ifft'] = tf.spectral.irfft(nets['subbands'],name='ifft_subbands')
669 | if return_subbands_only or include_all_keys:
670 | nets['subbands_time'] = nets['subbands_ifft']
671 | if rectify_and_lowpass_subbands: # TODO: the subband operations are hard coded in?
672 | nets['subbands_time_relu'] = tf.nn.relu(nets['subbands_time'], name='rectified_subbands')
673 | nets['subbands_time_lowpassed'] = hanning_pooling_1d_no_depthwise(nets['subbands_time_relu'], downsample=2, length_of_window=2*4, make_plots=False, data_format='NCW', normalize=True, sqrt_window=False)
674 |
675 | # TODO: noise is only added in the case when we are calcalculating the time subbands, but we might want something similar for the cochleagram
676 | if return_subbands_only or include_all_keys:
677 | # Compress subbands if specified and add noise.
678 | nets = compression_function(nets, input_node_name='subbands_time_lowpassed', output_node_name='subbands_time_lowpassed_compressed')
679 | if include_subbands_noise:
680 | nets = add_neural_noise(nets, subbands_noise_mean, subbands_noise_stddev, input_node_name='subbands_time_lowpassed_compressed', output_node_name='subbands_time_lowpassed_compressed_with_noise')
681 | nets['subbands_time_lowpassed_compressed_with_noise'] = tf.expand_dims(nets['subbands_time_lowpassed_compressed_with_noise'],-1)
682 | nets['subbands_time_processed'] = nets['subbands_time_lowpassed_compressed_with_noise']
683 | else:
684 | nets['subbands_time_lowpassed_compressed'] = tf.expand_dims(nets['subbands_time_lowpassed_compressed'],-1)
685 | nets['subbands_time_processed'] = nets['subbands_time_lowpassed_compressed']
686 |
687 | return nets
688 |
689 |
690 | def hilbert_transform_from_fft(nets, SR, SIGNAL_SIZE, pad_factor, debug, rFFT):
691 | """
692 | Performs the hilbert transform from the subband FFT -- gets ifft using only the real parts of the signal
693 | Parameters
694 | ----------
695 | nets : dictionary
696 | dictionary containing parts of the cochleagram graph. 'subbands' are used for the hilbert transform
697 | SR : int
698 | raw sampling rate in Hz for the audio.
699 | SIGNAL_SIZE : int
700 | the length of the audio signal used for the cochleagram graph
701 | pad_factor : int
702 | how much padding to add to the signal. Follows conventions of pycochleagram (ie pad of 2 doubles the signal length)
703 | debug : boolean
704 | Adds more nodes to the graph for explicitly defining the real and imaginary parts of the signal when set to True.
705 | rFFT : Boolean
706 | If true, cochleagram graph is constructed using rFFT wherever possible
707 | """
708 |
709 | if not rFFT:
710 | # make the step tensor for the hilbert transform (only keep the real components)
711 | if pad_factor is not None:
712 | freq_signal = np.fft.fftfreq(SIGNAL_SIZE*pad_factor, 1./SR)
713 | else:
714 | freq_signal = np.fft.fftfreq(SIGNAL_SIZE,1./SR)
715 | nets['step_tensor'] = make_step_tensor(freq_signal)
716 |
717 | # envelopes in frequency domain -- hilbert transform of the subbands
718 | nets['envelopes_freq'] = tf.multiply(nets['subbands'],nets['step_tensor'],name='env_freq')
719 | else:
720 | # make the padding to turn rFFT into a step function
721 | num_filts = nets['filts_tensor'].get_shape().as_list()[1]
722 | # num_batch = nets['subbands'].get_shape().as_list()[0]
723 | num_batch = tf.shape(nets['subbands'])[0]
724 | # TODO: this also might be a problem when we have pad_factor > 1
725 | print(num_batch)
726 | print(num_filts)
727 | print(int(SIGNAL_SIZE/2)-1)
728 | nets['hilbert_padding'] = tf.zeros([num_batch,num_filts,int(SIGNAL_SIZE/2)-1], tf.complex64)
729 | nets['envelopes_freq'] = tf.concat([nets['subbands'],nets['hilbert_padding']],2,name='env_freq')
730 |
731 | if debug: # return real and imaginary parts separately
732 | nets['envelopes_freq_r'] = tf.real(nets['envelopes_freq'])
733 | nets['envelopes_freq_i'] = tf.imag(nets['envelopes_freq'])
734 |
735 | # fft of the envelopes.
736 | nets['envelopes_time'] = tf.ifft(nets['envelopes_freq'],name='ifft_envelopes')
737 |
738 | if not rFFT: # TODO: was this a bug in pycochleagram where the pad factor doesn't actually work?
739 | if pad_factor is not None:
740 | nets['envelopes_time'] = nets['envelopes_time'][:,:,:SIGNAL_SIZE]
741 |
742 | if debug: # return real and imaginary parts separately
743 | nets['envelopes_time_r'] = tf.real(nets['envelopes_time'])
744 | nets['envelopes_time_i'] = tf.imag(nets['envelopes_time'])
745 |
746 | return nets
747 |
748 |
749 | def abs_envelopes(nets, SMOOTH_ABS):
750 | """
751 | Absolute value of the envelopes (and expand to one channel), analytic hilbert signal
752 |
753 | Parameters
754 | ----------
755 | nets : dictionary
756 | dictionary containing the cochleagram graph. Downsampling will be applied to 'envelopes_time'
757 | SMOOTH_ABS : Boolean
758 | If True, uses a smoother version of the absolute value for the hilbert transform sqrt(10^-3 + real(env) + imag(env))
759 | Returns
760 | -------
761 | nets : dictionary
762 | dictionary containing the updated cochleagram graph
763 | """
764 |
765 | if SMOOTH_ABS:
766 | nets['envelopes_abs'] = tf.sqrt(1e-10 + tf.square(tf.real(nets['envelopes_time'])) + tf.square(tf.imag(nets['envelopes_time'])))
767 | else:
768 | nets['envelopes_abs'] = tf.abs(nets['envelopes_time'], name='complex_abs_envelopes')
769 | nets['envelopes_abs'] = tf.expand_dims(nets['envelopes_abs'],3, name='exd_abs_real_envelopes')
770 | return nets
771 |
772 |
773 | def downsample_and_rectify(nets, SR, ENV_SR, WINDOW_SIZE, pycoch_downsamp):
774 | """
775 | Downsamples the cochleagram and then performs rectification on the output (in case the downsampling results in small negative numbers)
776 | Parameters
777 | ----------
778 | nets : dictionary
779 | dictionary containing the cochleagram graph. Downsampling will be applied to 'envelopes_abs'
780 | SR : int
781 | raw sampling rate of the audio signal
782 | ENV_SR : int
783 | end sampling rate of the envelopes
784 | WINDOW_SIZE : int
785 | the size of the downsampling window (should be large enough to go to zero on the edges).
786 | pycoch_downsamp : Boolean
787 | if true, uses a slightly different downsampling function
788 | Returns
789 | -------
790 | nets : dictionary
791 | dictionary containing parts of the cochleagram graph with added nodes for the downsampled subbands
792 | """
793 | # The stride for the downsample, works fine if it is an integer.
794 | DOWNSAMPLE = SR/ENV_SR
795 | if not ENV_SR == SR:
796 | # make the downsample tensor
797 | nets['downsample_filt_tensor'] = make_downsample_filt_tensor(SR, ENV_SR, WINDOW_SIZE, pycoch_downsamp=pycoch_downsamp)
798 | nets['cochleagram_preRELU'] = tf.nn.conv2d(nets['envelopes_abs'], nets['downsample_filt_tensor'], [1, 1, DOWNSAMPLE, 1], 'SAME',name='conv2d_cochleagram_raw')
799 | else:
800 | nets['cochleagram_preRELU'] = nets['envelopes_abs']
801 | nets['cochleagram_no_compression'] = tf.nn.relu(nets['cochleagram_preRELU'], name='coch_no_compression')
802 |
803 | return nets
804 |
805 |
806 | def include_compression(nets, compression='none', linear_max=796.87416837456942, input_node_name='cochleagram_no_compression', output_node_name='cochleagram', linear_params=None, rate_level_kwargs={}, custom_compression_op=None):
807 | """
808 | Choose compression operation to use and adds appropriate nodes to nets
809 | Parameters
810 | ----------
811 | nets : dictionary
812 | dictionary containing parts of the cochleagram graph. Compression will be applied to input_node_name
813 | compression : string
814 | type of compression to perform
815 | linear_max : float
816 | used for the linearbelow compression operations (compression is linear below a value and compressed above it)
817 | input_node_name : string
818 | name in nets to apply the compression
819 | output_node_name : string
820 | name in nets that will be used for the following operation (default is cochleagram, but if returning subbands than it can be chaged)
821 | linear_params : list of floats
822 | used for the linear compression operation, [m, b] where the output of the compression is y=mx+b. m and b can be vectors of shape [1,num_filts,1] to apply different values to each frequency channel.
823 | custom_compression_op : None or tensorflow partial function
824 | if specified as a function, applies the tensorflow function as a custom compression operation. Should take the input node and 'name' as the arguments
825 | Returns
826 | -------
827 | nets : dictionary
828 | dictionary containing parts of the cochleagram graph with added nodes for the compressed cochleagram
829 | """
830 | # compression of the cochleagram
831 | if compression=='quarter':
832 | nets[output_node_name] = tf.sqrt(tf.sqrt(nets[input_node_name], name=output_node_name))
833 | elif compression=='quarter_plus':
834 | nets[output_node_name] = tf.sqrt(tf.sqrt(nets[input_node_name]+1e-01, name=output_node_name))
835 | elif compression=='point3':
836 | nets[output_node_name] = tf.pow(nets[input_node_name],0.3, name=output_node_name)
837 | elif compression=='stable_point3':
838 | nets[output_node_name] = tf.identity(stable_power_compression(nets[input_node_name]*linear_max),name=output_node_name)
839 | elif compression=='stable_point3_norm_grads':
840 | nets[output_node_name] = tf.identity(stable_power_compression_norm_grad(nets[input_node_name]*linear_max),name=output_node_name)
841 | elif compression=='linearbelow1':
842 | nets[output_node_name] = tf.where((nets[input_node_name]*linear_max)<1, nets[input_node_name]*linear_max, tf.pow(nets[input_node_name]*linear_max,0.3), name=output_node_name)
843 | elif compression=='stable_linearbelow1':
844 | nets['stable_power_compressed_%s'%output_node_name] = tf.identity(stable_power_compression(nets[input_node_name]*linear_max),name='stable_power_compressed_%s'%output_node_name)
845 | nets[output_node_name] = tf.where((nets[input_node_name]*linear_max)<1, nets[input_node_name]*linear_max, nets['stable_power_compressed_%s'%output_node_name], name=output_node_name)
846 | elif compression=='linearbelow1sqrt':
847 | nets[output_node_name] = tf.where((nets[input_node_name]*linear_max)<1, nets[input_node_name]*linear_max, tf.sqrt(nets[input_node_name]*linear_max), name=output_node_name)
848 | elif compression=='quarter_clipped':
849 | nets[output_node_name] = tf.sqrt(tf.sqrt(tf.maximum(nets[input_node_name],1e-01), name=output_node_name))
850 | elif compression=='none':
851 | nets[output_node_name] = nets[input_node_name]
852 | elif compression=='sqrt':
853 | nets[output_node_name] = tf.sqrt(nets[input_node_name], name=output_node_name)
854 | elif compression=='dB': # NOTE: this compression does not work well for the backwards pass, results in nans
855 | nets[output_node_name + '_noclipped'] = 20 * tflog10(nets[input_node_name])/tf.reduce_max(nets[input_node_name])
856 | nets[output_node_name] = tf.maximum(nets[output_node_name + '_noclipped'], -60)
857 | elif compression=='dB_plus': # NOTE: this compression does not work well for the backwards pass, results in nans
858 | nets[output_node_name + '_noclipped'] = 20 * tflog10(nets[input_node_name]+1)/tf.reduce_max(nets[input_node_name]+1)
859 | nets[output_node_name] = tf.maximum(nets[output_node_name + '_noclipped'], -60, name=output_node_name)
860 | elif compression=='linear':
861 | assert (type(linear_params)==list) and len(linear_params)==2, "Specifying linear compression but not specifying the compression parameters in linear_params=[m, b]"
862 | nets[output_node_name] = linear_params[0]*nets[input_node_name] + linear_params[1]
863 | elif compression=='rate_level':
864 | nets[output_node_name] = AN_rate_level_function(nets[input_node_name], name=output_node_name, **rate_level_kwargs)
865 | elif compression=='custom':
866 | nets[output_node_name] = custom_compression_op(nets[input_node_name], name=output_node_name)
867 |
868 | return nets
869 |
870 |
871 | def make_step_tensor(freq_signal):
872 | """
873 | Make step tensor for calcaulting the anlyatic envelopes.
874 | Parameters
875 | __________
876 | freq_signal : array
877 | numpy array containing the frequenies of the audio signal (as calculated by np.fft.fftfreqs).
878 | Returns
879 | -------
880 | step_tensor : tensorflow tensor
881 | tensorflow tensor with dimensions [0 len(freq_signal) 0 0] as a step function where frequencies > 0 are 1 and frequencies < 0 are 0.
882 | """
883 | step_func = (freq_signal>=0).astype(np.int)*2 # wikipedia says that this should be 2x the original.
884 | step_func[freq_signal==0] = 0 # https://en.wikipedia.org/wiki/Analytic_signal (this shouldn't actually matter i think.
885 | step_tensor = tf.constant(step_func, dtype=tf.complex64)
886 | step_tensor = tf.expand_dims(step_tensor, 0)
887 | step_tensor = tf.expand_dims(step_tensor, 1)
888 | return step_tensor
889 |
890 |
891 | def make_filts_tensor(SIGNAL_SIZE, SR=16000, LOW_LIM=20, HIGH_LIM=8000, N=40, SAMPLE_FACTOR=4, use_rFFT=False, pad_factor=None, custom_filts=None, erb_filter_kwargs={}):
892 | """
893 | Use pycochleagram to make the filters using the specified prameters (make_erb_cos_filters_nx). Then input them into a tensorflow tensor to be used in the tensorflow cochleagram graph.
894 | Parameters
895 | ----------
896 | SIGNAL_SIZE: int
897 | length of the audio signal to convert, and the size of cochleagram filters to make.
898 | SR : int
899 | raw sampling rate in Hz for the audio.
900 | LOW_LIM : int
901 | Lower frequency limits for the filters.
902 | HIGH_LIM : int
903 | Higher frequency limits for the filters.
904 | N : int
905 | Number of filters to uniquely span the frequency space
906 | SAMPLE_FACTOR : int
907 | number of times to overcomplete the filters.
908 | use_rFFT : Boolean
909 | if True, the only returns the first half of the filters, corresponding to the positive component.
910 | custom_filts : None, or numpy array
911 | if not None, a numpy array containing the filters to use for the cochleagram generation. If none, uses erb.make_erb_cos_filters from pycochleagram to construct the filterbank. If using rFFT, should contain th full filters, shape [SIGNAL_SIZE, NUMBER_OF_FILTERS]
912 | erb_filter_kwargs : dictionary
913 | contains additional arguments with filter parameters to use with erb.make_erb_cos_filters
914 | Returns
915 | -------
916 | filts_tensor : tensorflow tensor, complex
917 | tensorflow tensor with dimensions [0 SIGNAL_SIZE NUMBER_OF_FILTERS] that includes the erb filters created from make_erb_cos_filters_nx in pycochleagram
918 | """
919 | if pad_factor:
920 | padding_size = (pad_factor-1)*SIGNAL_SIZE
921 | else:
922 | padding_size=None
923 |
924 | if custom_filts is None:
925 | # make the filters
926 | filts, hz_cutoffs, freqs = make_erb_cos_filters_nx(SIGNAL_SIZE, SR, N, LOW_LIM, HIGH_LIM, SAMPLE_FACTOR, padding_size=padding_size, **erb_filter_kwargs) #TODO: decide if we want to change the pad_factor and full_filter arguments.
927 | else: # TODO: ADD CHECKS TO MAKE SURE THAT THESE MATCH UP WITH THE INPUT SIGNAL
928 | assert custom_filts.shape[1] == SIGNAL_SIZE, "CUSTOM FILTER SHAPE DOES NOT MATCH THE INPUT AUDIO SHAPE"
929 | filts = custom_filts
930 |
931 | if not use_rFFT:
932 | filts_tensor = tf.constant(filts, tf.complex64)
933 | else: # TODO I believe that this is where the padd factor problem comes in! We are only using part of the signal here.
934 | filts_tensor = tf.constant(filts[:,0:(int(SIGNAL_SIZE/2)+1)], tf.complex64)
935 |
936 | filts_tensor = tf.expand_dims(filts_tensor, 0)
937 |
938 | return filts_tensor
939 |
940 |
941 | def make_downsample_filt_tensor(SR=16000, ENV_SR=200, WINDOW_SIZE=1001, pycoch_downsamp=False):
942 | """
943 | Make the sinc filter that will be used to downsample the cochleagram
944 | Parameters
945 | ----------
946 | SR : int
947 | raw sampling rate of the audio signal
948 | ENV_SR : int
949 | end sampling rate of the envelopes
950 | WINDOW_SIZE : int
951 | the size of the downsampling window (should be large enough to go to zero on the edges).
952 | pycoch_downsamp : Boolean
953 | if true, uses a slightly different downsampling function
954 | Returns
955 | -------
956 | downsample_filt_tensor : tensorflow tensor, tf.float32
957 | a tensor of shape [0 WINDOW_SIZE 0 0] the sinc windows with a kaiser lowpass filter that is applied while downsampling the cochleagram
958 | """
959 | DOWNSAMPLE = SR/ENV_SR
960 | if not pycoch_downsamp:
961 | downsample_filter_times = np.arange(-WINDOW_SIZE/2,int(WINDOW_SIZE/2))
962 | downsample_filter_response_orig = np.sinc(downsample_filter_times/DOWNSAMPLE)/DOWNSAMPLE
963 | downsample_filter_window = signal.kaiser(WINDOW_SIZE, 5)
964 | downsample_filter_response = downsample_filter_window * downsample_filter_response_orig
965 | else:
966 | max_rate = DOWNSAMPLE
967 | f_c = 1. / max_rate # cutoff of FIR filter (rel. to Nyquist)
968 | half_len = 10 * max_rate # reasonable cutoff for our sinc-like function
969 | if max_rate!=1:
970 | downsample_filter_response = signal.firwin(2 * half_len + 1, f_c, window=('kaiser', 5.0))
971 | else: # just in case we aren't downsampling -- I think this should work?
972 | downsample_filter_response = zeros(2 * half_len + 1)
973 | downsample_filter_response[half_len + 1] = 1
974 |
975 | # Zero-pad our filter to put the output samples at the center
976 | # n_pre_pad = int((DOWNSAMPLE - half_len % DOWNSAMPLE))
977 | # n_post_pad = 0
978 | # n_pre_remove = (half_len + n_pre_pad) // DOWNSAMPLE
979 | # We should rarely need to do this given our filter lengths...
980 | # while _output_len(len(h) + n_pre_pad + n_post_pad, x.shape[axis],
981 | # up, down) < n_out + n_pre_remove:
982 | # n_post_pad += 1
983 | # downsample_filter_response = np.concatenate((np.zeros(n_pre_pad), downsample_filter_response, np.zeros(n_post_pad)))
984 |
985 | downsample_filt_tensor = tf.constant(downsample_filter_response, tf.float32)
986 | downsample_filt_tensor = tf.expand_dims(downsample_filt_tensor, 0)
987 | downsample_filt_tensor = tf.expand_dims(downsample_filt_tensor, 2)
988 | downsample_filt_tensor = tf.expand_dims(downsample_filt_tensor, 3)
989 |
990 | return downsample_filt_tensor
991 |
992 |
993 | def add_neural_noise(nets, subbands_noise_mean, subbands_noise_stddev, input_node_name='subbands_time_lowpassed_compressed', output_node_name='subbands_time_lowpassed_compressed_with_noise'):
994 | # Add white noise variable with the same size to the rectified and compressed subbands
995 | nets['neural_noise'] = tf.random.normal(tf.shape(nets[input_node_name]), mean=subbands_noise_mean,
996 | stddev=subbands_noise_stddev, dtype=nets[input_node_name].dtype)
997 | nets[output_node_name] = tf.nn.relu(tf.math.add(nets[input_node_name], nets['neural_noise']))
998 | return nets
999 |
1000 |
1001 | def reshape_coch_kell_2018(nets):
1002 | """
1003 | Wrapper to reshape the cochleagram to 256x256 similar to that used in kell2018.
1004 | Note that this function relies on tf.image.resize_images which can have unexpected behavior... use with caution.
1005 | nets : dictionary
1006 | dictionary containing parts of the cochleagram graph. should already contain cochleagram
1007 | """
1008 | print('### WARNING: tf.image.resize_images is not trusted, use caution ###')
1009 | nets['min_cochleagram'] = tf.reduce_min(nets['cochleagram'])
1010 | nets['max_cochleagram'] = tf.reduce_max(nets['cochleagram'])
1011 | # it is possible that this scaling is going to mess up the gradients for the waveform generation
1012 | nets['scaled_cochleagram'] = 255*(1-((nets['max_cochleagram']-nets['cochleagram'])/(nets['max_cochleagram']-nets['min_cochleagram'])))
1013 | nets['reshaped_cochleagram'] = tf.image.resize_images(nets['scaled_cochleagram'],[256,256], align_corners=False, preserve_aspect_ratio=False)
1014 | return nets, 'reshaped_cochleagram'
1015 |
1016 |
1017 | def convert_Pa_to_dBSPL(pa):
1018 | """ Converts units of Pa to dB re 20e-6 Pa (dB SPL) """
1019 | return 20. * np.log10(pa / 20e-6)
1020 |
1021 |
1022 | def convert_dBSPL_to_Pa(dbspl):
1023 | """ Converts units of dB re 20e-6 Pa (dB SPL) to Pa """
1024 | return 20e-6 * np.power(10., dbspl / 20.)
1025 |
1026 |
1027 | def AN_rate_level_function(tensor_subbands, name='rate_level_fcn', rate_spont=70., rate_max=250.,
1028 | rate_normalize=True, beta=3., halfmax_dBSPL=20.):
1029 | """
1030 | Function implements the auditory nerve rate-level function described by Peter Heil
1031 | and colleagues (2011, J. Neurosci.): the "amplitude-additivity model".
1032 |
1033 | Args
1034 | ----
1035 | tensor_subbands (tensor): shape must be [batch, freq, time, (channel)], units are Pa
1036 | name (str): name for the tensorflow operation
1037 | rate_spont (float): spontaneous spiking rate (spikes/s)
1038 | rate_max (float): maximum spiking rate (spikes/s)
1039 | rate_normalize (bool): if True, output will be re-scaled between 0 and 1
1040 | beta (float or list): determines the steepness of rate-level function (dimensionless)
1041 | halfmax_dBSPL (float or list): determines threshold of rate-level function (units dB SPL)
1042 |
1043 | Returns
1044 | -------
1045 | tensor_rates (tensor): same shape as tensor_subbands, units are spikes/s or normalized
1046 | """
1047 | # Check arguments and compute shape for frequency-channel-specific parameters
1048 | assert rate_spont > 0, "rate_spont must be greater than zero to avoid division by zero"
1049 | if len(tensor_subbands.shape) == 3:
1050 | freq_specific_shape = [tensor_subbands.shape[1], 1]
1051 | elif len(tensor_subbands.shape) == 4:
1052 | freq_specific_shape = [tensor_subbands.shape[1], 1, 1]
1053 | else:
1054 | raise ValueError("tensor_subbands must have shape [batch, freq, time, (channel)]")
1055 | # Convert beta to tensor (can be a single value or frequency channel specific)
1056 | beta = np.array(beta).reshape([-1])
1057 | assert_msg = "beta must be one value or a list of length {}".format(tensor_subbands.shape[1])
1058 | assert len(beta) == 1 or len(beta) == tensor_subbands.shape[1], assert_msg
1059 | beta_vals = tf.constant(beta,
1060 | dtype=tensor_subbands.dtype,
1061 | shape=freq_specific_shape)
1062 | # Convert halfmax_dBSPL to tensor (can be a single value or frequency channel specific)
1063 | halfmax_dBSPL = np.array(halfmax_dBSPL).reshape([-1])
1064 | assert_msg = "halfmax_dBSPL must be one value or a list of length {}".format(tensor_subbands.shape[1])
1065 | assert len(halfmax_dBSPL) == 1 or len(halfmax_dBSPL) == tensor_subbands.shape[1], assert_msg
1066 | P_halfmax = tf.constant(convert_dBSPL_to_Pa(halfmax_dBSPL),
1067 | dtype=tensor_subbands.dtype,
1068 | shape=freq_specific_shape)
1069 | # Convert rate_spont and rate_max to tf.constants (single values)
1070 | R_spont = tf.constant(rate_spont, dtype=tensor_subbands.dtype, shape=[])
1071 | R_max = tf.constant(rate_max, dtype=tensor_subbands.dtype, shape=[])
1072 | # Implementation analogous to equation (8) from Heil et al. (2011, J. Neurosci.)
1073 | P_0 = P_halfmax / (tf.pow((R_max + R_spont) / R_spont, 1/beta_vals) - 1)
1074 | R_func = lambda P: R_max / (1 + ((R_max - R_spont) / R_spont) * tf.pow(P / P_0 + 1, -beta_vals))
1075 | tensor_rates = tf.map_fn(R_func, tensor_subbands, name=name)
1076 | # If rate_normalize is True, re-scale spiking rates to fall between 0 and 1
1077 | if rate_normalize:
1078 | tensor_rates = (tensor_rates - R_spont) / (R_max - R_spont)
1079 | return tensor_rates
1080 |
1081 |
1082 | def make_hanning_kernel_1d(downsample=2, length_of_window=8, make_plots=False, normalize=False, sqrt_window=True):
1083 | """
1084 | Make the symmetric 1d hanning kernel to use for the pooling filters
1085 | For downsample=2, using length_of_window=8 gives a reduction of -24.131545969216841 at 0.25 cycles
1086 | For downsample=3, using length_of_window=12 gives a reduction of -28.607805482176282 at 1/6 cycles
1087 | For downsample=4, using length_of_window=15 gives a reduction of -23 at 1/8 cycles
1088 | We want to reduce the frequencies above the nyquist by at least 20dB.
1089 | Parameters
1090 | ----------
1091 | downsample : int
1092 | proportion downsampling
1093 | length_of_window : int
1094 | how large of a window to use
1095 | make_plots: boolean
1096 | make plots of the filters
1097 | normalize : boolean
1098 | if true, divide the filter by the sum of its values, so that the smoothed signal is the same amplitude as the original.
1099 | sqrt_window : boolean
1100 | if true, takes the sqrt of the window (old version) -- normal window generation has sqrt_window=False
1101 | Returns
1102 | -------
1103 | one_dimensional_kernel : numpy array
1104 | hanning kernel in 1d to use as a kernel for filtering
1105 | """
1106 |
1107 | window = 0.5 * (1 - np.cos(2.0 * np.pi * (np.arange(length_of_window)) / (length_of_window - 1)))
1108 | if sqrt_window:
1109 | one_dimensional_kernel = np.sqrt(window)
1110 | else:
1111 | one_dimensional_kernel = window
1112 |
1113 | if normalize:
1114 | one_dimensional_kernel = one_dimensional_kernel/sum(one_dimensional_kernel)
1115 | window = one_dimensional_kernel
1116 |
1117 | if make_plots:
1118 | A = np.fft.fft(window, 2048) / (len(window) / 2.0)
1119 | freq = np.linspace(-0.5, 0.5, len(A))
1120 | response = 20.0 * np.log10(np.abs(np.fft.fftshift(A / abs(A).max())))
1121 |
1122 | nyquist = 1 / (2 * downsample)
1123 | ny_idx = np.where(np.abs(freq - nyquist) == np.abs(freq - nyquist).min())[0][0]
1124 | print(['Frequency response at ' + 'nyquist (%.3f Hz)'%nyquist + ' is ' + '%d'%response[ny_idx]])
1125 | plt.figure()
1126 | plt.plot(window)
1127 | plt.title(r"Hanning window")
1128 | plt.ylabel("Amplitude")
1129 | plt.xlabel("Sample")
1130 | plt.figure()
1131 | plt.plot(freq, response)
1132 | plt.axis([-0.5, 0.5, -120, 0])
1133 | plt.title(r"Frequency response of the Hanning window")
1134 | plt.ylabel("Normalized magnitude [dB]")
1135 | plt.xlabel("Normalized frequency [cycles per sample]")
1136 |
1137 | return one_dimensional_kernel
1138 |
1139 |
1140 | def make_hanning_kernel_tensor_1d(n_channels, downsample=2, length_of_window=8, make_plots=False, normalize=False, sqrt_window=True):
1141 | """
1142 | Make a tensor containing the symmetric 1d hanning kernel to use for the pooling filters
1143 | For downsample=2, using length_of_window=8 gives a reduction of -24.131545969216841 at 0.25 cycles
1144 | For downsample=3, using length_of_window=12 gives a reduction of -28.607805482176282 at 1/6 cycles
1145 | For downsample=4, using length_of_window=15 gives a reduction of -23 at 1/8 cycles
1146 | We want to reduce the frequencies above the nyquist by at least 20dB.
1147 | Parameters
1148 | ----------
1149 | n_channels : int
1150 | number of channels to copy the kernel into
1151 | downsample : int
1152 | proportion downsampling
1153 | length_of_window : int
1154 | how large of a window to use
1155 | make_plots: boolean
1156 | make plots of the filters
1157 | normalize : boolean
1158 | if true, divide the filter by the sum of its values, so that the smoothed signal is the same amplitude as the original.
1159 | sqrt_window : boolean
1160 | if true, takes the sqrt of the window (old version) -- normal window generation has sqrt_window=False
1161 | Returns
1162 | -------
1163 | hanning_tensor : tensorflow tensor
1164 | tensorflow tensor containing the hanning tensor with size [1 length_of_window n_channels 1]
1165 | """
1166 | hanning_kernel = make_hanning_kernel_1d(downsample=downsample,length_of_window=length_of_window,make_plots=make_plots, normalize=normalize, sqrt_window=sqrt_window)
1167 | hanning_kernel = np.expand_dims(np.dstack([hanning_kernel.astype(np.float32)]*n_channels),axis=3)
1168 | hanning_tensor = tf.constant(hanning_kernel)
1169 | return hanning_tensor
1170 |
1171 |
1172 | def hanning_pooling_1d(input_tensor, downsample=2, length_of_window=8, make_plots=False, data_format='NWC', normalize=False, sqrt_window=True):
1173 | """
1174 | Parameters
1175 | ----------
1176 | input_tensor : tensorflow tensor
1177 | tensor on which we will apply the hanning pooling operation
1178 | downsample : int
1179 | proportion downsampling
1180 | length_of_window : int
1181 | how large of a window to use
1182 | make_plots: boolean
1183 | make plots of the filters
1184 | data_format : 'NWC' or 'NCW'
1185 | Defaults to "NWC", the data is stored in the order of [batch, in_width, in_channels].
1186 | The "NCW" format stores data as [batch, in_channels, in_width].
1187 | normalize : boolean
1188 | if true, divide the filter by the sum of its values, so that the smoothed signal is the same amplitude as the original.
1189 | sqrt_window : boolean
1190 | if true, takes the sqrt of the window (old version) -- normal window generation has sqrt_window=False
1191 | Returns
1192 | -------
1193 | output_tensor : tensorflow tensor
1194 | tensorflow tensor containing the downsampled input_tensor of shape corresponding to data_format
1195 | """
1196 |
1197 | if data_format=='NWC':
1198 | n_channels = input_tensor.get_shape().as_list()[2]
1199 | elif data_format=='NCW':
1200 | batch_size, n_channels, in_width = input_tensor.get_shape().as_list()
1201 | input_tensor = tf.transpose(input_tensor, [0, 2, 1]) # reshape to [batch_size, in_wdith, in_channels]
1202 |
1203 | input_tensor = tf.expand_dims(input_tensor,1) # reshape to [batch_size, 1, in_width, in_channels]
1204 | h_tensor = make_hanning_kernel_tensor_1d(n_channels, downsample=downsample, length_of_window=length_of_window, make_plots=make_plots, normalize=normalize, sqrt_window=sqrt_window)
1205 |
1206 | output_tensor = tf.nn.depthwise_conv2d(input_tensor, h_tensor, strides=[1, downsample, downsample, 1], padding='SAME', name='hpooling')
1207 |
1208 | output_tensor = tf.squeeze(output_tensor, name='squeeze_output')
1209 | if data_format=='NWC':
1210 | return output_tensor
1211 | elif data_format=='NCW':
1212 | return tf.transpose(output_tensor, [0, 2, 1]) # reshape to [batch_size, in_channels, out_width]
1213 |
1214 |
1215 | def make_hanning_kernel_tensor_1d_no_depthwise(n_channels, downsample=2, length_of_window=8, make_plots=False, normalize=False, sqrt_window=True):
1216 | """
1217 | Make a tensor containing the symmetric 1d hanning kernel to use for the pooling filters
1218 | For downsample=2, using length_of_window=8 gives a reduction of -24.131545969216841 at 0.25 cycles
1219 | For downsample=3, using length_of_window=12 gives a reduction of -28.607805482176282 at 1/6 cycles
1220 | For downsample=4, using length_of_window=15 gives a reduction of -23 at 1/8 cycles
1221 | We want to reduce the frequencies above the nyquist by at least 20dB.
1222 | Parameters
1223 | ----------
1224 | n_channels : int
1225 | number of channels to copy the kernel into
1226 | downsample : int
1227 | proportion downsampling
1228 | length_of_window : int
1229 | how large of a window to use
1230 | make_plots: boolean
1231 | make plots of the filters
1232 | normalize : boolean
1233 | if true, divide the filter by the sum of its values, so that the smoothed signal is the same amplitude as the original.
1234 | sqrt_window : boolean
1235 | if true, takes the sqrt of the window (old version) -- normal window generation has sqrt_window=False
1236 | Returns
1237 | -------
1238 | hanning_tensor : tensorflow tensor
1239 | tensorflow tensor containing the hanning tensor with size [length_of_window, num_channels, num_channels]
1240 | """
1241 | hanning_kernel = make_hanning_kernel_1d(downsample=downsample,length_of_window=length_of_window,make_plots=make_plots, normalize=normalize, sqrt_window=sqrt_window).astype(np.float32)
1242 | hanning_kernel_expanded = np.expand_dims(hanning_kernel,0) * np.expand_dims(np.eye(n_channels),3).astype(np.float32) # [n_channels, n_channels, filter_width]
1243 | hanning_tensor = tf.constant(hanning_kernel_expanded) # [length_of_window, num_channels, num_channels]
1244 | hanning_tensor = tf.transpose(hanning_tensor, [2, 0, 1])
1245 | return hanning_tensor
1246 |
1247 |
1248 | def hanning_pooling_1d_no_depthwise(input_tensor, downsample=2, length_of_window=8, make_plots=False, data_format='NWC', normalize=False, sqrt_window=True):
1249 | """
1250 | Parameters
1251 | ----------
1252 | input_tensor : tensorflow tensor
1253 | tensor on which we will apply the hanning pooling operation
1254 | downsample : int
1255 | proportion downsampling
1256 | length_of_window : int
1257 | how large of a window to use
1258 | make_plots: boolean
1259 | make plots of the filters
1260 | data_format : 'NWC' or 'NCW'
1261 | Defaults to "NWC", the data is stored in the order of [batch, in_width, in_channels].
1262 | The "NCW" format stores data as [batch, in_channels, in_width].
1263 | normalize : boolean
1264 | if true, divide the filter by the sum of its values, so that the smoothed signal is the same amplitude as the original.
1265 | make_hanning_kernel_tensor_1d_no_depthwise
1266 | sqrt_window : boolean
1267 | if true, takes the sqrt of the window (old version) -- normal window generation has sqrt_window=False
1268 | Returns
1269 | -------
1270 | output_tensor : tensorflow tensor
1271 | tensorflow tensor containing the downsampled input_tensor of shape corresponding to data_format
1272 | """
1273 |
1274 | if data_format=='NWC':
1275 | n_channels = input_tensor.get_shape().as_list()[2]
1276 | elif data_format=='NCW':
1277 | batch_size, n_channels, in_width = input_tensor.get_shape().as_list()
1278 | input_tensor = tf.transpose(input_tensor, [0, 2, 1]) # reshape to [batch_size, in_wdith, in_channels]
1279 |
1280 | h_tensor = make_hanning_kernel_tensor_1d_no_depthwise(n_channels, downsample=downsample, length_of_window=length_of_window, make_plots=make_plots, normalize=normalize, sqrt_window=sqrt_window)
1281 |
1282 | output_tensor = tf.nn.conv1d(input_tensor, h_tensor, stride=downsample, padding='SAME', name='hpooling')
1283 |
1284 | if data_format=='NWC':
1285 | return output_tensor
1286 | elif data_format=='NCW':
1287 | return tf.transpose(output_tensor, [0, 2, 1]) # reshape to [batch_size, in_channels, out_width]
1288 |
1289 |
1290 | def build_cochlear_model(tensor_waveform,
1291 | signal_rate=20000,
1292 | filter_type='half-cosine',
1293 | filter_spacing='erb',
1294 | HIGH_LIM=8000,
1295 | LOW_LIM=20,
1296 | N=40,
1297 | SAMPLE_FACTOR=1,
1298 | bandwidth_scale_factor=1.0,
1299 | compression='stable_point3',
1300 | include_highpass=False,
1301 | include_lowpass=False,
1302 | linear_max=1.0,
1303 | rFFT=True,
1304 | rectify_and_lowpass_subbands=True,
1305 | return_subbands_only=True,
1306 | **kwargs):
1307 | """
1308 | This function serves as a wrapper for `tfcochleagram_graph` and builds the cochlear model graph.
1309 | * * * * * * Default arguments are set to those used to train recognition networks * * * * * *
1310 |
1311 | Parameters
1312 | ----------
1313 | tensor_waveform (tensor): input signal waveform (with shape [batch, time])
1314 | signal_rate (int): sampling rate of signal waveform in Hz
1315 | filter_type (str): type of cochlear filters to build ('half-cosine')
1316 | filter_spacing (str, default='erb'): Specifies the type of reference spacing for the
1317 | half-cosine filters. Options include 'erb' and 'linear'.
1318 | HIGH_LIM (float): high frequency cutoff of filterbank (only used for 'half-cosine')
1319 | LOW_LIM (float): low frequency cutoff of filterbank (only used for 'half-cosine')
1320 | N (int): number of cochlear bandpass filters
1321 | SAMPLE_FACTOR (int): specifies how densely to sample cochlea (only used for 'half-cosine')
1322 | bandwidth_scale_factor (float): factor by which to symmetrically scale the filter bandwidths
1323 | bandwidth_scale_factor=2.0 means filters will be twice as wide.
1324 | Note that values < 1 will cause frequency gaps between the filters.
1325 | include_highpass (bool): determines if filterbank includes highpass filter(s) (only used for 'half-cosine')
1326 | include_lowpass (bool): determines if filterbank includes lowpass filter(s) (only used for 'half-cosine')
1327 | linear_max (float): used for the linearbelow compression operations
1328 | (compression is linear below a value and compressed above it)
1329 | rFFT (bool): If True, builds the graph using rFFT and irFFT operations whenever possible
1330 | rectify_and_lowpass_subbands (bool): If True, rectifies and lowpass-filters subbands before returning
1331 | return_subbands_only (bool): If True, returns subbands before taking the hilbert envelope as the output node
1332 | kwargs (dict): additional keyword arguments passed directly to tfcochleagram_graph
1333 |
1334 | Returns
1335 | -------
1336 | tensor_cochlear_representation (tensor): output cochlear representation
1337 | coch_container (dict): dictionary containing cochlear model stages
1338 | """
1339 | signal_length = tensor_waveform.get_shape().as_list()[-1]
1340 |
1341 | if filter_type == 'half-cosine':
1342 | assert HIGH_LIM <= signal_rate/2, "cochlear filterbank high_lim is above Nyquist frequency"
1343 | filts, center_freqs, freqs = make_cos_filters_nx(
1344 | signal_length,
1345 | signal_rate,
1346 | N,
1347 | LOW_LIM,
1348 | HIGH_LIM,
1349 | SAMPLE_FACTOR,
1350 | padding_size=None,
1351 | full_filter=True,
1352 | strict=True,
1353 | bandwidth_scale_factor=bandwidth_scale_factor,
1354 | include_lowpass=include_lowpass,
1355 | include_highpass=include_highpass,
1356 | filter_spacing=filter_spacing)
1357 | assert filts.shape[1] == signal_length, "filter array shape must match signal length"
1358 | else:
1359 | raise ValueError('Specified filter_type {} is not supported'.format(filter_type))
1360 |
1361 | coch_container = {'input_signal': tensor_waveform}
1362 | coch_container = cochleagram_graph(
1363 | coch_container,
1364 | signal_length,
1365 | signal_rate,
1366 | LOW_LIM=LOW_LIM,
1367 | HIGH_LIM=HIGH_LIM,
1368 | N=N,
1369 | SAMPLE_FACTOR=SAMPLE_FACTOR,
1370 | custom_filts=filts,
1371 | linear_max=linear_max,
1372 | rFFT=rFFT,
1373 | rectify_and_lowpass_subbands=rectify_and_lowpass_subbands,
1374 | return_subbands_only=return_subbands_only,
1375 | **kwargs)
1376 |
1377 | tensor_cochlear_representation = coch_container['output_tfcoch_graph']
1378 | return tensor_cochlear_representation, coch_container
1379 |
--------------------------------------------------------------------------------
/util_recognition_network.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import tensorflow as tf
5 |
6 |
7 | def build_network(tensor_input, list_layer_dict, n_classes_dict={}):
8 | """
9 | Build tensorflow graph for a feedforward neural network given an
10 | input tensor and list of layer descriptions
11 | """
12 | tensor_output = tensor_input
13 | tensors_dict = {}
14 | for layer_dict in list_layer_dict:
15 | layer_type = layer_dict['layer_type']
16 | if 'batch_normalization' in layer_type:
17 | layer = tf.keras.layers.BatchNormalization(**layer_dict['args'])
18 | elif 'conv2d' in layer_type:
19 | layer = PaddedConv2D(**layer_dict['args'])
20 | elif 'dense' in layer_type:
21 | layer = tf.keras.layers.Dense(**layer_dict['args'])
22 | elif 'dropout' in layer_type:
23 | layer = tf.keras.layers.Dropout(**layer_dict['args'])
24 | elif 'hpool' in layer_type:
25 | layer = HanningPooling(**layer_dict['args'])
26 | elif 'flatten' in layer_type:
27 | layer = tf.keras.layers.Flatten(**layer_dict['args'])
28 | elif 'leaky_relu' in layer_type:
29 | layer = tf.keras.layers.LeakyReLU(**layer_dict['args'])
30 | elif 'relu' in layer_type:
31 | layer = tf.keras.layers.ReLU(**layer_dict['args'])
32 | elif 'fc_top' in layer_type:
33 | layer = DenseTaskHeads(
34 | n_classes_dict=n_classes_dict,
35 | **layer_dict['args'])
36 | else:
37 | msg = "layer_type={} not recognized".format(layer_type)
38 | raise NotImplementedError(msg)
39 | tensor_output = layer(tensor_output)
40 | if layer_dict.get('args', {}).get('name', None) is not None:
41 | tensors_dict[layer_dict['args']['name']] = tensor_output
42 | return tensor_output, tensors_dict
43 |
44 |
45 | def same_pad_along_axis(tensor_input,
46 | kernel_length=1,
47 | stride_length=1,
48 | axis=1,
49 | **kwargs):
50 | """
51 | Adds 'SAME' padding to only specified axis of tensor_input
52 | for 2D convolution
53 | """
54 | x = tensor_input.shape.as_list()[axis]
55 | if x % stride_length == 0:
56 | p = kernel_length - stride_length
57 | else:
58 | p = kernel_length - (x % stride_length)
59 | p = tf.math.maximum(p, 0)
60 | paddings = [(0, 0)] * len(tensor_input.shape)
61 | paddings[axis] = (p // 2, p - p // 2)
62 | return tf.pad(tensor_input, paddings, **kwargs)
63 |
64 |
65 | def PaddedConv2D(filters,
66 | kernel_size,
67 | strides=(1, 1),
68 | padding='VALID',
69 | **kwargs):
70 | """
71 | Wrapper function around tf.keras.layers.Conv2D to support
72 | custom padding options
73 | """
74 | if padding.upper() == 'VALID_TIME':
75 | pad_function = lambda x : same_pad_along_axis(
76 | x,
77 | kernel_length=kernel_size[0],
78 | stride_length=strides[0],
79 | axis=1)
80 | padding = 'VALID'
81 | else:
82 | pad_function = lambda x: x
83 | def layer(tensor_input):
84 | conv2d_layer = tf.keras.layers.Conv2D(
85 | filters,
86 | kernel_size,
87 | strides=strides,
88 | padding=padding,
89 | **kwargs)
90 | return conv2d_layer(pad_function(tensor_input))
91 | return layer
92 |
93 |
94 | def HanningPooling(strides=2,
95 | pool_size=8,
96 | padding='SAME',
97 | sqrt_window=False,
98 | normalize=False,
99 | name=None):
100 | """
101 | Weighted average pooling layer with Hanning window applied via
102 | 2D convolution (with identity transform as depthwise component)
103 | """
104 | if isinstance(strides, int):
105 | strides = (strides, strides)
106 | if isinstance(pool_size, int):
107 | pool_size = (pool_size, pool_size)
108 | assert len(strides) == 2, "HanningPooling expects 2D args"
109 | assert len(pool_size) == 2, "HanningPooling expects 2D args"
110 |
111 | (dim0, dim1) = pool_size
112 | if dim0 == 1:
113 | win0 = np.ones(dim0)
114 | else:
115 | win0 = (1 - np.cos(2 * np.pi * np.arange(dim0) / (dim0 - 1))) / 2
116 | if dim1 == 1:
117 | win1 = np.ones(dim1)
118 | else:
119 | win1 = (1 - np.cos(2 * np.pi * np.arange(dim1) / (dim1 - 1))) / 2
120 | hanning_window = np.outer(win0, win1)
121 | if sqrt_window:
122 | hanning_window = np.sqrt(hanning_window)
123 | if normalize:
124 | hanning_window = hanning_window / hanning_window.sum()
125 |
126 | if padding.upper() == 'VALID_TIME':
127 | pad_function = lambda x : same_pad_along_axis(
128 | x,
129 | kernel_length=pool_size[0],
130 | stride_length=strides[0],
131 | axis=1)
132 | padding = 'VALID'
133 | else:
134 | pad_function = lambda x: x
135 |
136 | def layer(tensor_input):
137 | tensor_hanning_window = tf.constant(
138 | hanning_window[:, :, np.newaxis, np.newaxis],
139 | dtype=tensor_input.dtype,
140 | name="{}_hanning_window".format(name))
141 | tensor_eye = tf.eye(
142 | num_rows=tensor_input.shape.as_list()[-1],
143 | num_columns=None,
144 | batch_shape=[1, 1],
145 | dtype=tensor_input.dtype,
146 | name=None)
147 | tensor_output = tf.nn.convolution(
148 | pad_function(tensor_input),
149 | tensor_hanning_window * tensor_eye,
150 | strides=strides,
151 | padding=padding,
152 | data_format=None,
153 | name=name)
154 | return tensor_output
155 |
156 | return layer
157 |
158 |
159 | def DenseTaskHeads(n_classes_dict={}, name='logits', **kwargs):
160 | """
161 | Dense layer for each task head specified in n_classes_dict
162 | """
163 | def layer(tensor_input):
164 | tensors_logits = {}
165 | for key in sorted(n_classes_dict.keys()):
166 | if len(n_classes_dict.keys()) > 1:
167 | classification_name = '{}_{}'.format(name, key)
168 | else:
169 | classification_name = name
170 | classification_layer = tf.keras.layers.Dense(
171 | units=n_classes_dict[key],
172 | name=classification_name,
173 | **kwargs)
174 | tensors_logits[key] = classification_layer(tensor_input)
175 | return tensors_logits
176 | return layer
177 |
--------------------------------------------------------------------------------