├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── .pylintrc ├── .vscode └── settings.json ├── Audio_Toolbox_Demonstrations.ipynb ├── LICENSE.txt ├── README.md ├── __init__.py ├── python_auditory_toolbox ├── auditory_toolbox.py ├── auditory_toolbox_comparison_test.py ├── auditory_toolbox_jax.py ├── auditory_toolbox_jax_test.py ├── auditory_toolbox_test.py ├── auditory_toolbox_torch.py ├── auditory_toolbox_torch_test.py └── examples │ ├── CorrelogramPitchExample.png │ ├── DudaTones.wav │ ├── DudaVowelsCorrelogram.mp4 │ ├── GammatoneFilterResponse.png │ ├── LeonPitch.wav │ ├── LeonVowels.wav │ ├── TapestryFilterbank.png │ ├── TapestryGammatoneFeatures.png │ ├── TapestryReconstruction.png │ ├── TapestrySpectrogram.png │ └── tapestry.wav ├── requirements.txt └── setup.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 3.11 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.11" 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install pylint pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with pylint 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | pylint $(git ls-files '*.py') 35 | - name: Test with pytest 36 | run: | 37 | pytest 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *cypthon* 2 | **/__pycache__ 3 | .vscode 4 | **/.DS_Store 5 | dependency_links.txt 6 | requires.txt 7 | PKG-INFO 8 | SOURCES.txt 9 | top_level.txt 10 | dist/** 11 | 12 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MAIN] 9 | 10 | # Files or directories to be skipped. They should be base names, not paths. 11 | ignore=third_party 12 | 13 | # Files or directories matching the regex patterns are skipped. The regex 14 | # matches against base names, not paths. 15 | ignore-patterns= 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=no 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | # Use multiple processes to speed up Pylint. 25 | jobs=4 26 | 27 | # Allow loading of arbitrary C extensions. Extensions are imported into the 28 | # active Python interpreter and may run arbitrary code. 29 | unsafe-load-any-extension=no 30 | 31 | 32 | [MESSAGES CONTROL] 33 | 34 | # Only show warnings with the listed confidence levels. Leave empty to show 35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 36 | confidence= 37 | 38 | # Enable the message, report, category or checker with the given id(s). You can 39 | # either give multiple identifier separated by comma (,) or put this option 40 | # multiple time (only on the command line, not in the configuration file where 41 | # it should appear only once). See also the "--disable" option for examples. 42 | #enable= 43 | 44 | # Disable the message, report, category or checker with the given id(s). You 45 | # can either give multiple identifiers separated by comma (,) or put this 46 | # option multiple times (only on the command line, not in the configuration 47 | # file where it should appear only once).You can also use "--disable=all" to 48 | # disable everything first and then reenable specific checks. For example, if 49 | # you want to run only the similarities checker, you can use "--disable=all 50 | # --enable=similarities". If you want to run only the classes checker, but have 51 | # no Warning level messages displayed, use"--disable=all --enable=classes 52 | # --disable=W" 53 | disable=R, 54 | abstract-method, 55 | apply-builtin, 56 | arguments-differ, 57 | attribute-defined-outside-init, 58 | backtick, 59 | bad-option-value, 60 | basestring-builtin, 61 | buffer-builtin, 62 | c-extension-no-member, 63 | consider-using-enumerate, 64 | cmp-builtin, 65 | cmp-method, 66 | coerce-builtin, 67 | coerce-method, 68 | delslice-method, 69 | div-method, 70 | eq-without-hash, 71 | execfile-builtin, 72 | file-builtin, 73 | filter-builtin-not-iterating, 74 | fixme, 75 | getslice-method, 76 | global-statement, 77 | hex-method, 78 | idiv-method, 79 | implicit-str-concat, 80 | import-error, 81 | import-self, 82 | import-star-module-level, 83 | input-builtin, 84 | intern-builtin, 85 | invalid-str-codec, 86 | locally-disabled, 87 | long-builtin, 88 | long-suffix, 89 | map-builtin-not-iterating, 90 | misplaced-comparison-constant, 91 | missing-function-docstring, 92 | metaclass-assignment, 93 | next-method-called, 94 | next-method-defined, 95 | no-absolute-import, 96 | no-init, # added 97 | no-member, 98 | no-name-in-module, 99 | no-self-use, 100 | nonzero-method, 101 | oct-method, 102 | old-division, 103 | old-ne-operator, 104 | old-octal-literal, 105 | old-raise-syntax, 106 | parameter-unpacking, 107 | print-statement, 108 | raising-string, 109 | range-builtin-not-iterating, 110 | raw_input-builtin, 111 | rdiv-method, 112 | reduce-builtin, 113 | relative-import, 114 | reload-builtin, 115 | round-builtin, 116 | setslice-method, 117 | signature-differs, 118 | standarderror-builtin, 119 | suppressed-message, 120 | sys-max-int, 121 | trailing-newlines, 122 | unichr-builtin, 123 | unicode-builtin, 124 | unnecessary-pass, 125 | unpacking-in-except, 126 | useless-else-on-loop, 127 | useless-suppression, 128 | using-cmp-argument, 129 | wrong-import-order, 130 | xrange-builtin, 131 | zip-builtin-not-iterating, 132 | 133 | 134 | [REPORTS] 135 | 136 | # Set the output format. Available formats are text, parseable, colorized, msvs 137 | # (visual studio) and html. You can also give a reporter class, eg 138 | # mypackage.mymodule.MyReporterClass. 139 | output-format=text 140 | 141 | # Tells whether to display a full report or only the messages 142 | reports=no 143 | 144 | # Python expression which should return a note less than 10 (10 is the highest 145 | # note). You have access to the variables errors warning, statement which 146 | # respectively contain the number of errors / warnings messages and the total 147 | # number of statements analyzed. This is used by the global evaluation report 148 | # (RP0004). 149 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 150 | 151 | # Template used to display messages. This is a python new-style format string 152 | # used to format the message information. See doc for all details 153 | #msg-template= 154 | 155 | 156 | [BASIC] 157 | 158 | # Good variable names which should always be accepted, separated by a comma 159 | good-names=main,_ 160 | 161 | # Bad variable names which should always be refused, separated by a comma 162 | bad-names= 163 | 164 | # Colon-delimited sets of names that determine each other's naming style when 165 | # the name regexes allow several styles. 166 | name-group= 167 | 168 | # Include a hint for the correct naming format with invalid-name 169 | include-naming-hint=no 170 | 171 | # List of decorators that produce properties, such as abc.abstractproperty. Add 172 | # to this list to register other decorators that produce valid properties. 173 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 174 | 175 | # Regular expression matching correct function names 176 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 177 | 178 | # Regular expression matching correct variable names 179 | variable-rgx=^[a-z][a-z0-9_]*$ 180 | 181 | # Regular expression matching correct constant names 182 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 183 | 184 | # Regular expression matching correct attribute names 185 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 186 | 187 | # Regular expression matching correct argument names 188 | argument-rgx=^[a-z][a-z0-9_]*$ 189 | 190 | # Regular expression matching correct class attribute names 191 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 192 | 193 | # Regular expression matching correct inline iteration names 194 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 195 | 196 | # Regular expression matching correct class names 197 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 198 | 199 | # Regular expression matching correct module names 200 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 201 | 202 | # Regular expression matching correct method names 203 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 204 | 205 | # Regular expression which should only match function or class names that do 206 | # not require a docstring. 207 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 208 | 209 | # Minimum line length for functions/classes that require docstrings, shorter 210 | # ones are exempt. 211 | docstring-min-length=12 212 | 213 | 214 | [TYPECHECK] 215 | 216 | # List of decorators that produce context managers, such as 217 | # contextlib.contextmanager. Add to this list to register other decorators that 218 | # produce valid context managers. 219 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 220 | 221 | # List of module names for which member attributes should not be checked 222 | # (useful for modules/projects where namespaces are manipulated during runtime 223 | # and thus existing member attributes cannot be deduced by static analysis. It 224 | # supports qualified module names, as well as Unix pattern matching. 225 | ignored-modules= 226 | 227 | # List of class names for which member attributes should not be checked (useful 228 | # for classes with dynamically set attributes). This supports the use of 229 | # qualified names. 230 | ignored-classes=optparse.Values,thread._local,_thread._local 231 | 232 | # List of members which are set dynamically and missed by pylint inference 233 | # system, and so shouldn't trigger E1101 when accessed. Python regular 234 | # expressions are accepted. 235 | generated-members= 236 | 237 | 238 | [FORMAT] 239 | 240 | # Maximum number of characters on a single line. 241 | max-line-length=80 242 | 243 | # TODO(https://github.com/pylint-dev/pylint/issues/3352): Direct pylint to exempt 244 | # lines made too long by directives to pytype. 245 | 246 | # Regexp for a line that is allowed to be longer than the limit. 247 | ignore-long-lines=(?x)( 248 | ^\s*(\#\ )??$| 249 | ^\s*(from\s+\S+\s+)?import\s+.+$) 250 | 251 | # Allow the body of an if to be on the same line as the test if there is no 252 | # else. 253 | single-line-if-stmt=yes 254 | 255 | # Maximum number of lines in a module 256 | max-module-lines=99999 257 | 258 | # String used as indentation unit. The internal Google style guide mandates 2 259 | # spaces. Google's externaly-published style guide says 4, consistent with 260 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 261 | # projects (like TensorFlow). 262 | indent-string=' ' 263 | 264 | # Number of spaces of indent required inside a hanging or continued line. 265 | indent-after-paren=4 266 | 267 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 268 | expected-line-ending-format= 269 | 270 | 271 | [MISCELLANEOUS] 272 | 273 | # List of note tags to take in consideration, separated by a comma. 274 | notes=TODO 275 | 276 | 277 | [STRING] 278 | 279 | # This flag controls whether inconsistent-quotes generates a warning when the 280 | # character used as a quote delimiter is used inconsistently within a module. 281 | check-quote-consistency=yes 282 | 283 | 284 | [VARIABLES] 285 | 286 | # Tells whether we should check for unused import in __init__ files. 287 | init-import=no 288 | 289 | # A regular expression matching the name of dummy variables (i.e. expectedly 290 | # not used). 291 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 292 | 293 | # List of additional names supposed to be defined in builtins. Remember that 294 | # you should avoid to define new builtins when possible. 295 | additional-builtins= 296 | 297 | # List of strings which can identify a callback function by name. A callback 298 | # name must start or end with one of those strings. 299 | callbacks=cb_,_cb 300 | 301 | # List of qualified module names which can have objects that can redefine 302 | # builtins. 303 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 304 | 305 | 306 | [LOGGING] 307 | 308 | # Logging modules to check that the string format arguments are in logging 309 | # function parameter format 310 | logging-modules=logging,absl.logging,tensorflow.io.logging 311 | 312 | 313 | [SIMILARITIES] 314 | 315 | # Minimum lines number of a similarity. 316 | min-similarity-lines=4 317 | 318 | # Ignore comments when computing similarities. 319 | ignore-comments=yes 320 | 321 | # Ignore docstrings when computing similarities. 322 | ignore-docstrings=yes 323 | 324 | # Ignore imports when computing similarities. 325 | ignore-imports=no 326 | 327 | 328 | [SPELLING] 329 | 330 | # Spelling dictionary name. Available dictionaries: none. To make it working 331 | # install python-enchant package. 332 | spelling-dict= 333 | 334 | # List of comma separated words that should not be checked. 335 | spelling-ignore-words= 336 | 337 | # A path to a file that contains private dictionary; one word per line. 338 | spelling-private-dict-file= 339 | 340 | # Tells whether to store unknown words to indicated private dictionary in 341 | # --spelling-private-dict-file option instead of raising a message. 342 | spelling-store-unknown-words=no 343 | 344 | 345 | [IMPORTS] 346 | 347 | # Deprecated modules which should not be used, separated by a comma 348 | deprecated-modules=regsub, 349 | TERMIOS, 350 | Bastion, 351 | rexec, 352 | sets 353 | 354 | # Create a graph of every (i.e. internal and external) dependencies in the 355 | # given file (report RP0402 must not be disabled) 356 | import-graph= 357 | 358 | # Create a graph of external dependencies in the given file (report RP0402 must 359 | # not be disabled) 360 | ext-import-graph= 361 | 362 | # Create a graph of internal dependencies in the given file (report RP0402 must 363 | # not be disabled) 364 | int-import-graph= 365 | 366 | # Force import order to recognize a module as part of the standard 367 | # compatibility libraries. 368 | known-standard-library= 369 | 370 | # Force import order to recognize a module as part of a third party library. 371 | known-third-party=enchant, absl 372 | 373 | # Analyse import fallback blocks. This can be used to support both Python 2 and 374 | # 3 compatible code, which means that the block might have code that exists 375 | # only in one or another interpreter, leading to false positives when analysed. 376 | analyse-fallback-blocks=no 377 | 378 | 379 | [CLASSES] 380 | 381 | # List of method names used to declare (i.e. assign) instance attributes. 382 | defining-attr-methods=__init__, 383 | __new__, 384 | setUp 385 | 386 | # List of member names, which should be excluded from the protected access 387 | # warning. 388 | exclude-protected=_asdict, 389 | _fields, 390 | _replace, 391 | _source, 392 | _make 393 | 394 | # List of valid names for the first argument in a class method. 395 | valid-classmethod-first-arg=cls, 396 | class_ 397 | 398 | # List of valid names for the first argument in a metaclass class method. 399 | valid-metaclass-classmethod-first-arg=mcs 400 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "workbench.colorTheme": "Default Light Modern" 3 | } -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Python Auditory Toolbox 2 | 3 | This is a Python port of (portions of) the 4 | [Matlab Auditory Toolbox](https://engineering.purdue.edu/~malcolm/interval/1998-010/). 5 | This package provides code built upon the 6 | [Numpy](https://numpy.org/doc/stable/index.html), 7 | [PyTorch](https://pytorch.org/), and 8 | [JAX](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) 9 | numerical libraries. 10 | 11 | The Python Auditory toolbox includes these functions from the original in Matlab: 12 | - Patternson-Holdworth ERB (Gammatone) Filter Bank 13 | - MakeErbFilters 14 | - ErbFilterBank 15 | - Correlogram Processing 16 | - CorrelogramFrame 17 | - CorrelogramArray 18 | - CorrelogramPitch 19 | - Demonstrations 20 | - MakeVowel 21 | - FMPoints 22 | - Spectrogram 23 | 24 | This toolbox does not include Lyon's Passive Long-wave Cochlear model as this model 25 | has been supersceded by [CARFAC](https://github.com/google/carfac). 26 | 27 | All functions are available on top of any of these three computational libraries: 28 | [JAX](https://github.com/google/jax), 29 | [NumPy](https://numpy.org/) or 30 | [PyTorch](https://pytorch.org/). 31 | 32 | [This colab](https://colab.research.google.com/drive/1JGm24f1kOBl-EmtscJck58LGgWkfWGO8#scrollTo=1dB7di7Nv622) 33 | provides examples of calling (and testing) this libary using the NumPy functionality. 34 | 35 | This toolbox can be used to build biophysically inspired models of the auditory periphery using JAX, 36 | PyTorch and NumPy. This can hopefully be used to further develop realistic models with better explanations of 37 | what is changing as we optimize to match different psychoacoustic tests. It may further be useful for developing 38 | auditory models such as those developed in Sarah Verhulst's 39 | ([Hearing Technology Lab on GitHub](https://github.com/HearingTechnology)) 40 | and Josh McDermott's 41 | ([Model Metamers Paper](https://www.nature.com/articles/s41593-023-01442-0)) labs. 42 | 43 | You can include the python_auditory_toolbox in your work in several ways. Via the Python package installer: 44 | 45 | pip install python_auditory_toolbox 46 | 47 | From GitHub at: 48 | 49 | https://github.com/MalcolmSlaney/python_auditory_toolbox 50 | 51 | Or see the toolbox in action (with pretty pictures) via Colab at: 52 | 53 | https://colab.research.google.com/drive/1JGm24f1kOBl-EmtscJck58LGgWkfWGO8?usp=sharing 54 | 55 | ## Note 56 | This package includes three different implementations of the auditory toolbox and thus the union 57 | of the three different import requirements. Most users will probably be only using 58 | one of the three libraries (NumPy, JAX, or PyTorch), will only need to import one of the 59 | auditory_toolbox files, and will not need all the prerequisite libraries. 60 | 61 | Please cite this work as: 62 | 63 | Malcolm Slaney and Søren Fuglsang, Python Auditory Toolbox, 2023. https://github.com/MalcolmSlaney/python_auditory_toolbox. 64 | 65 | ## Examples 66 | Here are the frequency responses for a 10-channel ERB gammatone filtebank. 67 | 68 | ![Gammatone (ERB) Filter Reponse](python_auditory_toolbox/examples/GammatoneFilterResponse.png) 69 | 70 | Here is an example of a correlogram, here with a number of harmonic examples 71 | that demonstrate the correlogram representation. 72 | 73 | or via [YouTube](https://youtu.be/kTqhfxHPcVo) 74 | 75 | MFCC (mel-frequency cepstral coefficients) is a classic speech representation 76 | that was often used in (pre-DNN) speech recognizers. 77 | It converts the original spectrogram, shown here, 78 | 79 | ![Original tapestry spectrogram](python_auditory_toolbox/examples/TapestrySpectrogram.png) 80 | 81 | into a 40 channel filterbank. And finally into a 13-dimensional cepstral representation. 82 | 83 | We can invert these steps to reconstruct the original filterbank representation 84 | 85 | ![Reconstruction of filterbank representation](python_auditory_toolbox/examples/TapestryFilterbank.png) 86 | 87 | And then the reconstruct the original spectrogram. 88 | 89 | ![Reconstruction of spectrogram](python_auditory_toolbox/examples/TapestryReconstruction.png) 90 | 91 | Note, in particular, the pitch harmonics (the horizontal banding) have been 92 | filtered out by the cepstral processing. 93 | 94 | ## Examples: PyTorch 95 | The following code block demonstrates a feature extraction scheme that involves a 96 | 64-channel ERB gammatone filterbank. While the NumPy and JAX versions mimic the original 97 | Matlab API, the PyTorch version defines a class. The output features are shown below. 98 | 99 | ```python 100 | import torch 101 | import torchaudio 102 | import matplotlib.pyplot as plt 103 | import auditory_toolbox_torch as pat 104 | 105 | class CustomPipeline(torch.nn.Module): 106 | def __init__(self, sampling_rate: int = 16000) -> None: 107 | super().__init__() 108 | self.erbbank = pat.ErbFilterBank(sampling_rate,64,100) 109 | self.relu1 = torch.nn.ReLU() 110 | self.avgpool1 = torch.nn.AvgPool1d(80, stride=20) 111 | 112 | def forward(self, x: torch.Tensor) -> torch.Tensor: 113 | x = self.erbbank(x) 114 | x = self.relu1(x) 115 | x = self.avgpool1(x) 116 | x = torch.pow(x,0.3) 117 | return x 118 | 119 | wav, fs = torchaudio.load('./examples/tapestry.wav') 120 | 121 | pipeline = CustomPipeline(fs) 122 | pipeline.to(dtype=torch.float32) 123 | 124 | fig = plt.figure() 125 | plt.imshow(pipeline.forward(wav).squeeze(), aspect='auto', cmap='Blues') 126 | ``` 127 | ![Gammatone features](python_auditory_toolbox/examples/TapestryGammatoneFeatures.png) 128 | 129 | 130 | ## Authors 131 | Malcolm Slaney (malcolm@ieee.org) and 132 | Søren A. Fuglsang (sorenaf@drcmr.dk) 133 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The python_auditory_toolbox Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 Malcolm Slaney (malcolm@ieee.org) and 16 | # Søren A. Fuglsang (sorenaf@drcmr.dk) 17 | # 18 | # Licensed under the Apache License, Version 2.0 (the "License"); 19 | # you may not use this file except in compliance with the License. 20 | # You may obtain a copy of the License at 21 | # 22 | # http://www.apache.org/licenses/LICENSE-2.0 23 | # 24 | # Unless required by applicable law or agreed to in writing, software 25 | # distributed under the License is distributed on an "AS IS" BASIS, 26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 27 | # See the License for the specific language governing permissions and 28 | # limitations under the License. 29 | # ============================================================================== 30 | 31 | """Package init file for python_auditory_toolbox 32 | """ 33 | 34 | # No imports here, instead use 35 | # from python_auditory_toolbox import XXX 36 | # where XXX is the individual implementation name (auditory_toolbox, 37 | # auditory_toolbox_jax or auditory_toolbox_torch). 38 | -------------------------------------------------------------------------------- /python_auditory_toolbox/auditory_toolbox.py: -------------------------------------------------------------------------------- 1 | """A python port of portions of the Matlab Auditory Toolbox. 2 | """ 3 | import math 4 | 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | from scipy import signal 8 | 9 | from typing import List 10 | 11 | 12 | def ErbSpace(low_freq: float = 100, high_freq: float = 44100/4, 13 | n: int = 100) -> np.ndarray: 14 | """This function computes an array of N frequencies uniformly spaced between 15 | high_freq and low_freq on an erb scale. N is set to 100 if not specified. 16 | 17 | See also linspace, logspace, MakeErbCoeffs, MakeErbFilters. 18 | 19 | For a definition of erb, see Moore, B. C. J., and Glasberg, B. R. (1983). 20 | "Suggested formulae for calculating auditory-filter bandwidths and 21 | excitation patterns," J. Acoust. Soc. Am. 74, 750-753. 22 | 23 | Args: 24 | low_freq: The center frequency in Hz of the lowest channel 25 | high_freq: The upper limit in Hz of the channel bank. The center frequency 26 | of the highest channel will be below this frequency. 27 | n: Number of channels 28 | 29 | Returns: 30 | An array of center frequencies, equally spaced on the ERB scale. 31 | """ 32 | 33 | # Change the following three parameters if you wish to use a different 34 | # erb scale. Must change in MakeerbCoeffs too. 35 | ear_q = 9.26449 # Glasberg and Moore Parameters 36 | min_bw = 24.7 37 | 38 | # All of the follow_freqing expressions are derived in Apple TR #35, "An 39 | # Efficient Implementation of the Patterson-Holdsworth Cochlear 40 | # Filter Bank." See pages 33-34. 41 | cf_array = (-(ear_q*min_bw) + 42 | np.exp(np.arange(1, 1+n)* 43 | (-np.log(high_freq + ear_q*min_bw) + 44 | np.log(low_freq + ear_q*min_bw))/n) * (high_freq + 45 | ear_q*min_bw)) 46 | return cf_array 47 | 48 | 49 | def MakeErbFilters(fs: float, num_channels: int, 50 | low_freq: float = 20) -> List[np.ndarray]: 51 | """This function computes the filter coefficients for a bank of 52 | Gammatone filters. These filters were defined by Patterson and 53 | Holdworth for simulating the cochlea. 54 | 55 | The result is returned as an array of filter coefficients. Each row 56 | of the filter arrays contains the coefficients for four second order 57 | filters. The transfer function for these four filters share the same 58 | denominator (poles) but have different numerators (zeros). All of these 59 | coefficients are assembled into one vector that the ErbFilterBank 60 | can take apart to implement the filter. 61 | 62 | The filter bank contains "num_channels" channels that extend from 63 | half the sampling rate (fs) to "low_freq". Alternatively, if the num_channels 64 | input argument is a vector, then the values of this vector are taken to 65 | be the center frequency of each desired filter. (The low_freq argument is 66 | ignored in this case.) 67 | 68 | Note this implementation fixes a problem in the original code by 69 | computing four separate second order filters. This avoids a big 70 | problem with round off errors in cases of very small cfs (100Hz) and 71 | large sample rates (44kHz). The problem is caused by roundoff error 72 | when a number of poles are combined, all very close to the unit 73 | circle. Small errors in the eigth order coefficient, are multiplied 74 | when the eigth root is taken to give the pole location. These small 75 | errors lead to poles outside the unit circle and instability. Thanks 76 | to Julius Smith for leading me to the proper explanation. 77 | 78 | Execute the following code to evaluate the frequency 79 | response of a 10 channel filterbank. 80 | n = 512 81 | fs = 16000 82 | fcoefs = pat.MakeErbFilters(16000,10,100) 83 | y = pat.ErbFilterBank(np.array([1.0] + [0] * (n-1), dtype=float), fcoefs) 84 | resp = 20*np.log10(np.abs(np.fft.fft(y, axis=1))).T 85 | freq_scale = np.expand_dims(np.linspace(0, 16000, 512), 1) 86 | plt.semilogx(freq_scale[:n//2, :], resp[:n//2, :]) 87 | plt.axis((100, fs/2, -60, 0)) 88 | plt.xlabel('Frequency (Hz)') 89 | plt.ylabel('Filter Response (dB)'); 90 | 91 | Args: 92 | fs: Sampling rate (in Hz) of the filterbank (needed to determine CFs). 93 | num_channel: How many channels in the filterbank. 94 | low_freq: The lowest center frequency of the filterbank. 95 | 96 | Returns: 97 | A list of 10 num_channel-D arrays containing the filter parameters. 98 | """ 99 | 100 | t = 1/fs 101 | if isinstance(num_channels, int): 102 | cf = ErbSpace(low_freq, fs/2, num_channels) 103 | else: 104 | cf = num_channels 105 | 106 | # So equations below match the original Matlab syntax 107 | pi = np.pi 108 | abs = np.abs # pylint: disable=redefined-builtin 109 | sqrt = np.sqrt 110 | sin = np.sin 111 | cos = np.cos 112 | exp = np.exp 113 | i = np.array([1j], dtype=np.csingle) 114 | 115 | # Change the follow_freqing three parameters if you wish to use a different 116 | # erb scale. Must change in ErbSpace too. 117 | ear_q = 9.26449 # Glasberg and Moore Parameters 118 | min_bw = 24.7 119 | order = 1 120 | 121 | erb = ((cf/ear_q)**order + min_bw**order)**(1/order) 122 | b=1.019*2*pi*erb 123 | 124 | a0 = t # Feedback coefficients (poles) 125 | a2 = 0 126 | b0 = 1 # Feedforward coefficients (zeros) 127 | b1 = -2*cos(2*cf*pi*t)/exp(b*t) 128 | b2 = exp(-2*b*t) 129 | 130 | a11 = -(2*t*cos(2*cf*pi*t)/exp(b*t) + 2*sqrt(3+2**1.5)*t*sin(2*cf*pi*t)/ 131 | exp(b*t))/2 132 | a12 = -(2*t*cos(2*cf*pi*t)/exp(b*t) - 2*sqrt(3+2**1.5)*t*sin(2*cf*pi*t)/ 133 | exp(b*t))/2 134 | a13 = -(2*t*cos(2*cf*pi*t)/exp(b*t) + 2*sqrt(3-2**1.5)*t*sin(2*cf*pi*t)/ 135 | exp(b*t))/2 136 | a14 = -(2*t*cos(2*cf*pi*t)/exp(b*t) - 2*sqrt(3-2**1.5)*t*sin(2*cf*pi*t)/ 137 | exp(b*t))/2 138 | 139 | gain = abs((-2*exp(4*i*cf*pi*t)*t + 140 | 2*exp(-(b*t) + 2*i*cf*pi*t)*t* 141 | (cos(2*cf*pi*t) - sqrt(3 - 2**(3/2))* 142 | sin(2*cf*pi*t))) * 143 | (-2*exp(4*i*cf*pi*t)*t + 144 | 2*exp(-(b*t) + 2*i*cf*pi*t)*t* 145 | (cos(2*cf*pi*t) + sqrt(3 - 2**(3/2)) * 146 | sin(2*cf*pi*t)))* 147 | (-2*exp(4*i*cf*pi*t)*t + 148 | 2*exp(-(b*t) + 2*i*cf*pi*t)*t* 149 | (cos(2*cf*pi*t) - 150 | sqrt(3 + 2**(3/2))*sin(2*cf*pi*t))) * 151 | (-2*exp(4*i*cf*pi*t)*t + 2*exp(-(b*t) + 2*i*cf*pi*t)*t* 152 | (cos(2*cf*pi*t) + sqrt(3 + 2**(3/2))*sin(2*cf*pi*t))) / 153 | (-2 / exp(2*b*t) - 2*exp(4*i*cf*pi*t) + 154 | 2*(1 + exp(4*i*cf*pi*t))/exp(b*t))**4) 155 | 156 | allfilts = np.ones(len(cf)) 157 | fcoefs = [a0*allfilts, a11, a12, a13, a14, a2*allfilts, 158 | b0*allfilts, b1, b2, gain] 159 | return fcoefs 160 | 161 | 162 | def ErbFilterBank(x: np.ndarray, fcoefs: List[np.ndarray]) -> np.ndarray: 163 | """Filter an input signal with a filterbank, producing one output vector 164 | per channel. 165 | 166 | Args: 167 | x: The input signal, one-dimensional 168 | fcoefs: A list of 10 num-channel-dimensional arrays that describe the 169 | filterbank. 170 | 171 | Returns: 172 | num-channel outputs in a num_channel x time array. 173 | """ 174 | [a0, a11, a12, a13, a14, a2, b0, b1, b2, gain] = fcoefs 175 | n_chan = a0.shape[0] 176 | assert n_chan == a11.shape[0] 177 | assert n_chan == a12.shape[0] 178 | assert n_chan == a13.shape[0] 179 | assert n_chan == a14.shape[0] 180 | assert n_chan == b0.shape[0] 181 | assert n_chan == b1.shape[0] 182 | assert n_chan == b2.shape[0] 183 | assert n_chan == gain.shape[0] 184 | 185 | sos = np.stack((np.stack([a0/gain, a0, a0, a0], axis=1), 186 | np.stack([a11/gain, a12, a13, a14], axis=1), 187 | np.stack([a2/gain, a2, a2, a2], axis=1), 188 | np.stack([b0, b0, b0, b0], axis=1), 189 | np.stack([b1, b1, b1, b1], axis=1), 190 | np.stack([b2, b2, b2, b2], axis=1)), 191 | axis=2) 192 | 193 | all_y = None 194 | for c in range(n_chan): 195 | y = signal.sosfilt(sos[c, :, :], x) 196 | if all_y is None: 197 | all_y = np.zeros((n_chan, len(y)), dtype=x.dtype) 198 | all_y[c, :] = y 199 | return all_y 200 | 201 | 202 | def CorrelogramFrame(data: np.ndarray, pic_width: int, 203 | start: int = 0, win_len: int = 0) -> np.ndarray: 204 | """Generate one from of a correlogram using FFTs to calculate autocorrelation. 205 | 206 | Args 207 | data: A num_channel x time array of input waveforms, one time domain signal 208 | per channel. 209 | pic_width: Number of pixels (time lags) in the final correlogram frame. 210 | start: The starting sample 211 | win_length: How much data to take from the input signal when computing the 212 | autocorrelation. 213 | 214 | Returns: 215 | A two dimensional array, of size num_channels x pic_width, containing one 216 | frame of the correlogram. 217 | """ 218 | channels, data_len = data.shape 219 | if not win_len: 220 | win_len = data_len 221 | 222 | # Round up to double the window size, and then the next power of 2. 223 | fft_size = int(2**np.ceil(np.log2(2*max(pic_width, win_len)))) 224 | 225 | start = max(0, start) 226 | last = min(data_len, start+win_len) 227 | a = .54 228 | b = -.46 229 | wr = math.sqrt(64/256) 230 | phi = np.pi/win_len 231 | ws = 2*wr/np.sqrt(4*a*a+2*b*b)*( 232 | a + b*np.cos(2*np.pi*(np.arange(win_len))/win_len + phi)) 233 | 234 | f = np.zeros((channels, fft_size), dtype=data.dtype) 235 | f[:, :last-start] = data[:, start:last] * ws[:last-start] 236 | f = np.fft.fft(f, axis=1) 237 | f = np.fft.ifft(f*np.conj(f), axis=1) 238 | pic = np.maximum(0, np.real(f[:, :pic_width])) 239 | good_rows = np.logical_and( # Make sure first column is bigger than the rest. 240 | pic[:, 0] > 0, 241 | np.logical_and(pic[:, 0] > pic[:, 1], pic[:, 0] > pic[:, 2])) 242 | pic = np.where(np.expand_dims(good_rows, axis=-1), 243 | pic / np.tile(np.sqrt(pic[:, :1]), (1, pic_width)), 244 | np.array([0])) 245 | 246 | return pic 247 | 248 | 249 | def FMPoints(sample_len, freq, fm_freq=6, fm_amp=None, fs=22050): 250 | """Generate impulse train corresponding to a vibrato. 251 | 252 | points=FMPoints(sample_len, freq, fm_freq, fm_amp, fs) 253 | Generates (fractional) sample locations for frequency-modulated impulses 254 | sample_len = number of samples 255 | freq = pitch frequency (Hz) 256 | fm_freq = vibrato frequency (Hz) (defaults to 6 Hz) 257 | fm_amp = max change in pitch (defaults to 5% of freq) 258 | fs = sample frequency (defaults to 22254.545454 samples/s) 259 | 260 | Basic formula: phase angle = 2*pi*freq*t + 261 | (fm_amp/fm_freq)*sin(2*pi*fm_freq*t) 262 | k-th zero crossing approximately at sample number 263 | (fs/freq)*(k - (fm_amp/(2*pi*fm_freq))*sin(2*pi*k*(fm_freq/freq))) 264 | 265 | Args: 266 | sample_len: How much data to generate, in samples 267 | freq: Base frequency of the output signal (Hz) 268 | fm_freq: Vibrato frequency (in Hz) 269 | fm_amp: Magnitude of the FM deviation (in Hz) 270 | fs: Sample rate for the output signal. 271 | 272 | Returns: 273 | An impulse train, indicating the positive-going zero crossing 274 | of the phase funcion. 275 | """ 276 | if fm_amp is None: 277 | fm_amp = 0.05*freq 278 | 279 | kmax = int(freq*(sample_len/fs)) 280 | points = np.arange(kmax) 281 | points = (fs/freq)*(points-( 282 | fm_amp/(2*np.pi*fm_freq))*np.sin(2*np.pi*(fm_freq/freq)*points)) 283 | return points 284 | 285 | 286 | def MakeVowel(sample_len, pitch, sample_rate, f1=0, f2=0, f3=0, bw=50): 287 | """Synthesize an artificial vowel using formant filters. 288 | 289 | MakeVowel(sample_len, pitch [, sample_rate, f1, f2, f3]) - 290 | Make a vowel with 291 | "sample_len" samples and the given pitch. The sample rate defaults to 292 | be 22254.545454 Hz (the native Mactinosh Sampling Rate). The 293 | formant frequencies are f1, f2 & f3. Some common vowels are 294 | Vowel f1 f2 f3 295 | /a/ 730 1090 2440 296 | /i/ 270 2290 3010 297 | /u/ 300 870 2240 298 | 299 | The pitch variable can either be a scalar indicating the actual 300 | pitch frequency, or an array of impulse locations. Using an 301 | array of impulses allows this routine to compute vowels with 302 | varying pitch. 303 | 304 | Alternatively, f1 can be replaced with one of the following strings 305 | 'a', 'i', 'u' and the appropriate formant frequencies are 306 | automatically selected. 307 | 308 | Args: 309 | sample_len: How many samples to generate 310 | pitch: Either a single floating point value indidcating a constant 311 | pitch (in Hz), or a train of impulses generated by FMPoints. 312 | sample_rate: The sample rate for the output signal (Hz) 313 | f1: Either a vowel spec, one of /a/, /i/, or /u', or the frequency 314 | of the first formatn. 315 | f2: Optional 2nd formant frequency (if f1 is not a vowel name) 316 | f3: Optional 3rd formant frequency (if f1 is not a vowel name) 317 | bw: Bandwidth of the forman filters 318 | 319 | Returns: 320 | A time domain waveform containing the synthetic vowel sound. 321 | """ 322 | if isinstance(f1, str): 323 | if f1 == 'a' or f1 == '/a/': 324 | f1, f2, f3 = (730, 1090, 2440) 325 | elif f1 == 'i' or f1 == '/i/': 326 | f1, f2, f3 = (270, 2290, 3010) 327 | elif f1 == 'u' or f1 == '/u/': 328 | f1, f2, f3 = (300, 870, 2240) 329 | 330 | 331 | # GlottalPulses(pitch, fs, sample_len) - Generate a stream of 332 | # glottal pulses with the given pitch (in Hz) and sampling 333 | # frequency (sample_rate). A vector of the requested length is returned. 334 | y = np.zeros(sample_len, float) 335 | if isinstance(pitch, (int, float)): 336 | points = np.arange(0, sample_len-1, sample_rate/pitch) 337 | else: 338 | points = np.sort(np.asarray(pitch)) 339 | points = points[points < sample_len-1] 340 | indices = np.floor(points).astype(int) 341 | 342 | # Use a triangular approximation to an impulse function. The important 343 | # part is to keep the total amplitude the same. 344 | y[indices] = (indices+1)-points 345 | y[indices+1] = points-indices 346 | 347 | # GlottalFilter(x,fs) - Filter an impulse train and simulate the glottal 348 | # transfer function. The sampling interval (sample_rate) is given in Hz. 349 | # The filtering performed by this function is two first-order filters 350 | # at 250Hz. 351 | a = np.exp(-250*2*np.pi/sample_rate) 352 | #y=filter([1,0,-1],[1,-2*a,a*a],y) # Not as good as one below.... 353 | y = signal.lfilter([1],[1,0,-a*a],y) 354 | 355 | # FormantFilter(input, f, fs) - Filter an input sequence to model one 356 | # formant in a speech signal. The formant frequency (in Hz) is given 357 | # by f and the bandwidth of the formant is a constant 50Hz. The 358 | # sampling frequency in Hz is given by fs. 359 | if f1 > 0: 360 | cft = f1/sample_rate 361 | q = f1/bw 362 | rho = np.exp(-np.pi * cft / q) 363 | theta = 2 * np.pi * cft * np.sqrt(1-1/(4 * q*q)) 364 | a2 = -2*rho*np.cos(theta) 365 | a3 = rho*rho 366 | y=signal.lfilter([1+a2+a3],[1,a2,a3],y) 367 | 368 | # FormantFilter(input, f, fs) - Filter an input sequence to model one 369 | # formant in a speech signal. The formant frequency (in Hz) is given 370 | # by f and the bandwidth of the formant is a constant 50Hz. The 371 | # sampling frequency in Hz is given by fs. 372 | if f2 > 0: 373 | cft = f2/sample_rate 374 | q = f2/bw 375 | rho = np.exp(-np.pi * cft / q) 376 | theta = 2 * np.pi * cft * np.sqrt(1-1/(4 * q*q)) 377 | a2 = -2*rho*np.cos(theta) 378 | a3 = rho*rho 379 | y= signal.lfilter([1+a2+a3],[1,a2,a3],y) 380 | 381 | # FormantFilter(input, f, fs) - Filter an input sequence to model one 382 | # formant in a speech signal. The formant frequency (in Hz) is given 383 | # by f and the bandwidth of the formant is a constant 50Hz. The 384 | # sampling frequency in Hz is given by fs. 385 | if f3 > 0: 386 | cft = f3/sample_rate 387 | q = f3/bw 388 | rho = np.exp(-np.pi * cft / q) 389 | theta = 2 * np.pi * cft * np.sqrt(1-1/(4 * q*q)) 390 | a2 = -2*rho*np.cos(theta) 391 | a3 = rho*rho 392 | y= signal.lfilter([1+a2+a3],[1,a2,a3],y) 393 | return y 394 | 395 | 396 | def CorrelogramArray(data, sr=16000, frame_rate=12, width=256): 397 | """Generate an array of correlogram frames. 398 | 399 | Args: 400 | data: The filterbank's output, size num_channel x time. 401 | sr: The sample rate for the data (needed when computing the frame times) 402 | frame_rate: How often (in Hz) correlogram frames should be generated. 403 | width: The width (in lags) of the correlogram 404 | 405 | Returns: 406 | A num_frames x num_channels x width tensor of correlogram frames. 407 | """ 408 | _, sample_len = data.shape 409 | frame_increment = int(sr/frame_rate) 410 | frame_count = int((sample_len-width)/frame_increment) + 1 411 | movie = None 412 | for i in range(frame_count): 413 | start = i*frame_increment 414 | frame = CorrelogramFrame(data, width, start, frame_increment*4) 415 | if movie is None: 416 | movie = np.zeros((frame_count, frame.shape[0], 417 | frame.shape[1]), dtype=float) 418 | movie[i, :, :] = frame 419 | return movie 420 | 421 | def CorrelogramPitch(correlogram, width, sr=22254.54, 422 | low_pitch=0, high_pitch=20000): 423 | """Compute the summary of a correlogram to find the pitch. 424 | 425 | pitch=CorrelogramPitch(correlogram, width, sr, low_pitch, high_pitch 426 | computes the pitch of a correlogram sequence by finding the time lag 427 | with the largest correlation energy. 428 | 429 | Args: 430 | correlogram: A 3D correlogram array, output from CorrelogramArray. 431 | num_frames x num_channels x num_times 432 | width: Width of the correlogram. Historical parameter. Should be 433 | equal to correlogram.shape[1] 434 | low_pitch: Lowest allowable pitch (Hz). Pitch peaks are only searched 435 | within the region low_pitch to high_pitch 436 | high_pitch: Highest allowable pitch (Hz). 437 | 438 | Returns: 439 | A 2-element tuple, containing 440 | 1) a one-dimensional array of length num_frames indicating the pitch 441 | or 0 if no pitch is found 442 | 2) A one-dimensional array indicating the pitch salience on a scale 443 | from 0 (no pitch found) to 1 clear pitch. 444 | """ 445 | 446 | drop_low = int(sr/high_pitch) 447 | if low_pitch > 0: 448 | drop_high = int(min(width,math.ceil(sr/low_pitch))) 449 | else: 450 | drop_high = width 451 | 452 | frames = correlogram.shape[0] 453 | 454 | pitch = np.zeros(frames) 455 | salience = np.zeros(frames) 456 | for j in range(frames): 457 | # Get one frame from the correlogram and compute 458 | # the sum (as a function of time lag) across all channels. 459 | summary = np.sum(correlogram[j, :, :], axis=0) 460 | zero_lag = summary[0] 461 | # Now we need to find the first pitch past the peak at zero 462 | # lag. The following lines smooth the summary pitch a bit, then 463 | # look for the first point where the summary goes back up. 464 | # Everything up to this point is zeroed out. 465 | window_length = 16 466 | sumfilt = signal.lfilter(np.ones(window_length), [1,] , summary) 467 | sumdif = sumfilt[1:width] - sumfilt[:width-1] 468 | sumdif[:window_length] = 0 469 | valleys = np.argwhere(sumdif>0) 470 | summary[:int(valleys[0, 0])] = 0 471 | summary[1:drop_low] = 0 472 | summary[drop_high:] = 0 473 | # plt.plot(summary) 474 | # Now find the location of the biggest peak and call this the pitch 475 | p = np.argmax(summary) 476 | if p > 0: 477 | pitch[j] = sr/float(p) 478 | salience[j] = summary[p]/zero_lag 479 | 480 | return pitch,salience 481 | 482 | def Mfcc(input_signal, sampling_rate=16000, frame_rate=100, debug=False): 483 | """Mfcc - Mel frequency cepstrum coefficient analysis. 484 | 485 | Find the cepstral coefficients (ceps) corresponding to the 486 | input. 487 | 488 | Args: 489 | input_signal: The one-dimensional time-domain audio signal 490 | sampling_rate: The sample rate of the input in Hz. 491 | frame_rate: The desired output sampling rate 492 | debug: A debug flag that turns on various plots. 493 | 494 | Returns: 495 | A five-tuple consisting of: 496 | 1) The MFCC representation, a 13 x num_frames output. 497 | 2) The detailed fft magnitude (freqresp) used in MFCC calculation, 498 | 3) The mel-scale filter bank output (fb) 499 | 4) The filter bank output by inverting the cepstrals with a cosine 500 | transform (fbrecon), 501 | 5) The smooth frequency response by interpolating the fb reconstruction 502 | (freqrecon) 503 | 504 | Modified a bit to make testing an algorithm easier... 4/15/94 505 | Fixed Cosine Transform (indices of cos() were swapped) - 5/26/95 506 | Added optional frame_rate argument - 6/8/95 507 | Added proper filterbank reconstruction using inverse DCT - 10/27/95 508 | Added filterbank inversion to reconstruct spectrum - 11/1/95 509 | """ 510 | # Filter bank parameters 511 | lowest_frequency = 133.3333 512 | linear_filters = 13 513 | linear_spacing = 66.66666666 514 | log_filters = 27 515 | log_spacing = 1.0711703 516 | fft_size = 512 517 | cepstral_coefficients = 13 518 | window_size = 400 519 | window_size = 256 # Standard says 400, but 256 makes more sense 520 | # Really should be a function of the sample 521 | # rate (and the lowest_frequency) and the 522 | # frame rate. 523 | 524 | # Keep this around for later.... 525 | total_filters = linear_filters + log_filters 526 | 527 | # Now figure the band edges. Interesting frequencies are spaced 528 | # by linear_spacing for a while, then go logarithmic. First figure 529 | # all the interesting frequencies. Lower, center, and upper band 530 | # edges are all consequtive interesting frequencies. 531 | 532 | freqs = np.zeros(total_filters+2) 533 | freqs[:linear_filters] = (lowest_frequency + 534 | np.arange(linear_filters)*linear_spacing) 535 | freqs[linear_filters:total_filters+2] = ( 536 | freqs[linear_filters-1] * log_spacing**np.arange(1, log_filters+3)) 537 | # print('freqs:', freqs) 538 | lower = freqs[:total_filters] 539 | center = freqs[1:total_filters+1] 540 | upper = freqs[2:total_filters+2] 541 | 542 | # We now want to combine FFT bins so that each filter has unit 543 | # weight, assuming a triangular weighting function. First figure 544 | # out the height of the triangle, then we can figure out each 545 | # frequencies contribution 546 | mfcc_filter_weights = np.zeros((total_filters,fft_size)) 547 | triangle_height = 2/(upper-lower) 548 | fft_freqs = np.arange(fft_size)/fft_size*sampling_rate 549 | 550 | for chan in range(total_filters): 551 | mfcc_filter_weights[chan,:] = ( 552 | np.logical_and(fft_freqs > lower[chan], fft_freqs <= center[chan]) * 553 | triangle_height[chan]*(fft_freqs-lower[chan])/(center[chan]- 554 | lower[chan]) + 555 | np.logical_and(fft_freqs > center[chan], fft_freqs < upper[chan]) * 556 | triangle_height[chan]*(upper[chan]-fft_freqs)/(upper[chan]- 557 | center[chan])) 558 | 559 | if debug: 560 | plt.semilogx(fft_freqs,mfcc_filter_weights.T) 561 | #axis([lower(1) upper(total_filters) 0 max(max(mfcc_filter_weights))]) 562 | 563 | ham_window = 0.54 - 0.46*np.cos(2*np.pi*np.arange(window_size)/window_size) 564 | 565 | if False: # Window it like ComplexSpectrum # pylint: disable=using-constant-test 566 | window_step = sampling_rate/frame_rate 567 | a = .54 568 | b = -.46 569 | wr = np.sqrt(window_step/window_size) 570 | phi = np.pi/window_size 571 | ham_window = (2*wr/np.sqrt(4*a*a+2*b*b)* 572 | (a + b*np.cos(2*np.pi*np.arange(window_size)/window_size + 573 | phi))) 574 | 575 | # Figure out Discrete Cosine Transform. We want a matrix 576 | # dct(i,j) which is total_filters x cepstral_coefficients in size. 577 | # The i,j component is given by 578 | # cos( i * (j+0.5)/total_filters pi ) 579 | # where we have assumed that i and j start at 0. 580 | 581 | cepstral_indices = np.reshape(np.arange(cepstral_coefficients), 582 | (cepstral_coefficients, 1)) 583 | filter_indices = np.reshape(2*np.arange(0, total_filters)+1, 584 | (1, total_filters)) 585 | cos_term = np.cos(np.matmul(cepstral_indices, 586 | filter_indices)*np.pi/2/total_filters) 587 | 588 | mfcc_dct_matrix = 1/np.sqrt(total_filters/2)*cos_term 589 | mfcc_dct_matrix[0,:] = mfcc_dct_matrix[0,:] * np.sqrt(2)/2 590 | 591 | if debug: 592 | plt.imshow(mfcc_dct_matrix) 593 | plt.xlabel('Filter Coefficient') 594 | plt.ylabel('Cepstral Coefficient') 595 | 596 | # Filter the input with the preemphasis filter. Also figure how 597 | # many columns of data we will end up with. 598 | if True: # pylint: disable=using-constant-test 599 | pre_emphasized = signal.lfilter([1 -.97], 1, input_signal) 600 | else: 601 | pre_emphasized = input_signal 602 | 603 | window_step = sampling_rate/frame_rate 604 | cols = int((len(input_signal)-window_size)/window_step) 605 | 606 | # Allocate all the space we need for the output arrays. 607 | ceps = np.zeros((cepstral_coefficients, cols)) 608 | freqresp = np.zeros((fft_size//2, cols)) 609 | fb = np.zeros((total_filters, cols)) 610 | 611 | # Invert the filter bank center frequencies. For each FFT bin 612 | # we want to know the exact position in the filter bank to find 613 | # the original frequency response. The next block of code finds the 614 | # integer and fractional sampling positions. 615 | if True: # pylint: disable=using-constant-test 616 | fr = np.arange(fft_size//2)/(fft_size/2)*sampling_rate/2 617 | j = 0 618 | for i in np.arange(fft_size//2): 619 | if fr[i] > center[j+1]: 620 | j = j + 1 621 | j = min(j, total_filters-2) 622 | # if j > total_filters-2: 623 | # j = total_filters-1 624 | fr[i] = min(total_filters-1-.0001, 625 | max(0,j + (fr[i]-center[j])/(center[j+1]-center[j]))) 626 | fri = fr.astype(int) 627 | frac = fr - fri 628 | 629 | freqrecon = np.zeros((fft_size//2, cols)) 630 | fbrecon = np.zeros((total_filters, cols)) 631 | # Ok, now let's do the processing. For each chunk of data: 632 | # * Window the data with a hamming window, 633 | # * Shift it into FFT order, 634 | # * Find the magnitude of the fft, 635 | # * Convert the fft data into filter bank outputs, 636 | # * Find the log base 10, 637 | # * Find the cosine transform to reduce dimensionality. 638 | for start in np.arange(cols): 639 | first = round(start*window_step) # Round added by Malcolm 640 | last = round(first + window_size) 641 | fft_data = np.zeros(fft_size) 642 | fft_data[:window_size] = pre_emphasized[first:last]*ham_window 643 | fft_mag = np.abs(np.fft.fft(fft_data)) 644 | ear_mag = np.log10(np.matmul(mfcc_filter_weights, fft_mag.T)) 645 | 646 | ceps[:,start] = np.matmul(mfcc_dct_matrix, ear_mag) 647 | freqresp[:,start] = fft_mag[:fft_size//2].T 648 | fb[:,start] = ear_mag 649 | fbrecon[:,start] = np.matmul(mfcc_dct_matrix[:cepstral_coefficients,:].T, 650 | ceps[:,start]) 651 | 652 | if True: # pylint: disable=using-constant-test 653 | f10 = 10**fbrecon[:,start] 654 | freqrecon[:,start] = sampling_rate/fft_size * (f10[fri]*(1-frac) + 655 | f10[fri+1]*frac) 656 | 657 | # OK, just to check things, let's also reconstruct the original FB 658 | # output. We do this by multiplying the cepstral data by the transpose 659 | # of the original DCT matrix. This all works because we were careful to 660 | # scale the DCT matrix so it was orthonormal. 661 | fbrecon = np.matmul(mfcc_dct_matrix[:cepstral_coefficients,:].T, ceps) 662 | return ceps,freqresp,fb,fbrecon,freqrecon 663 | 664 | 665 | def Spectrogram(wave: np.ndarray, 666 | segsize: int = 128, 667 | nlap: int = 8, 668 | ntrans: int = 4, 669 | normalize: bool = False) -> np.ndarray: 670 | """ 671 | Compute a pretty spectrogram. Premphasize the audio to preserve the high 672 | frequencies, and normalize the result using fourth-root compression to more 673 | closely match human perception (both auditory and visual). Original algorithm 674 | by Richard F. Lyon. 675 | 676 | Args: 677 | wave: The one dimensional signal 678 | segsize: How much of the signal to consider for each output frame 679 | nlap: is number of hamming windows overlapping a point 680 | ntrans: is factor by which transform is bigger than segment 681 | normalize: Whether to normalize & compress the output for display. 682 | Returns: 683 | A spectrogram 'array' with fourth root of amplude, filter smoothed and 684 | formatted for display. 685 | """ 686 | wave = signal.lfilter([1, -0.95], [1], wave) 687 | 688 | s = len(wave) 689 | nsegs = math.floor(s/(segsize/nlap) - nlap + 1) 690 | array = np.zeros((ntrans//2*segsize, nsegs)) 691 | window = 0.54-0.46*np.cos(2*np.pi/(segsize+1)*(np.arange(segsize))) 692 | for i in range(nsegs): 693 | seg = np.zeros(ntrans*segsize) # leave half full of zeroes 694 | start = i*segsize//nlap 695 | stop = start + segsize 696 | piece = wave[start:stop] 697 | seg[:segsize] = window*piece 698 | seg = np.abs(np.fft.fft(seg)) 699 | array[:, i] = seg[:array.shape[0]] # seg[ntrans/2*segsize:ntrans*segsize] 700 | if normalize: 701 | # compress with square root of amplitude (fourth root of power) 702 | off = 0.0001*np.max(array) # low end stabilization offset, 703 | array = (off+array)**0.25-off**0.25 # better than a threshold hack! 704 | array = 255/np.max(array)*array 705 | return array 706 | -------------------------------------------------------------------------------- /python_auditory_toolbox/auditory_toolbox_comparison_test.py: -------------------------------------------------------------------------------- 1 | """ Direct comparison of the three implementations.""" 2 | from absl.testing import absltest 3 | import torch 4 | import torch.fft 5 | import numpy as np 6 | import jax.numpy as jnp 7 | 8 | import auditory_toolbox_torch as pat_torch 9 | import auditory_toolbox as pat_np 10 | import auditory_toolbox_jax as pat_jax 11 | 12 | class AuditoryToolboxTests(absltest.TestCase): 13 | 14 | """Compare the three different auditory toolbox implementations.""" 15 | 16 | def test_erbfilter_fcoefs(self): 17 | def compare_fcoefs(num_chan,fs,low_freq): 18 | fcoefs_np = pat_np.MakeErbFilters(fs, num_chan, low_freq) 19 | fcoefs_jax = pat_jax.MakeErbFilters(fs, num_chan, low_freq) 20 | model_torch = pat_torch.ErbFilterBank(fs, num_chan, low_freq) 21 | 22 | # Compare fcoefs 23 | fcoefs_np = np.c_[fcoefs_np].T 24 | fcoefs_jax = np.asarray(jnp.c_[fcoefs_jax]).T 25 | fcoefs_torch = torch.cat(model_torch.fcoefs,dim=1).numpy() 26 | 27 | np.testing.assert_almost_equal(fcoefs_np, fcoefs_torch) 28 | np.testing.assert_almost_equal(fcoefs_jax, fcoefs_np, decimal=6) 29 | np.testing.assert_almost_equal(fcoefs_jax, fcoefs_torch, decimal=6) 30 | 31 | compare_fcoefs(10, 16000, 60) 32 | compare_fcoefs(64, 16000, 60) 33 | compare_fcoefs(128, 16000, 60) 34 | compare_fcoefs(128, 44100, 60) 35 | 36 | def test_erbfilter_output(self): 37 | def compare_erbfilterbank_output(input_vec, num_chan, fs, low_freq): 38 | x = input_vec 39 | fcoefs_np = pat_np.MakeErbFilters(fs, num_chan, low_freq) 40 | fcoefs_jax = pat_jax.MakeErbFilters(fs, num_chan, low_freq) 41 | model_torch = pat_torch.ErbFilterBank(fs, num_chan,low_freq) 42 | 43 | y_np = pat_np.ErbFilterBank(x, fcoefs_np) 44 | y_jax = np.asarray(pat_np.ErbFilterBank(x, fcoefs_jax)) 45 | y_torch = model_torch(torch.tensor(x).unsqueeze(0)).numpy().squeeze() 46 | 47 | # Compare output 48 | np.testing.assert_almost_equal(y_np, y_torch) 49 | np.testing.assert_almost_equal(y_np, y_jax, decimal=4) 50 | np.testing.assert_almost_equal(y_torch, y_jax, decimal=4) 51 | 52 | # Simulation 1 53 | x = np.zeros(512) 54 | x[0] = 1 55 | compare_erbfilterbank_output(x, 10, 16000, 100) 56 | compare_erbfilterbank_output(x, 64, 16000, 100) 57 | compare_erbfilterbank_output(x, 128, 44100, 100) 58 | 59 | # Simulation 2 60 | compare_erbfilterbank_output(np.random.randn(10000), 10, 16000, 100) 61 | 62 | # Simulation 2 63 | sample_len = 20000 64 | sample_rate = 22254 65 | pitch_center = 120 66 | x = pat_np.MakeVowel(sample_len, pat_np.FMPoints(sample_len, pitch_center), 67 | sample_rate, 'u') 68 | compare_erbfilterbank_output(x, 10, sample_rate, 100) 69 | compare_erbfilterbank_output(x, 64, sample_rate, 100) 70 | compare_erbfilterbank_output(x, 128, sample_rate, 100) 71 | 72 | def test_fmpoints(self): 73 | def compare_fm_points_outputs(sample_len, freq, fm_freq, fm_amp, fs): 74 | points_np = pat_np.FMPoints(sample_len, freq, 75 | fm_freq, fm_amp, fs) 76 | points_jax = pat_jax.FMPoints(sample_len, freq, 77 | fm_freq, fm_amp, fs) 78 | points_torch = pat_torch.fm_points(sample_len, freq, 79 | fm_freq, fm_amp, fs) 80 | points_jax = np.asarray(points_jax) 81 | points_torch = points_torch.squeeze().numpy() 82 | np.testing.assert_almost_equal(points_np, points_torch) 83 | np.testing.assert_almost_equal(points_np, points_jax, decimal=3) 84 | np.testing.assert_almost_equal(points_jax, points_torch, decimal=3) 85 | 86 | base_pitch = 160 87 | sample_rate = 16000 88 | fmfreq = 10 89 | fmamp = 20 90 | sample_len = 10000 91 | compare_fm_points_outputs(sample_len,base_pitch, 92 | fmfreq,fmamp,sample_rate) 93 | 94 | base_pitch = 160 95 | sample_rate = 16000 96 | fmfreq = 100 97 | fmamp = 20 98 | sample_len = 10000 99 | compare_fm_points_outputs(sample_len,base_pitch, 100 | fmfreq,fmamp,sample_rate) 101 | 102 | base_pitch = 560 103 | sample_rate = 16000 104 | fmfreq = 50 105 | fmamp = 20 106 | sample_len = 10000 107 | compare_fm_points_outputs(sample_len,base_pitch, 108 | fmfreq,fmamp,sample_rate) 109 | 110 | def test_correlogram_array(self): 111 | def compare_correlogram_output(input_vec, num_chan, fs, low_freq, 112 | frame_width): 113 | x = input_vec 114 | fcoefs_np = pat_np.MakeErbFilters(fs, num_chan, low_freq) 115 | fcoefs_jax = pat_jax.MakeErbFilters(fs, num_chan, low_freq) 116 | model_torch = pat_torch.ErbFilterBank(fs,num_chan,low_freq) 117 | 118 | y_np = pat_np.ErbFilterBank(x, fcoefs_np) 119 | y_jax = pat_np.ErbFilterBank(x, fcoefs_jax) 120 | y_torch = model_torch(torch.tensor(x).unsqueeze(0)) 121 | 122 | 123 | y_np = pat_np.CorrelogramFrame(y_np, frame_width) 124 | y_jax = pat_np.CorrelogramFrame(y_jax, frame_width) 125 | y_torch = pat_torch.correlogram_frame(y_torch, frame_width) 126 | y_torch = y_torch.squeeze().numpy() 127 | y_jax = np.asarray(y_jax) 128 | # Compare output 129 | np.testing.assert_almost_equal(y_np,y_torch) 130 | np.testing.assert_almost_equal(y_np,y_jax,decimal=3) 131 | np.testing.assert_almost_equal(y_torch,y_jax,decimal=3) 132 | 133 | # Simulation 1 134 | x = np.zeros(512) 135 | x[0] = 1 136 | compare_correlogram_output(x, 10, 16000, 100, 256) 137 | compare_correlogram_output(x, 64, 16000, 100, 256) 138 | compare_correlogram_output(x, 128, 44100, 100, 256) 139 | compare_correlogram_output(x, 10, 16000, 100, 512) 140 | compare_correlogram_output(x, 64, 16000, 100, 512) 141 | compare_correlogram_output(x, 128, 44100, 100, 512) 142 | 143 | # # Simulation 2 144 | sample_len = 20000 145 | sample_rate = 22254 146 | pitch_center = 120 147 | x = pat_np.MakeVowel(sample_len, 148 | pat_np.FMPoints(sample_len, pitch_center), 149 | sample_rate, 150 | 'u') 151 | compare_correlogram_output(x, 10, 16000, 100, 256) 152 | compare_correlogram_output(x, 64, 16000, 100, 256) 153 | compare_correlogram_output(x, 128, 44100, 100, 256) 154 | compare_correlogram_output(x, 10, 16000, 100, 512) 155 | compare_correlogram_output(x, 64, 16000, 100, 512) 156 | compare_correlogram_output(x, 128, 44100, 100, 512) 157 | 158 | 159 | if __name__ == '__main__': 160 | absltest.main() 161 | -------------------------------------------------------------------------------- /python_auditory_toolbox/auditory_toolbox_jax.py: -------------------------------------------------------------------------------- 1 | """A JAX port of portions of the Matlab Auditory Toolbox. 2 | """ 3 | import math 4 | 5 | import matplotlib.pyplot as plt 6 | import jax 7 | import jax.numpy as jnp 8 | 9 | from typing import List, Union 10 | 11 | 12 | def ErbSpace(low_freq: float = 100, high_freq: float = 44100/4, 13 | n: int = 100) -> jnp.ndarray: 14 | """This function computes an array of N frequencies uniformly spaced between 15 | high_freq and low_freq on an erb scale. N is set to 100 if not specified. 16 | 17 | See also linspace, logspace, MakeerbCoeffs, MakeErbFilters. 18 | 19 | For a definition of erb, see Moore, B. C. J., and Glasberg, B. R. (1983). 20 | "Suggested formulae for calculating auditory-filter bandwidths and 21 | excitation patterns," J. Acoust. Soc. Am. 74, 750-753. 22 | 23 | Args: 24 | low_freq: The center frequency in Hz of the lowest channel 25 | high_freq: The upper limit in Hz of the channel bank. The center frequency 26 | of the highest channel will be below this frequency. 27 | n: Number of channels 28 | 29 | Returns: 30 | An array of center frequencies, equally spaced on the ERB scale. 31 | """ 32 | 33 | # Change the following three parameters if you wish to use a different 34 | # erb scale. Must change in MakeerbCoeffs too. 35 | ear_q = 9.26449 # Glasberg and Moore Parameters 36 | min_bw = 24.7 37 | 38 | # All of the follow_freqing expressions are derived in Apple TR #35, "An 39 | # Efficient Implementation of the Patterson-Holdsworth Cochlear 40 | # Filter Bank." See pages 33-34. 41 | cf_array = (-(ear_q*min_bw) + 42 | jnp.exp(jnp.arange(1, 1+n)* 43 | (-jnp.log(high_freq + ear_q*min_bw) + 44 | jnp.log(low_freq + ear_q*min_bw))/n) * (high_freq + 45 | ear_q*min_bw)) 46 | return cf_array 47 | 48 | 49 | def MakeErbFilters(fs: float, num_channels: int, 50 | low_freq:float = 20) -> List[jnp.ndarray]: 51 | """This function computes the filter coefficients for a bank of 52 | Gammatone filters. These filters were defined by Patterson and 53 | Holdworth for simulating the cochlea. 54 | 55 | The result is returned as an array of filter coefficients. Each row 56 | of the filter arrays contains the coefficients for four second order 57 | filters. The transfer function for these four filters share the same 58 | denominator (poles) but have different numerators (zeros). All of these 59 | coefficients are assembled into one vector that the ErbFilterBank 60 | can take apart to implement the filter. 61 | 62 | The filter bank contains "num_channels" channels that extend from 63 | half the sampling rate (fs) to "low_freq". Alternatively, if the num_channels 64 | ijnp.t argument is a vector, then the values of this vector are taken to 65 | be the center frequency of each desired filter. (The low_freq argument is 66 | ignored in this case.) 67 | 68 | Note this implementation fixes a problem in the original code by 69 | computing four separate second order filters. This avoids a big 70 | problem with round off errors in cases of very small cfs (100Hz) and 71 | large sample rates (44kHz). The problem is caused by roundoff error 72 | when a number of poles are combined, all very close to the unit 73 | circle. Small errors in the eigth order coefficient, are multiplied 74 | when the eigth root is taken to give the pole location. These small 75 | errors lead to poles outside the unit circle and instability. Thanks 76 | to Julius Smith for leading me to the proper explanation. 77 | 78 | Execute the following code to evaluate the frequency 79 | response of a 10 channel filterbank. 80 | n = 512 81 | fs = 16000 82 | fcoefs = pat.MakeErbFilters(16000,10,100) 83 | y = pat.ErbFilterBank(jnp.array([1.0] + [0] * (n-1), dtype=float), fcoefs) 84 | resp = 20*jnp.log10(jnp.abs(jnp.fft.fft(y, axis=1))).T 85 | freq_scale = jnp.expand_dims(jnp.linspace(0, 16000, 512), 1) 86 | plt.semilogx(freq_scale[:n//2, :], resp[:n//2, :]) 87 | plt.axis((100, fs/2, -60, 0)) 88 | plt.xlabel('Frequency (Hz)') 89 | plt.ylabel('Filter Response (dB)'); 90 | 91 | Args: 92 | fs: Sampling rate (in Hz) of the filterbank (needed to determine CFs). 93 | num_channel: How many channels in the filterbank. 94 | low_freq: The lowest center frequency of the filterbank. 95 | 96 | Returns: 97 | A list of 10 num_channel-D arrays containing the filter parameters. 98 | """ 99 | 100 | t = 1/fs 101 | if isinstance(num_channels, int): 102 | cf = ErbSpace(low_freq, fs/2, num_channels) 103 | else: 104 | cf = num_channels 105 | 106 | # So equations below match the original Matlab syntax 107 | pi = jnp.pi 108 | abs = jnp.abs # pylint: disable=redefined-builtin 109 | sqrt = jnp.sqrt 110 | sin = jnp.sin 111 | cos = jnp.cos 112 | exp = jnp.exp 113 | i = jnp.array([1j], dtype=jnp.csingle) 114 | 115 | # Change the follow_freqing three parameters if you wish to use a different 116 | # erb scale. Must change in ErbSpace too. 117 | ear_q = 9.26449 # Glasberg and Moore Parameters 118 | min_bw = 24.7 119 | order = 1 120 | 121 | erb = ((cf/ear_q)**order + min_bw**order)**(1/order) 122 | b=1.019*2*pi*erb 123 | 124 | a0 = t # Feedback coefficients (poles) 125 | a2 = 0 126 | b0 = 1 # Feedforward coefficients (zeros) 127 | b1 = -2*cos(2*cf*pi*t)/exp(b*t) 128 | b2 = exp(-2*b*t) 129 | 130 | a11 = -(2*t*cos(2*cf*pi*t)/exp(b*t) + 2*sqrt(3+2**1.5)*t*sin(2*cf*pi*t)/ 131 | exp(b*t))/2 132 | a12 = -(2*t*cos(2*cf*pi*t)/exp(b*t) - 2*sqrt(3+2**1.5)*t*sin(2*cf*pi*t)/ 133 | exp(b*t))/2 134 | a13 = -(2*t*cos(2*cf*pi*t)/exp(b*t) + 2*sqrt(3-2**1.5)*t*sin(2*cf*pi*t)/ 135 | exp(b*t))/2 136 | a14 = -(2*t*cos(2*cf*pi*t)/exp(b*t) - 2*sqrt(3-2**1.5)*t*sin(2*cf*pi*t)/ 137 | exp(b*t))/2 138 | 139 | gain = abs((-2*exp(4*i*cf*pi*t)*t + 140 | 2*exp(-(b*t) + 2*i*cf*pi*t)*t* 141 | (cos(2*cf*pi*t) - sqrt(3 - 2**(3/2))* 142 | sin(2*cf*pi*t))) * 143 | (-2*exp(4*i*cf*pi*t)*t + 144 | 2*exp(-(b*t) + 2*i*cf*pi*t)*t* 145 | (cos(2*cf*pi*t) + sqrt(3 - 2**(3/2)) * 146 | sin(2*cf*pi*t)))* 147 | (-2*exp(4*i*cf*pi*t)*t + 148 | 2*exp(-(b*t) + 2*i*cf*pi*t)*t* 149 | (cos(2*cf*pi*t) - 150 | sqrt(3 + 2**(3/2))*sin(2*cf*pi*t))) * 151 | (-2*exp(4*i*cf*pi*t)*t + 2*exp(-(b*t) + 2*i*cf*pi*t)*t* 152 | (cos(2*cf*pi*t) + sqrt(3 + 2**(3/2))*sin(2*cf*pi*t))) / 153 | (-2 / exp(2*b*t) - 2*exp(4*i*cf*pi*t) + 154 | 2*(1 + exp(4*i*cf*pi*t))/exp(b*t))**4) 155 | 156 | allfilts = jnp.ones(len(cf)) 157 | fcoefs = [a0*allfilts, a11, a12, a13, a14, a2*allfilts, 158 | b0*allfilts, b1, b2, gain] 159 | return fcoefs 160 | 161 | 162 | def ErbFilterBank(x: jnp.ndarray, 163 | fcoefs: Union[jnp.ndarray, List[jnp.ndarray]]) -> jnp.ndarray: 164 | """Filter an input signal with a filterbank, producing one output vector 165 | per channel. 166 | 167 | Args: 168 | x: The input signal, one-dimensional 169 | fcoefs: A list of 10 num-channel-dimensional arrays that describe the 170 | filterbank. Alternatively, this can be a 10 x num_channel array. 171 | 172 | Returns: 173 | num-channel outputs in a num_channel x time array. 174 | """ 175 | # Take apart the list of coefficients, or split apart the single array. 176 | [a0, a11, a12, a13, a14, a2, b0, b1, b2, gain] = fcoefs 177 | n_chan = a0.shape[0] 178 | assert n_chan == a11.shape[0] 179 | assert n_chan == a12.shape[0] 180 | assert n_chan == a13.shape[0] 181 | assert n_chan == a14.shape[0] 182 | assert n_chan == b0.shape[0] 183 | assert n_chan == b1.shape[0] 184 | assert n_chan == b2.shape[0] 185 | assert n_chan == gain.shape[0] 186 | 187 | sos = jnp.stack((jnp.stack([a0/gain, a0, a0, a0], axis=1), 188 | jnp.stack([a11/gain, a12, a13, a14], axis=1), 189 | jnp.stack([a2/gain, a2, a2, a2], axis=1), 190 | jnp.stack([b0, b0, b0, b0], axis=1), 191 | jnp.stack([b1, b1, b1, b1], axis=1), 192 | jnp.stack([b2, b2, b2, b2], axis=1)), 193 | axis=2) 194 | 195 | def ErbKernel(f): 196 | return SosFilt(f, x) 197 | 198 | return jax.vmap(ErbKernel, in_axes=0)(sos) 199 | 200 | 201 | def CorrelogramFrame(data: jnp.ndarray, pic_width: int, 202 | start: int = 0, win_len: int = 0) -> jnp.ndarray: 203 | """Generate one from of a correlogram using FFTs to calculate autocorrelation. 204 | 205 | Args 206 | data: A num_channel x time array of input waveforms, one time domain signal 207 | per channel. 208 | pic_width: Number of pixels (time lags) in the final correlogram frame. 209 | start: The starting sample 210 | win_length: How much data to take from the input signal when computing the 211 | autocorrelation. 212 | 213 | Returns: 214 | A two dimensional array, of size num_channels x pic_width, containing one 215 | frame of the correlogram. 216 | """ 217 | _, data_len = data.shape 218 | if not win_len: 219 | win_len = data_len 220 | 221 | # Round up to double the window size, and then the next power of 2. 222 | fft_size = int(2**jnp.ceil(jnp.log2(2*max(pic_width, win_len)))) 223 | 224 | start = max(0, start) 225 | last = min(data_len, start+win_len) 226 | a = .54 227 | b = -.46 228 | wr = math.sqrt(64/256) 229 | phi = jnp.pi/win_len 230 | ws = 2*wr/jnp.sqrt(4*a*a+2*b*b)*( 231 | a + b*jnp.cos(2*jnp.pi*(jnp.arange(win_len))/win_len + phi)) 232 | 233 | f = jnp.hstack((data[:, start:last] * ws[:last-start], 234 | jnp.zeros((data.shape[0], fft_size - (last-start))))) 235 | f = jnp.fft.fft(f, axis=1) 236 | f = jnp.fft.ifft(f*jnp.conj(f), axis=1) 237 | pic = jnp.maximum(0, jnp.real(f[:, :pic_width])) 238 | good_rows = jnp.logical_and( # Make sure first column is bigger than the rest. 239 | pic[:, 0] > 0, 240 | jnp.logical_and(pic[:, 0] > pic[:, 1], pic[:, 0] > pic[:, 2])) 241 | pic = jnp.where(jnp.expand_dims(good_rows, axis=-1), 242 | pic / jnp.tile(jnp.sqrt(pic[:, :1]), (1, pic_width)), 243 | jnp.array([0])) 244 | 245 | return pic 246 | 247 | 248 | def FMPoints(sample_len, freq, fm_freq=6, fm_amp=None, fs=22050): 249 | """Generate impulse train corresponding to a vibrato. 250 | 251 | points=FMPoints(sample_len, freq, fm_freq, fm_amp, fs) 252 | Generates (fractional) sample locations for frequency-modulated impulses 253 | sample_len = number of samples 254 | freq = pitch frequency (Hz) 255 | fm_freq = vibrato frequency (Hz) (defaults to 6 Hz) 256 | fm_amp = max change in pitch (defaults to 5% of freq) 257 | fs = sample frequency (defaults to 22254.545454 samples/s) 258 | 259 | Basic formula: phase angle = 2*pi*freq*t + 260 | (fm_amp/fm_freq)*sin(2*pi*fm_freq*t) 261 | k-th zero crossing approximately at sample number 262 | (fs/freq)*(k - (fm_amp/(2*pi*fm_freq))*sin(2*pi*k*(fm_freq/freq))) 263 | 264 | Args: 265 | sample_len: How much data to generate, in samples 266 | freq: Base frequency of the output signal (Hz) 267 | fm_freq: Vibrato frequency (in Hz) 268 | fm_amp: Magnitude of the FM deviation (in Hz) 269 | fs: Sample rate for the output signal. 270 | 271 | Returns: 272 | An impulse train, indicating the positive-going zero crossing 273 | of the phase funcion. 274 | """ 275 | if fm_amp is None: 276 | fm_amp = 0.05*freq 277 | 278 | kmax = int(freq*(sample_len/fs)) 279 | points = jnp.arange(kmax) 280 | points = (fs/freq)*(points-( 281 | fm_amp/(2*jnp.pi*fm_freq))*jnp.sin(2*jnp.pi*(fm_freq/freq)*points)) 282 | return points 283 | 284 | 285 | def MakeVowel(sample_len, pitch, sample_rate, f1=0, f2=0, f3=0, bw=50): 286 | """Synthesize an artificial vowel using formant filters. 287 | 288 | MakeVowel(sample_len, pitch [, sample_rate, f1, f2, f3]) - 289 | Make a vowel with 290 | "sample_len" samples and the given pitch. The sample rate defaults to 291 | be 22254.545454 Hz (the native Mactinosh Sampling Rate). The 292 | formant frequencies are f1, f2 & f3. Some common vowels are 293 | Vowel f1 f2 f3 294 | /a/ 730 1090 2440 295 | /i/ 270 2290 3010 296 | /u/ 300 870 2240 297 | 298 | The pitch variable can either be a scalar indicating the actual 299 | pitch frequency, or an array of impulse locations. Using an 300 | array of impulses allows this routine to compute vowels with 301 | varying pitch. 302 | 303 | Alternatively, f1 can be replaced with one of the following strings 304 | 'a', 'i', 'u' and the appropriate formant frequencies are 305 | automatically selected. 306 | 307 | Args: 308 | sample_len: How many samples to generate 309 | pitch: Either a single floating point value indidcating a constant 310 | pitch (in Hz), or a train of impulses generated by FMPoints. 311 | sample_rate: The sample rate for the output signal (Hz) 312 | f1: Either a vowel spec, one of /a/, /i/, or /u', or the frequency 313 | of the first formatn. 314 | f2: Optional 2nd formant frequency (if f1 is not a vowel name) 315 | f3: Optional 3rd formant frequency (if f1 is not a vowel name) 316 | bw: Width of the formant filter (in Hz) 317 | 318 | Returns: 319 | A time domain waveform containing the synthetic vowel sound. 320 | """ 321 | if isinstance(f1, str): 322 | if f1 == 'a' or f1 == '/a/': 323 | f1, f2, f3 = (730, 1090, 2440) 324 | elif f1 == 'i' or f1 == '/i/': 325 | f1, f2, f3 = (270, 2290, 3010) 326 | elif f1 == 'u' or f1 == '/u/': 327 | f1, f2, f3 = (300, 870, 2240) 328 | 329 | 330 | # GlottalPulses(pitch, fs, sample_len) - Generate a stream of 331 | # glottal pulses with the given pitch (in Hz) and sampling 332 | # frequency (sample_rate). A vector of the requested length is returned. 333 | y = jnp.zeros(sample_len, float) 334 | if isinstance(pitch, (int, float)): 335 | points = jnp.arange(0, sample_len-1, sample_rate/pitch) 336 | else: 337 | points = jnp.sort(jnp.asarray(pitch)) 338 | points = points[points < sample_len-1] 339 | indices = jnp.floor(points).astype(int) 340 | 341 | # Use a triangular approximation to an impulse function. The important 342 | # part is to keep the total amplitude the same. 343 | y = y.at[indices].set((indices+1)-points) 344 | y = y.at[indices+1].set(points-indices) 345 | 346 | # GlottalFilter(x,fs) - Filter an impulse train and simulate the glottal 347 | # transfer function. The sampling interval (sample_rate) is given in Hz. 348 | # The filtering performed by this function is two first-order filters 349 | # at 250Hz. 350 | a = jnp.exp(-250*2*jnp.pi/sample_rate) 351 | #y=filter([1,0,-1],[1,-2*a,a*a],y) # Not as good as one below.... 352 | y = SignalFilter([1, 0, 0],[1,0,-a*a],y) 353 | 354 | # FormantFilter(ijnp.t, f, fs) - Filter an ijnp.t sequence to model one 355 | # formant in a speech signal. The formant frequency (in Hz) is given 356 | # by f and the bandwidth of the formant is a constant 50Hz. The 357 | # sampling frequency in Hz is given by fs. 358 | if f1 > 0: 359 | cft = f1/sample_rate 360 | q = f1/bw 361 | rho = jnp.exp(-jnp.pi * cft / q) 362 | theta = 2 * jnp.pi * cft * jnp.sqrt(1-1/(4 * q*q)) 363 | a2 = -2*rho*jnp.cos(theta) 364 | a3 = rho*rho 365 | y=SignalFilter([1+a2+a3, 0, 0],[1,a2,a3],y) 366 | 367 | # FormantFilter(ijnp.t, f, fs) - Filter an ijnp.t sequence to model one 368 | # formant in a speech signal. The formant frequency (in Hz) is given 369 | # by f and the bandwidth of the formant is a constant 50Hz. The 370 | # sampling frequency in Hz is given by fs. 371 | if f2 > 0: 372 | cft = f2/sample_rate 373 | q = f2/bw 374 | rho = jnp.exp(-jnp.pi * cft / q) 375 | theta = 2 * jnp.pi * cft * jnp.sqrt(1-1/(4 * q*q)) 376 | a2 = -2*rho*jnp.cos(theta) 377 | a3 = rho*rho 378 | y= SignalFilter([1+a2+a3, 0, 0],[1,a2,a3],y) 379 | 380 | # FormantFilter(ijnp.t, f, fs) - Filter an ijnp.t sequence to model one 381 | # formant in a speech signal. The formant frequency (in Hz) is given 382 | # by f and the bandwidth of the formant is a constant 50Hz. The 383 | # sampling frequency in Hz is given by fs. 384 | if f3 > 0: 385 | cft = f3/sample_rate 386 | q = f3/bw 387 | rho = jnp.exp(-jnp.pi * cft / q) 388 | theta = 2 * jnp.pi * cft * jnp.sqrt(1-1/(4 * q*q)) 389 | a2 = -2*rho*jnp.cos(theta) 390 | a3 = rho*rho 391 | y= SignalFilter([1+a2+a3, 0, 0],[1,a2,a3],y) 392 | return y 393 | 394 | 395 | def CorrelogramArray(data: jnp.ndarray, sr: float = 16000, 396 | frame_rate: int = 12, width: int = 256) -> jnp.ndarray: 397 | """Generate an array of correlogram frames. 398 | 399 | Args: 400 | data: The filterbank's output, size num_channel x time. 401 | sr: The sample rate for the data (needed when computing the frame times) 402 | frame_rate: How often (in Hz) correlogram frames should be generated. 403 | width: The width (in lags) of the correlogram 404 | 405 | Returns: 406 | A num_frames x num_channels x width tensor of correlogram frames. 407 | """ 408 | _, sample_len = data.shape 409 | frame_increment = int(sr/frame_rate) 410 | frame_count = int((sample_len-width)/frame_increment) + 1 411 | movie = [] 412 | for i in range(frame_count): 413 | start = i*frame_increment 414 | frame = CorrelogramFrame(data, width, start, frame_increment*4) 415 | movie.append(frame) 416 | return jnp.asarray(movie) 417 | 418 | def CorrelogramPitch(correlogram, width, sr=22254.54, 419 | low_pitch=0, high_pitch=20000): 420 | """Compute the summary of a correlogram to find the pitch. 421 | 422 | pitch=CorrelogramPitch(correlogram, width, sr, low_pitch, high_pitch 423 | computes the pitch of a correlogram sequence by finding the time lag 424 | with the largest correlation energy. 425 | 426 | Args: 427 | correlogram: A 3D correlogram array, output from CorrelogramArray. 428 | num_frames x num_channels x num_times 429 | width: Width of the correlogram. Historical parameter. Should be 430 | equal to correlogram.shape[1] 431 | low_pitch: Lowest allowable pitch (Hz). Pitch peaks are only searched 432 | within the region low_pitch to high_pitch 433 | high_pitch: Highest allowable pitch (Hz). 434 | 435 | Returns: 436 | A 2-element tuple, containing 437 | 1) a one-dimensional array of length num_frames indicating the pitch 438 | or 0 if no pitch is found 439 | 2) A one-dimensional array indicating the pitch salience on a scale 440 | from 0 (no pitch found) to 1 clear pitch. 441 | """ 442 | assert correlogram.ndim == 3 443 | width = correlogram.shape[2] # Someday remove this unneeded parameter 444 | 445 | freqs = sr/jnp.arange(width) # CF of each lag bin 446 | valid_pitch_lags = jnp.logical_and(freqs > low_pitch, 447 | freqs < high_pitch) 448 | 449 | pitch = [] 450 | salience = [] 451 | for j in range(correlogram.shape[0]): 452 | # Get one frame from the correlogram and compute 453 | # the sum (as a function of time lag) across all channels. 454 | summary = jnp.sum(correlogram[j, :, :], axis=0) 455 | zero_lag = summary[0] 456 | # Now we need to find the first pitch past the peak at zero 457 | # lag. The following lines smooth the summary pitch a bit, then 458 | # look for the first point where the summary goes back up. 459 | # Everything up to this point is zeroed out. 460 | window_length = 16 461 | sumfilt = jnp.convolve(summary, jnp.ones(window_length)/window_length, 462 | 'same') 463 | 464 | # Find the local maximums in the filtered summary correlogram. 465 | local_peak = jnp.logical_and(sumfilt[1:-1] > sumfilt[0:-2], 466 | sumfilt[1:-1] > sumfilt[2:]) 467 | local_peak = jnp.hstack((0, local_peak, 0)) 468 | peaks = jnp.where(jnp.logical_and(local_peak, 469 | valid_pitch_lags), 470 | summary, 471 | 0*summary) 472 | # Now find the location of the biggest peak and call this the pitch 473 | p = jnp.argmax(peaks) 474 | pitch.append(jnp.where(p > 0, 475 | freqs[p], 476 | 0)) 477 | salience.append(summary[p]/zero_lag) 478 | 479 | return jnp.array(pitch), jnp.array(salience) 480 | 481 | 482 | def Mfcc(input_signal, sampling_rate=16000, frame_rate=100, debug=False): 483 | """Mfcc - Mel frequency cepstrum coefficient analysis. 484 | 485 | Find the cepstral coefficients (ceps) corresponding to the 486 | input. 487 | 488 | Args: 489 | input_signal: The one-dimensional time-domain audio signal 490 | sampling_rate: The sample rate of the input in Hz. 491 | frame_rate: The desired output sampling rate 492 | debug: A debug flag that turns on various plots. 493 | 494 | Returns: 495 | A five-tuple consisting of: 496 | 1) The MFCC representation, a 13 x num_frames output. 497 | 2) The detailed fft magnitude (freqresp) used in MFCC calculation, 498 | 3) The mel-scale filter bank output (fb) 499 | 4) The filter bank output by inverting the cepstrals with a cosine 500 | transform (fbrecon), 501 | 5) The smooth frequency response by interpolating the fb reconstruction 502 | (freqrecon) 503 | 504 | Modified a bit to make testing an algorithm easier... 4/15/94 505 | Fixed Cosine Transform (indices of cos() were swapped) - 5/26/95 506 | Added optional frame_rate argument - 6/8/95 507 | Added proper filterbank reconstruction using inverse DCT - 10/27/95 508 | Added filterbank inversion to reconstruct spectrum - 11/1/95 509 | """ 510 | # Filter bank parameters 511 | lowest_frequency = 133.3333 512 | linear_filters = 13 513 | linear_spacing = 66.66666666 514 | log_filters = 27 515 | log_spacing = 1.0711703 516 | fft_size = 512 517 | cepstral_coefficients = 13 518 | window_size = 400 519 | window_size = 256 # Standard says 400, but 256 makes more sense 520 | # Really should be a function of the sample 521 | # rate (and the lowest_frequency) and the 522 | # frame rate. 523 | 524 | # Keep this around for later.... 525 | total_filters = linear_filters + log_filters 526 | 527 | # Now figure the band edges. Interesting frequencies are spaced 528 | # by linear_spacing for a while, then go logarithmic. First figure 529 | # all the interesting frequencies. Lower, center, and upper band 530 | # edges are all consequtive interesting frequencies. 531 | 532 | linear_freqs = (lowest_frequency + 533 | jnp.arange(linear_filters)*linear_spacing) 534 | log_freqs = linear_freqs[-1] * log_spacing**jnp.arange(1, 535 | log_filters+3) 536 | freqs = jnp.hstack((linear_freqs, log_freqs)) 537 | lower = freqs[:total_filters] 538 | center = freqs[1:total_filters+1] 539 | upper = freqs[2:total_filters+2] 540 | 541 | # We now want to combine FFT bins so that each filter has unit 542 | # weight, assuming a triangular weighting function. First figure 543 | # out the height of the triangle, then we can figure out each 544 | # frequencies contribution 545 | mfcc_filter_weights = jnp.zeros((total_filters,fft_size)) 546 | triangle_height = 2/(upper-lower) 547 | fft_freqs = jnp.arange(fft_size)/fft_size*sampling_rate 548 | 549 | for chan in range(total_filters): 550 | mfcc_filter_weights = mfcc_filter_weights.at[chan,:].set( 551 | jnp.logical_and(fft_freqs > lower[chan], fft_freqs <= center[chan]) * 552 | triangle_height[chan]*(fft_freqs-lower[chan])/(center[chan]- 553 | lower[chan]) + 554 | jnp.logical_and(fft_freqs > center[chan], fft_freqs < upper[chan]) * 555 | triangle_height[chan]*(upper[chan]-fft_freqs)/(upper[chan]- 556 | center[chan])) 557 | 558 | if debug: 559 | plt.semilogx(fft_freqs,mfcc_filter_weights.T) 560 | #axis([lower(1) upper(total_filters) 0 max(max(mfcc_filter_weights))]) 561 | 562 | ham_window = 0.54 - 0.46*jnp.cos(2*jnp.pi*jnp.arange(window_size)/window_size) 563 | 564 | if False: # Window it like ComplexSpectrum # pylint: disable=using-constant-test 565 | window_step = sampling_rate/frame_rate 566 | a = .54 567 | b = -.46 568 | wr = jnp.sqrt(window_step/window_size) 569 | phi = jnp.pi/window_size 570 | ham_window = (2*wr/jnp.sqrt(4*a*a+2*b*b)* 571 | (a + b*jnp.cos(2*jnp.pi*jnp.arange(window_size)/window_size + 572 | phi))) 573 | 574 | # Figure out Discrete Cosine Transform. We want a matrix 575 | # dct(i,j) which is total_filters x cepstral_coefficients in size. 576 | # The i,j component is given by 577 | # cos( i * (j+0.5)/total_filters pi ) 578 | # where we have assumed that i and j start at 0. 579 | 580 | cepstral_indices = jnp.reshape(jnp.arange(cepstral_coefficients), 581 | (cepstral_coefficients, 1)) 582 | filter_indices = jnp.reshape(2*jnp.arange(0, total_filters)+1, 583 | (1, total_filters)) 584 | cos_term = jnp.cos(jnp.matmul(cepstral_indices, 585 | filter_indices)*jnp.pi/2/total_filters) 586 | 587 | mfcc_dct_matrix = 1/jnp.sqrt(total_filters/2)*cos_term 588 | mfcc_dct_matrix = mfcc_dct_matrix.at[0,:].set(mfcc_dct_matrix[0,:] * 589 | jnp.sqrt(2)/2) 590 | 591 | if debug: 592 | plt.imshow(mfcc_dct_matrix) 593 | plt.xlabel('Filter Coefficient') 594 | plt.ylabel('Cepstral Coefficient') 595 | 596 | # Filter the ijnp.t with the preemphasis filter. Also figure how 597 | # many columns of data we will end up with. 598 | if True: # pylint: disable=using-constant-test 599 | pre_emphasized = SignalFilter([1, -.97, 0], [1, 0, 0], input_signal) 600 | else: 601 | pre_emphasized = input_signal 602 | 603 | window_step = sampling_rate/frame_rate 604 | cols = int((len(input_signal)-window_size)/window_step) 605 | 606 | # Allocate all the space we need for the output arrays. 607 | ceps = [] 608 | freqresp = [] 609 | fb = [] 610 | 611 | # Invert the filter bank center frequencies. For each FFT bin 612 | # we want to know the exact position in the filter bank to find 613 | # the original frequency response. The next block of code finds the 614 | # integer and fractional sampling positions. 615 | fr = jnp.arange(fft_size//2)/(fft_size/2)*sampling_rate/2 616 | j = 0 617 | for i in jnp.arange(fft_size//2): 618 | if fr[i] > center[j+1]: 619 | j = j + 1 620 | j = min(j, total_filters-2) 621 | # if j > total_filters-2: 622 | # j = total_filters-1 623 | fr = fr.at[i].set(min(total_filters-1-.0001, 624 | max(0,j + (fr[i]-center[j])/(center[j+1]- 625 | center[j])))) 626 | fri = fr.astype(int) 627 | frac = fr - fri 628 | 629 | freqrecon = [] 630 | fbrecon = [] 631 | # Ok, now let's do the processing. For each chunk of data: 632 | # * Window the data with a hamming window, 633 | # * Shift it into FFT order, 634 | # * Find the magnitude of the fft, 635 | # * Convert the fft data into filter bank outputs, 636 | # * Find the log base 10, 637 | # * Find the cosine transform to reduce dimensionality. 638 | for start in jnp.arange(cols): 639 | first = round(start*window_step) # Round added by Malcolm 640 | last = round(first + window_size) 641 | fft_data = jnp.zeros(fft_size) 642 | fft_data = fft_data.at[:window_size].set(pre_emphasized[first:last]* 643 | ham_window) 644 | fft_mag = jnp.abs(jnp.fft.fft(fft_data)) 645 | ear_mag = jnp.log10(jnp.matmul(mfcc_filter_weights, fft_mag.T)) 646 | 647 | ceps.append(jnp.expand_dims(jnp.matmul(mfcc_dct_matrix, ear_mag), axis=-1)) 648 | freqresp.append(jnp.expand_dims(fft_mag[:fft_size//2].T, axis=-1)) 649 | fb.append(jnp.expand_dims(ear_mag, axis=-1)) 650 | fbrecon.append(jnp.matmul(mfcc_dct_matrix[:cepstral_coefficients,:].T, 651 | ceps[-1])) 652 | 653 | f10 = 10**fbrecon[-1] 654 | recon = sampling_rate/fft_size * (f10[fri, 0]*(1-frac) + 655 | f10[fri+1, 0]*frac) 656 | freqrecon.append(jnp.expand_dims(recon, axis=-1)) 657 | ceps = jnp.hstack(ceps) 658 | freqresp = jnp.hstack(freqresp) 659 | fb = jnp.hstack(fb) 660 | fbrecon = jnp.hstack(fbrecon) 661 | freqrecon = jnp.hstack(freqrecon) 662 | 663 | # OK, just to check things, let's also reconstruct the original FB 664 | # output. We do this by multiplying the cepstral data by the transpose 665 | # of the original DCT matrix. This all works because we were careful to 666 | # scale the DCT matrix so it was orthonormal. 667 | fbrecon = jnp.matmul(mfcc_dct_matrix[:cepstral_coefficients,:].T, ceps) 668 | return ceps, freqresp, fb, fbrecon, freqrecon 669 | 670 | 671 | @jax.jit 672 | def FilterScan(carry, x, a, b): 673 | """Internal function needed for jax.lax.scan.""" 674 | vzm1, vzm2 = carry 675 | v = x - a[1]*vzm1 - a[2]*vzm2 676 | y = b[0] * v + b[1]*vzm1 + b[2]*vzm2 677 | vzm2 = vzm1 # v delayed by one sample 678 | vzm1 = v # v delayed by two samples 679 | carry = vzm1, vzm2 680 | return carry, y 681 | 682 | 683 | @jax.jit 684 | def SignalFilter(b, a, x): 685 | """Redefine the filter function in scipy.signal.lfiter. This version only 686 | does second-order sections, and always filters over the last dimension.""" 687 | b = jnp.asarray(b) / a[0] 688 | a = jnp.asarray(a) / a[0] 689 | 690 | # Define a scan function with the filter parameters. 691 | def FilterKernel(carry, x): 692 | return FilterScan(carry, x, a, b) 693 | 694 | _, y = jax.lax.scan(FilterKernel, (0.0, 0.0), x) 695 | return y 696 | 697 | 698 | @jax.jit 699 | def SosFilt(sos, x): 700 | """Redefine the sosfilt function from scipy.signal.sosfilter. This version 701 | only filters over the last dimension. 702 | """ 703 | stages, six = sos.shape 704 | assert six == 6 705 | 706 | for i in range(stages): 707 | x = SignalFilter(sos[i, :3], sos[i, 3:], x) 708 | return x 709 | -------------------------------------------------------------------------------- /python_auditory_toolbox/auditory_toolbox_jax_test.py: -------------------------------------------------------------------------------- 1 | """Code to test the auditory toolbox.""" 2 | from absl.testing import absltest 3 | import jax.numpy as jnp 4 | import numpy as np # For testing 5 | import scipy 6 | import matplotlib.pyplot as plt 7 | 8 | import auditory_toolbox_jax as pat 9 | 10 | 11 | class AuditoryToolboxTests(absltest.TestCase): 12 | """Test cases for auditory toolbox.""" 13 | def test_erb_space(self): 14 | low_freq = 100.0 15 | high_freq = 44100/4.0 16 | num_channels = 100 17 | cf_array = pat.ErbSpace(low_freq = low_freq, high_freq = high_freq, 18 | n = num_channels) 19 | self.assertLen(cf_array, num_channels) 20 | # Make sure low and high CF's are where we expect them to be. 21 | self.assertAlmostEqual(cf_array[-1], low_freq, delta=0.001) 22 | self.assertLess(cf_array[0], high_freq) 23 | 24 | def test_make_erb_filters(self): 25 | # Ten channel ERB Filterbank. Make sure return has the right size. 26 | # Will test coefficients when we test the filterbank. 27 | fs = 16000 28 | low_freq = 100 29 | num_chan = 10 30 | fcoefs = pat.MakeErbFilters(fs, num_chan, low_freq) 31 | self.assertLen(fcoefs, 10) 32 | 33 | # Test all the filter coefficient array shapes 34 | a0, a11, a12, a13, a14, a2, b0, b1, b2, gain = fcoefs 35 | self.assertEqual(a0.shape, (num_chan,)) 36 | self.assertEqual(a11.shape, (num_chan,)) 37 | self.assertEqual(a12.shape, (num_chan,)) 38 | self.assertEqual(a13.shape, (num_chan,)) 39 | self.assertEqual(a14.shape, (num_chan,)) 40 | self.assertEqual(a2.shape, (num_chan,)) 41 | self.assertEqual(b0.shape, (num_chan,)) 42 | self.assertEqual(b1.shape, (num_chan,)) 43 | self.assertEqual(b2.shape, (num_chan,)) 44 | self.assertEqual(gain.shape, (num_chan,)) 45 | 46 | 47 | def test_erb_filterbank(self): 48 | fs = 16000 49 | low_freq = 100 50 | num_chan = 10 51 | fcoefs = pat.MakeErbFilters(fs, num_chan, low_freq) 52 | 53 | impulse_len = 512 54 | x = jnp.hstack((jnp.ones(1), jnp.zeros(impulse_len-1))) 55 | 56 | y = pat.ErbFilterBank(x, fcoefs) 57 | self.assertEqual(y.shape, (num_chan, impulse_len)) 58 | self.assertAlmostEqual(np.max(y), 0.10657410, delta=0.01) 59 | 60 | resp = 20*jnp.log10(jnp.abs(jnp.fft.fft(y.T, axis=0))) 61 | 62 | # Test to make sure spectral peaks are in the right place for each channel 63 | matlab_peak_locs = [184, 132, 94, 66, 46, 32, 21, 14, 8, 4] 64 | python_peak_locs = jnp.argmax(resp[:impulse_len//2], axis=0) 65 | 66 | # Add one to python locs because Matlab arrays start at 1 67 | self.assertEqual(matlab_peak_locs, list(python_peak_locs+1)) 68 | 69 | # Test using a single array for the fcoefs 70 | fcoefs_array = jnp.stack(fcoefs) 71 | self.assertEqual(fcoefs_array.shape, (10, 10)) 72 | y = pat.ErbFilterBank(x, fcoefs) 73 | self.assertEqual(y.shape, (num_chan, impulse_len)) 74 | self.assertAlmostEqual(np.max(y), 0.10657410, delta=0.01) 75 | 76 | def test_erb_filterbank_example(self): 77 | """Just to make sure the example code keeps working.""" 78 | n = 512 79 | fs = 16000 80 | fcoefs = pat.MakeErbFilters(16000,10,100) 81 | y = pat.ErbFilterBank(jnp.array([1.0] + [0] * (n-1), dtype=float), fcoefs) 82 | resp = 20*jnp.log10(jnp.abs(jnp.fft.fft(y, axis=1))).T 83 | freq_scale = jnp.expand_dims(jnp.linspace(0, 16000, 512), 1) 84 | plt.semilogx(freq_scale[:n//2, :], resp[:n//2, :]) 85 | plt.axis((100, fs/2, -60, 0)) 86 | plt.xlabel('Frequency (Hz)') 87 | plt.ylabel('Filter Response (dB)') 88 | 89 | def test_correlogram_array(self): 90 | def local_peaks(x): 91 | i = np.argwhere(np.logical_and(x[:-2] < x[1:-1], 92 | x[2:] < x[1:-1])) + 1 93 | return [j[0] for j in i] 94 | 95 | test_impulses = jnp.zeros((1,1024)) 96 | for i in range(0, test_impulses.shape[1], 100): 97 | test_impulses = test_impulses.at[:, i].set(1) 98 | test_frame = pat.CorrelogramFrame(test_impulses, 256) 99 | self.assertEqual(list(jnp.where(test_frame > 0.1)[1]), 100 | [0, 100, 200]) 101 | 102 | # Now test with cochlear input to correlogram 103 | impulse_len = 512 104 | fs = 16000 105 | low_freq = 100 106 | num_chan = 64 107 | fcoefs = pat.MakeErbFilters(fs, num_chan, low_freq) 108 | 109 | # Make harmonic input signal 110 | s = 0 111 | pitch_lag = 200 112 | for h in range(1, 10): 113 | s = s + jnp.sin(2*jnp.pi*jnp.arange(impulse_len)/pitch_lag*h) 114 | 115 | y = pat.ErbFilterBank(s, fcoefs) 116 | frame_width = 256 117 | frame = pat.CorrelogramFrame(y, frame_width) 118 | self.assertEqual(frame.shape, (num_chan, frame_width)) 119 | self.assertGreaterEqual(jnp.min(frame), 0.0) 120 | 121 | # Make sure the top channels have no output. 122 | spectral_profile = np.sum(frame, 1) 123 | no_output = np.where(spectral_profile < 2) 124 | np.testing.assert_equal(no_output[0], np.arange(31)) 125 | 126 | # Make sure we have spectral peaks at the right locations 127 | spectral_peaks = local_peaks(spectral_profile) 128 | self.assertEqual(spectral_peaks, [42, 44, 46, 48, 50, 53, 56, 60]) 129 | 130 | # Make sure the first peak (after 0 lag) is at the pitch lag 131 | summary_correlogram = jnp.sum(frame, 0) 132 | skip_lags = 100 133 | self.assertEqual(np.argmax(summary_correlogram[skip_lags:]) + skip_lags, 134 | pitch_lag) 135 | 136 | def test_correlogram_pitch(self): 137 | sample_len = 20000 138 | sample_rate = 22254 139 | pitch_center = 120 140 | u = pat.MakeVowel(sample_len, pat.FMPoints(sample_len, pitch_center), 141 | sample_rate, 'u') 142 | 143 | low_freq = 60 144 | num_chan = 100 145 | fcoefs = pat.MakeErbFilters(sample_rate, num_chan, low_freq) 146 | coch = pat.ErbFilterBank(u, fcoefs) 147 | cor = pat.CorrelogramArray(coch,sample_rate,50,256) 148 | [pitch,sal] = pat.CorrelogramPitch(cor, 256, sample_rate,100,200) 149 | 150 | # Make sure center and overall pitch deviation are as expected. 151 | self.assertAlmostEqual(jnp.mean(pitch), pitch_center, delta=2) 152 | self.assertAlmostEqual(jnp.min(pitch), pitch_center-6, delta=2) 153 | self.assertAlmostEqual(jnp.max(pitch), pitch_center+6, delta=2) 154 | np.testing.assert_array_less(0.8, sal[:40]) 155 | 156 | # Now test salience when we add noise 157 | n = np.random.randn(sample_len) * np.arange(sample_len)/sample_len 158 | un=u + n/4 159 | 160 | low_freq = 60 161 | num_chan = 100 162 | fcoefs = pat.MakeErbFilters(sample_rate, num_chan, low_freq) 163 | coch= pat.ErbFilterBank(un, fcoefs) 164 | cor = pat.CorrelogramArray(coch,sample_rate,50,256) 165 | [pitch,sal] = pat.CorrelogramPitch(cor,256,22254,100,200) 166 | 167 | lr = scipy.stats.linregress(range(len(sal)), y=sal, alternative='less') 168 | self.assertAlmostEqual(lr.slope, -0.012, delta=0.01) # Probabilistic data, 169 | self.assertAlmostEqual(lr.rvalue, -0.963, delta=0.03) # so be tolerant. 170 | 171 | def test_mfcc(self): 172 | # Put a tone into MFCC and make sure it's in the right 173 | # spot in the reconstruction. 174 | sample_rate = 16000.0 175 | f0 = 2000 176 | tone = jnp.sin(2*jnp.pi*f0*jnp.arange(4000)/sample_rate) 177 | [_,_,_,_,freqrecon]= pat.Mfcc(tone,sample_rate,100) 178 | 179 | fft_size = 512 # From the MFCC source code 180 | self.assertEqual(f0/sample_rate*fft_size, 181 | jnp.argmax(jnp.sum(freqrecon, axis=1))) 182 | 183 | def test_fm_points (self): 184 | base_pitch = 160 185 | sample_rate = 16000 186 | fmfreq = 10 187 | fmamp = 20 188 | points = pat.FMPoints(100000, base_pitch, fmfreq, fmamp, 16000) 189 | 190 | # Make sure the average glottal pulse locations is 1 over the pitch 191 | d_points = points[1:] - points[:-1] 192 | self.assertAlmostEqual(jnp.mean(d_points), sample_rate/base_pitch, delta=1) 193 | 194 | # Make sure the frequency deviation is as expected. 195 | # ToDo(malcolm): Test the deviation, it's not right! 196 | 197 | def test_make_vowel(self): 198 | def local_peaks(x): 199 | i = jnp.argwhere(jnp.logical_and(x[:-2] < x[1:-1], 200 | x[2:] < x[1:-1])) + 1 201 | return jnp.array([j[0] for j in i]) 202 | 203 | test_seq = local_peaks(jnp.array([1,2,3,2,1,1,2,2,3,4,1])) 204 | self.assertEqual(list(test_seq), [2, 9]) 205 | 206 | def vowel_peaks(vowel): 207 | """Synthesize a vowel and find the frequencies of the spectral peaks""" 208 | sample_rate = 16000 209 | vowel = pat.MakeVowel(1024, [1,], sample_rate, vowel) 210 | spectrum = 20*jnp.log10(jnp.abs(jnp.fft.fft(vowel))) 211 | freqs = jnp.arange(len(vowel))*sample_rate/len(vowel) 212 | return freqs[local_peaks(spectrum)[:3]] 213 | 214 | def peak_widths(vowel, bw=50): 215 | """Synthesize a vowel and find the frequencies of the spectral peaks""" 216 | sample_rate = 16000 217 | vowel = pat.MakeVowel(1024, [1,], sample_rate, vowel, bw=bw) 218 | spectrum = 20*np.log10(np.abs(np.fft.fft(vowel))) 219 | peak_locs = local_peaks(spectrum)[:3] 220 | peak_widths = scipy.signal.peak_widths(spectrum, peak_locs, 221 | rel_height=0.5)[0] 222 | return peak_widths 223 | 224 | # Make sure the spectrum of each vowel has peaks in the right spots. 225 | bin_width = 16000/1024 226 | np.testing.assert_allclose(vowel_peaks('a'), 227 | np.array([730, 1090, 2440]), 228 | atol=bin_width) 229 | np.testing.assert_allclose(vowel_peaks('i'), 230 | np.array([270, 2290, 3010]), 231 | atol=bin_width) 232 | np.testing.assert_allclose(vowel_peaks('u'), 233 | np.array([300, 870, 2240]), 234 | atol=bin_width) 235 | 236 | widths_50 = peak_widths('/a/', 50) 237 | widths_100 = peak_widths('/a/', 100) 238 | np.testing.assert_array_less(widths_50, widths_100) 239 | 240 | 241 | if __name__ == '__main__': 242 | absltest.main() 243 | -------------------------------------------------------------------------------- /python_auditory_toolbox/auditory_toolbox_test.py: -------------------------------------------------------------------------------- 1 | """Code to test the auditory toolbox.""" 2 | import math 3 | 4 | from absl.testing import absltest 5 | import numpy as np 6 | import scipy 7 | import matplotlib.pyplot as plt 8 | 9 | import auditory_toolbox as pat 10 | 11 | 12 | class AuditoryToolboxTests(absltest.TestCase): 13 | """Test cases for auditory toolbox.""" 14 | def test_erb_space(self): 15 | low_freq = 100.0 16 | high_freq = 44100/4.0 17 | num_channels = 100 18 | cf_array = pat.ErbSpace(low_freq = low_freq, high_freq = high_freq, 19 | n = num_channels) 20 | self.assertLen(cf_array, num_channels) 21 | # Make sure low and high CF's are where we expect them to be. 22 | self.assertAlmostEqual(cf_array[-1], low_freq) 23 | self.assertLess(cf_array[0], high_freq) 24 | 25 | def test_make_erb_filters(self): 26 | # Ten channel ERB Filterbank. Make sure return has the right size. 27 | # Will test coefficients when we test the filterbank. 28 | fs = 16000 29 | low_freq = 100 30 | num_chan = 10 31 | fcoefs = pat.MakeErbFilters(fs, num_chan, low_freq) 32 | self.assertLen(fcoefs, 10) 33 | 34 | # Test all the filter coefficient array shapes 35 | a0, a11, a12, a13, a14, a2, b0, b1, b2, gain = fcoefs 36 | self.assertEqual(a0.shape, (num_chan,)) 37 | self.assertEqual(a11.shape, (num_chan,)) 38 | self.assertEqual(a12.shape, (num_chan,)) 39 | self.assertEqual(a13.shape, (num_chan,)) 40 | self.assertEqual(a14.shape, (num_chan,)) 41 | self.assertEqual(a2.shape, (num_chan,)) 42 | self.assertEqual(b0.shape, (num_chan,)) 43 | self.assertEqual(b1.shape, (num_chan,)) 44 | self.assertEqual(b2.shape, (num_chan,)) 45 | self.assertEqual(gain.shape, (num_chan,)) 46 | 47 | def test_erb_filterbank(self): 48 | fs = 16000 49 | low_freq = 100 50 | num_chan = 10 51 | fcoefs = pat.MakeErbFilters(fs, num_chan, low_freq) 52 | 53 | impulse_len = 512 54 | x = np.zeros(impulse_len) 55 | x[0] = 1 56 | 57 | y = pat.ErbFilterBank(x, fcoefs) 58 | self.assertEqual(y.shape, (num_chan, impulse_len)) 59 | self.assertAlmostEqual(np.max(y), 0.10657410, delta=0.01) 60 | 61 | resp = 20*np.log10(np.abs(np.fft.fft(y.T, axis=0))) 62 | 63 | # Test to make sure spectral peaks are in the right place for each channel 64 | matlab_peak_locs = np.array([184, 132, 94, 66, 46, 32, 21, 14, 8, 4]) 65 | python_peak_locs = np.argmax(resp[:impulse_len//2], axis=0) 66 | 67 | # Add one to python locs because Matlab arrays start at 1 68 | np.testing.assert_equal(matlab_peak_locs, python_peak_locs+1) 69 | 70 | def test_erb_filterbank_example(self): 71 | n = 512 72 | fs = 16000 73 | fcoefs = pat.MakeErbFilters(16000,10,100) 74 | y = pat.ErbFilterBank(np.array([1.0] + [0] * (n-1), dtype=float), fcoefs) 75 | resp = 20*np.log10(np.abs(np.fft.fft(y, axis=1))).T 76 | freq_scale = np.expand_dims(np.linspace(0, 16000, 512), 1) 77 | plt.semilogx(freq_scale[:n//2, :], resp[:n//2, :]) 78 | plt.axis((100, fs/2, -60, 0)) 79 | plt.xlabel('Frequency (Hz)') 80 | plt.ylabel('Filter Response (dB)') 81 | 82 | def test_correlogram_array(self): 83 | def local_peaks(x): 84 | i = np.argwhere(np.logical_and(x[:-2] < x[1:-1], 85 | x[2:] < x[1:-1])) + 1 86 | return [j[0] for j in i] 87 | 88 | test_impulses = np.zeros((1,1024)) 89 | test_impulses[0, range(0, test_impulses.shape[1], 100)] = 1 90 | test_frame = pat.CorrelogramFrame(test_impulses, 256) 91 | np.testing.assert_equal(np.where(test_frame > 0.1)[1], [0, 100, 200]) 92 | 93 | # Now test with cochlear input to correlogram 94 | impulse_len = 512 95 | fs = 16000 96 | low_freq = 100 97 | num_chan = 64 98 | fcoefs = pat.MakeErbFilters(fs, num_chan, low_freq) 99 | 100 | # Make harmonic input signal 101 | s = 0 102 | pitch_lag = 200 103 | for h in range(1, 10): 104 | s = s + np.sin(2*np.pi*np.arange(impulse_len)/pitch_lag*h) 105 | 106 | y = pat.ErbFilterBank(s, fcoefs) 107 | frame_width = 256 108 | frame = pat.CorrelogramFrame(y, frame_width) 109 | self.assertEqual(frame.shape, (num_chan, frame_width)) 110 | self.assertGreaterEqual(np.min(frame), 0.0) 111 | 112 | # Make sure the top channels have no output. 113 | spectral_profile = np.sum(frame, 1) 114 | no_output = np.where(spectral_profile < 2) 115 | np.testing.assert_equal(no_output[0], np.arange(31)) 116 | 117 | # Make sure we have spectral peaks at the right locations 118 | spectral_peaks = local_peaks(spectral_profile) 119 | self.assertEqual(spectral_peaks, [42, 44, 46, 48, 50, 53, 56, 60]) 120 | 121 | # Make sure the first peak (after 0 lag) is at the pitch lag 122 | summary_correlogram = np.sum(frame, 0) 123 | skip_lags = 100 124 | self.assertEqual(np.argmax(summary_correlogram[skip_lags:]) + skip_lags, 125 | pitch_lag) 126 | 127 | def test_correlogram_pitch(self): 128 | sample_len = 20000 129 | sample_rate = 22254 130 | pitch_center = 120 131 | u = pat.MakeVowel(sample_len, pat.FMPoints(sample_len, pitch_center), 132 | sample_rate, 'u') 133 | 134 | low_freq = 60 135 | num_chan = 100 136 | fcoefs = pat.MakeErbFilters(sample_rate, num_chan, low_freq) 137 | coch = pat.ErbFilterBank(u, fcoefs) 138 | cor = pat.CorrelogramArray(coch,sample_rate,50,256) 139 | [pitch,sal] = pat.CorrelogramPitch(cor, 256, sample_rate,100,200) 140 | 141 | # Make sure center and overall pitch deviation are as expected. 142 | self.assertAlmostEqual(np.mean(pitch), pitch_center, delta=2) 143 | self.assertAlmostEqual(np.min(pitch), pitch_center-6, delta=2) 144 | self.assertAlmostEqual(np.max(pitch), pitch_center+6, delta=2) 145 | np.testing.assert_array_less(0.8, sal[:40]) 146 | 147 | # Now test salience when we add noise 148 | n = np.random.randn(sample_len) * np.arange(sample_len)/sample_len 149 | un=u + n/4 150 | 151 | low_freq = 60 152 | num_chan = 100 153 | fcoefs = pat.MakeErbFilters(sample_rate, num_chan, low_freq) 154 | coch= pat.ErbFilterBank(un, fcoefs) 155 | cor = pat.CorrelogramArray(coch,sample_rate,50,256) 156 | [pitch,sal] = pat.CorrelogramPitch(cor,256,22254,100,200) 157 | 158 | lr = scipy.stats.linregress(range(len(sal)), y=sal, alternative='less') 159 | self.assertAlmostEqual(lr.slope, -0.012, delta=0.001) 160 | self.assertAlmostEqual(lr.rvalue, -0.963, delta=0.02) 161 | 162 | def test_mfcc(self): 163 | # Put a tone into MFCC and make sure it's in the right 164 | # spot in the reconstruction. 165 | sample_rate = 16000.0 166 | f0 = 2000 167 | tone = np.sin(2*np.pi*f0*np.arange(4000)/sample_rate) 168 | [_,_,_,_,freqrecon]= pat.Mfcc(tone,sample_rate,100) 169 | 170 | fft_size = 512 # From the MFCC source code 171 | self.assertEqual(f0/sample_rate*fft_size, 172 | np.argmax(np.sum(freqrecon, axis=1))) 173 | 174 | def test_fm_points (self): 175 | base_pitch = 160 176 | sample_rate = 16000 177 | fmfreq = 10 178 | fmamp = 20 179 | points = pat.FMPoints(100000, base_pitch, fmfreq, fmamp, 16000) 180 | 181 | # Make sure the average glottal pulse locations is 1 over the pitch 182 | d_points = points[1:] - points[:-1] 183 | self.assertAlmostEqual(np.mean(d_points), sample_rate/base_pitch, delta=1) 184 | 185 | # Make sure the frequency deviation is as expected. 186 | # ToDo(malcolm): Test the deviation, it's not right! 187 | 188 | def test_make_vowel(self): 189 | def local_peaks(x): 190 | i = np.argwhere(np.logical_and(x[:-2] < x[1:-1], 191 | x[2:] < x[1:-1])) + 1 192 | return [j[0] for j in i] 193 | 194 | test_seq = local_peaks(np.array([1,2,3,2,1,1,2,2,3,4,1])) 195 | np.testing.assert_equal(test_seq, np.array([2, 9])) 196 | 197 | def vowel_peaks(vowel): 198 | """Synthesize a vowel and find the frequencies of the spectral peaks""" 199 | sample_rate = 16000 200 | vowel = pat.MakeVowel(1024, [1,], sample_rate, vowel) 201 | spectrum = 20*np.log10(np.abs(np.fft.fft(vowel))) 202 | freqs = np.arange(len(vowel))*sample_rate/len(vowel) 203 | return freqs[local_peaks(spectrum)[:3]] 204 | 205 | def peak_widths(vowel, bw=50): 206 | """Synthesize a vowel and find the frequencies of the spectral peaks""" 207 | sample_rate = 16000 208 | vowel = pat.MakeVowel(1024, [1,], sample_rate, vowel, bw=bw) 209 | spectrum = 20*np.log10(np.abs(np.fft.fft(vowel))) 210 | peak_locs = local_peaks(spectrum)[:3] 211 | peak_widths = scipy.signal.peak_widths(spectrum, peak_locs, 212 | rel_height=0.5)[0] 213 | return peak_widths 214 | 215 | # Make sure the spectrum of each vowel has peaks in the right spots. 216 | bin_width = 16000/1024 217 | np.testing.assert_allclose(vowel_peaks('a'), 218 | np.array([730, 1090, 2440]), 219 | atol=bin_width) 220 | np.testing.assert_allclose(vowel_peaks('i'), 221 | np.array([270, 2290, 3010]), 222 | atol=bin_width) 223 | np.testing.assert_allclose(vowel_peaks('u'), 224 | np.array([300, 870, 2240]), 225 | atol=bin_width) 226 | 227 | widths_50 = peak_widths('/a/', 50) 228 | widths_100 = peak_widths('/a/', 100) 229 | # This is not a good test, as the main peaks do get a bit wider, but the 230 | # important change is that the sidelobes, around the formant peak, get 231 | # bigger. That is harder to test. 232 | np.testing.assert_array_less(widths_50, widths_100) 233 | 234 | def test_vowel_error(self): 235 | # Looking for a crash for one specific pitch. Found. 236 | pat.MakeVowel(1024, 232.02860207189255, 8192, 'a', bw=50) 237 | 238 | def test_spectrogram(self): 239 | fs = 22050 240 | t = np.arange(fs)/fs 241 | f0 = 900 242 | tone = np.sin(2*np.pi*f0*t) 243 | segsize = 128 244 | ntrans = 4 245 | nlap = 8 246 | spec = pat.Spectrogram(tone, segsize=segsize, nlap=nlap, ntrans=ntrans) 247 | self.assertEqual(spec.shape[0], segsize*ntrans//2) 248 | self.assertEqual(spec.shape[1], (len(tone)-segsize)//(segsize//nlap) + 1) 249 | 250 | profile = np.sum(spec, axis=1) 251 | self.assertEqual(np.argmax(profile), 252 | math.floor(f0 / (fs/(segsize*ntrans)) + 0.5)) 253 | 254 | if __name__ == '__main__': 255 | absltest.main() 256 | -------------------------------------------------------------------------------- /python_auditory_toolbox/auditory_toolbox_torch.py: -------------------------------------------------------------------------------- 1 | """A PyTorch port of portions of the Matlab Auditory Toolbox. 2 | """ 3 | import math 4 | from typing import List, Optional, Tuple 5 | import torch 6 | from torch import nn 7 | from torchaudio.functional import lfilter 8 | 9 | 10 | class ErbFilterBank(nn.Module): 11 | """Applies an Auditory Filterbank to data of dimension of `(..., time)` as 12 | described in the 'Auditory Toolbox - An Efficient Implementation of the 13 | Patterson-Holdsworth Auditory Filter Bank' by Malcolm Slaney, available on: 14 | 15 | 16 | Args: 17 | ---------- 18 | num_channels : int 19 | How many channels in the filterbank. Default: 64 20 | lowest_frequency : float 21 | The lowest center frequency of the filterbank. Default: 100. 22 | sampling_rate : float 23 | Sampling rate (in Hz) of the filterbank (needed to determine CFs). 24 | Default: 16000 25 | dtype: (Optional) 26 | Cast coefficients to dtype after instantiation. Default: None 27 | 28 | .. note:: 29 | The implementation does not attempt to account for filtering delays. 30 | Note also that uniform temporal sampling is assumed and that the data 31 | are not mean-centered or zero-padded prior to filtering. 32 | 33 | Examples: 34 | >>> fbank = ErbFilterBank(sampling_rate=16000) 35 | >>> fbank = m.to(device=torch.device("cpu"), dtype=torch.float32) 36 | >>> input = torch.zeros(1,512,dtype=torch.float32) 37 | >>> input[0,0 ] = 1. 38 | >>> output = m(input) 39 | 40 | Attributes: 41 | fcoefs: Filter coefficients generated by make_erb_filters 42 | sos: Coefficients used for subsequent filtering. 43 | """ 44 | __constants__ = ['sampling_rate', 'num_channels', 'lowest_frequency'] 45 | 46 | def __init__( 47 | self, 48 | sampling_rate: float = 16000., 49 | num_channels: int = 64, 50 | lowest_frequency: float = 100., 51 | dtype: Optional[torch.dtype] = None, 52 | ) -> None: 53 | super().__init__() 54 | if sampling_rate <= 0: 55 | raise ValueError('Sampling rate cannoy be negative or zero') 56 | if lowest_frequency <= 0 or lowest_frequency >= sampling_rate/2: 57 | raise ValueError('Misspecified lowest frequency') 58 | 59 | self.sampling_rate = sampling_rate 60 | self.num_channels = num_channels 61 | self.lowest_frequency = lowest_frequency 62 | self.fcoefs = make_erb_filters(self.sampling_rate, 63 | self.num_channels, 64 | self.lowest_frequency) 65 | if dtype: 66 | sos = prepare_coefficients(self.fcoefs).to(dtype=dtype) 67 | else: 68 | sos = prepare_coefficients(self.fcoefs) 69 | 70 | self.register_buffer('sos', sos) 71 | 72 | def forward(self, x: torch.Tensor) -> torch.Tensor: 73 | """Pass audio through the set of filters. 74 | 75 | The code is directly adapted from: 'Auditory Toolbox - An Efficient 76 | Implementation of the Patterson-Holdsworth Auditory Filter Bank' by 77 | Malcolm Slaney. 78 | 79 | Parameters 80 | ---------- 81 | x : torch.Tensor 82 | The input signal of dimension (..., time). 83 | 84 | Returns 85 | ------- 86 | y : torch.Tensor: 87 | Output signal of dimension (..., num_channels, time). 88 | 89 | """ 90 | if x.ndim < 2: 91 | raise TypeError('The input tensor should have size `(..., time)`') 92 | if x.shape[-1] <= 1: 93 | raise TypeError('The input tensor should have size `(..., time)`') 94 | new_dims = [1 for j in list(x.unsqueeze(-2).shape)] 95 | new_dims[-2] = self.num_channels 96 | 97 | y = lfilter(x.unsqueeze(-2).tile(new_dims), 98 | self.sos[..., -1], 99 | self.sos[..., 0], 100 | clamp=False, batching=True) 101 | 102 | for j in range(3): 103 | y = lfilter(y, 104 | self.sos[..., -1], 105 | self.sos[..., j+1], 106 | clamp=False, batching=True) 107 | 108 | return y 109 | 110 | 111 | def erb_space(low_freq: float = 100, 112 | high_freq: float = 44100/4, 113 | n: int = 100) -> torch.Tensor: 114 | """Compute frequencies uniformly spaced on an erb scale. 115 | 116 | The code is directly adapted from: 'Auditory Toolbox - An Efficient 117 | Implementation of the Patterson-Holdsworth Auditory Filter Bank' by 118 | Malcolm Slaney. 119 | 120 | For a definition of erb, see Moore, B. C. J., and Glasberg, B. R. (1983). 121 | "Suggested formulae for calculating auditory-filter bandwidths and 122 | excitation patterns," J. Acoust. Soc. Am. 74, 750-753. 123 | 124 | 125 | Parameters 126 | ---------- 127 | low_freq : float 128 | The center frequency in Hz of the lowest channel. The default is 100. 129 | high_freq : float 130 | The upper limit in Hz of the channel bank. The center frequency 131 | of the highest channel will be below this frequency. 132 | n : int 133 | Number of channels. The default is 100. 134 | 135 | Returns 136 | ------- 137 | cf_array : torch.Tensor 138 | An array of center frequencies, equally spaced on the ERB scale. 139 | 140 | """ 141 | # Change the following three parameters if you wish to use a different 142 | # erb scale. Must change in MakeerbCoeffs too. 143 | ear_q = 9.26449 # Glasberg and Moore Parameters 144 | min_bw = 24.7 145 | 146 | # All of the follow_freqing expressions are derived in Apple TR #35, "An 147 | # Efficient Implementation of the Patterson-Holdsworth Cochlear 148 | # Filter Bank." See pages 33-34. 149 | cf_array = (-(ear_q*min_bw) + torch.exp( 150 | torch.arange(1, 1+n, dtype=torch.float64).unsqueeze(1) * 151 | (-math.log(high_freq + ear_q*min_bw) + 152 | math.log(low_freq + ear_q*min_bw))/n) * (high_freq + ear_q*min_bw)) 153 | return cf_array 154 | 155 | 156 | def make_erb_filters(fs: float, num_channels: int, 157 | low_freq: float) -> List[torch.Tensor]: 158 | """Compute filter coefficients for a bank of Gammatone filters. 159 | 160 | The code is directly adapted from: 'Auditory Toolbox - An Efficient 161 | Implementation of the Patterson-Holdsworth Auditory Filter Bank' by 162 | Malcolm Slaney. 163 | 164 | The filter bank contains "num_channels" channels that extend from 165 | half the sampling rate (fs) to "low_freq". Alternatively, if the 166 | num_channels argument is a vector, then the values of this vector are taken 167 | to be the center frequency of each desired filter. 168 | 169 | 170 | Parameters 171 | ---------- 172 | fs : float 173 | Sampling rate (in Hz) of the filterbank (needed to determine CFs). 174 | num_channels : int or list of floats 175 | How many channels in the filterbank. 176 | low_freq : float 177 | The lowest center frequency of the filterbank. 178 | 179 | Returns 180 | ------- 181 | fcoefs : List[torch.Tensor] 182 | A list of 10 num_channel-D arrays containing the filter parameters. 183 | 184 | """ 185 | t = 1/fs 186 | if isinstance(num_channels, int): 187 | cf = erb_space(low_freq, fs/2, num_channels) 188 | else: 189 | cf = num_channels 190 | 191 | # Change the follow_freqing three parameters if you wish to use a different 192 | # erb scale. Must change in ErbSpace too. 193 | ear_q = 9.26449 # Glasberg and Moore Parameters 194 | min_bw = 24.7 195 | order = 1 196 | 197 | erb = ((cf/ear_q)**order + min_bw**order)**(1/order) 198 | 199 | b = 1.019*2*math.pi*erb 200 | 201 | a11 = -(2 * t * torch.cos(2 * cf * math.pi * t) / torch.exp(b * t) + 2 * 202 | math.sqrt(3 + 2**1.5) * t * torch.sin(2 * cf * math.pi * t) / 203 | torch.exp(b * t)) / 2 204 | a12 = -(2 * t * torch.cos(2 * cf * math.pi * t) / torch.exp(b * t) - 2 * 205 | math.sqrt(3 + 2**1.5) * t * torch.sin(2 * cf * math.pi * t) / 206 | torch.exp(b * t)) / 2 207 | a13 = -(2 * t * torch.cos(2 * cf * math.pi * t) / torch.exp(b * t) + 2 * 208 | math.sqrt(3 - 2**1.5) * t * torch.sin(2 * cf * math.pi * t) / 209 | torch.exp(b * t)) / 2 210 | a14 = -(2 * t * torch.cos(2 * cf * math.pi * t) / torch.exp(b * t) - 2 * 211 | math.sqrt(3 - 2**1.5) * t * torch.sin(2 * cf * math.pi * t) / 212 | torch.exp(b * t)) / 2 213 | 214 | gain = torch.abs((-2*torch.exp(4*complex(0, 1)*cf*math.pi*t)*t + 215 | 2*torch.exp(-(b*t) + 2*complex(0, 1)*cf*math.pi*t)*t * 216 | (torch.cos(2*cf*math.pi*t) - math.sqrt(3 - 2**(3/2)) * 217 | torch.sin(2*cf*math.pi*t))) * 218 | (-2*torch.exp(4*complex(0, 1)*cf*math.pi*t)*t + 219 | 2*torch.exp(-(b*t) + 2*complex(0, 1)*cf*math.pi*t)*t * 220 | (torch.cos(2*cf*math.pi*t) + math.sqrt(3 - 2**(3/2)) * 221 | torch.sin(2*cf*math.pi*t))) * 222 | (-2*torch.exp(4*complex(0, 1)*cf*math.pi*t)*t + 223 | 2*torch.exp(-(b*t) + 2*complex(0, 1)*cf*math.pi*t)*t * 224 | (torch.cos(2*cf*math.pi*t) - 225 | math.sqrt(3 + 2**(3/2))*torch.sin(2*cf*math.pi*t))) * 226 | (-2*torch.exp(4*complex(0, 1)*cf*math.pi*t)*t + 227 | 2*torch.exp(-(b*t) + 2*complex(0, 1)*cf*math.pi*t)*t * 228 | (torch.cos(2*cf*math.pi*t) + 229 | math.sqrt(3 + 2**(3/2))*torch.sin(2*cf*math.pi*t))) / 230 | (-2 / torch.exp(2*b*t) - 231 | 2*torch.exp(4*complex(0, 1)*cf*math.pi*t) + 232 | 2*(1 + torch.exp(4*complex(0, 1)*cf*math.pi*t)) / 233 | torch.exp(b*t))**4) 234 | 235 | fcoefs = [t * torch.ones(len(cf), 1, dtype=torch.float64), 236 | a11, a12, a13, a14, 237 | 0 * torch.ones(len(cf), 1, dtype=torch.float64), 238 | 1 * torch.ones(len(cf), 1, dtype=torch.float64), 239 | -2*torch.cos(2*cf*math.pi*t)/torch.exp(b*t), 240 | torch.exp(-2*b*t), 241 | gain] 242 | 243 | return fcoefs 244 | 245 | 246 | 247 | 248 | def prepare_coefficients(fcoefs: List[torch.Tensor]) -> torch.Tensor: 249 | r"""Reassemble filter coefficients to realize filters. 250 | 251 | Parameters 252 | ---------- 253 | fcoefs : List[torch.Tensor] 254 | Coefficients prepared by make_erb_filters. 255 | 256 | Returns 257 | ------- 258 | sos : torch.Tensor 259 | Reassembled coefficients. 260 | 261 | """ 262 | [a0, a11, a12, a13, a14, a2, b0, b1, b2, gain] = fcoefs 263 | n_chan = a0.shape[0] 264 | assert n_chan == a11.shape[0] 265 | assert n_chan == a12.shape[0] 266 | assert n_chan == a13.shape[0] 267 | assert n_chan == a14.shape[0] 268 | assert n_chan == b0.shape[0] 269 | assert n_chan == b1.shape[0] 270 | assert n_chan == gain.shape[0] 271 | 272 | sos = torch.cat([ 273 | torch.cat([a0/gain, a0, a0, a0, b0], dim=1).unsqueeze(1), 274 | torch.cat([a11/gain, a12, a13, a14, b1], dim=1).unsqueeze(1), 275 | torch.cat([a2/gain, a2, a2, a2, b2], dim=1).unsqueeze(1), 276 | ], dim=1) 277 | 278 | return sos 279 | 280 | 281 | def make_vowel(sample_len: int, 282 | pitch: float, 283 | sample_rate: float, 284 | f, 285 | bw=50) -> torch.Tensor: 286 | """Synthesize an artificial vowel using formant filters. 287 | 288 | The code is directly adapted from MakeVowel by Malcolm Slaney 289 | 290 | Make a vowel with "sample_len" samples and the given pitch. The 291 | formant frequencies are f1, f2 & f3. Some common vowels are 292 | Vowel f1 f2 f3 293 | /a/ 730 1090 2440 294 | /i/ 270 2290 3010 295 | /u/ 300 870 2240 296 | 297 | The pitch variable can either be a scalar indicating the actual 298 | pitch frequency, or an array of impulse locations. Using an 299 | array of impulses allows this routine to compute vowels with 300 | varying pitch. 301 | 302 | Alternatively, f1 can be replaced with one of the following strings 303 | 'a', 'i', 'u' and the appropriate formant frequencies are 304 | automatically selected. 305 | 306 | Parameters 307 | ---------- 308 | sample_len : int 309 | How many samples to generate 310 | pitch : float 311 | Either a single floating point value indidcating a constant 312 | pitch (in Hz), or a train of impulses generated by fm_points. 313 | sample_rate : float 314 | The sample rate for the output signal (Hz) 315 | f : string or list 316 | Either a vowel spec, one of '/a/', '/i/', or '/u' or a list of 317 | (f1, f2, f3) where: 318 | f1: Is the frequency of the first formant. 319 | f2: Optional 2nd formant frequency 320 | f3: Optional 3rd formant frequency 321 | bw : width (in Hz) of the forman filter 322 | 323 | Returns 324 | ------- 325 | y : torch.Tensor 326 | Waveform 327 | 328 | """ 329 | f1, f2, f3 = 0., 0., 0. # Keep Lint happy by setting defaults first. 330 | if isinstance(f, str): 331 | if f in ['a', '/a/']: 332 | f1, f2, f3 = (730, 1090, 2440) 333 | elif f in ['i', '/i/']: 334 | f1, f2, f3 = (270, 2290, 3010) 335 | elif f in ['u', '/u/']: 336 | f1, f2, f3 = (300, 870, 2240) 337 | elif isinstance(f, list) and len(f) == 3: 338 | f1, f2, f3 = f[0], f[1], f[2] 339 | elif isinstance(f, list) and len(f) == 2: 340 | f1, f2 = f[0], f[1] 341 | f3 = 0. 342 | elif isinstance(f, list) and len(f) == 1: 343 | f1 = f[0] 344 | f2 = 0. 345 | f3 = 0. 346 | # GlottalPulses(pitch, fs, sample_len) - Generate a stream of 347 | # glottal pulses with the given pitch (in Hz) and sampling 348 | # frequency (sample_rate). A vector of the requested length is 349 | # returned. 350 | y = torch.zeros(sample_len, dtype=torch.float64) 351 | if isinstance(pitch, (int, float)): 352 | points = torch.arange(0, sample_len-1, sample_rate / 353 | pitch, dtype=torch.float64) 354 | else: 355 | points = torch.sort(torch.as_tensor(pitch, dtype=torch.float64))[0] 356 | points = points[points < sample_len-1] 357 | 358 | indices = torch.floor(points).to(torch.int16) 359 | 360 | # Use a triangular approximation to an impulse function. The important 361 | # part is to keep the total amplitude the same. 362 | y[(indices).tolist()] = (indices+1)-points 363 | y[(indices+1).tolist()] = points-indices 364 | 365 | # GlottalFilter(x,fs) - Filter an impulse train and simulate the glottal 366 | # transfer function. The sampling interval (sample_rate) is given in Hz. 367 | # The filtering performed by this function is two first-order filters 368 | # at 250Hz. 369 | y = glottal_filter(sample_rate, y) 370 | 371 | # FormantFilter - Filter an input sequence to model one 372 | # formant in a speech signal. The formant frequency (in Hz) is given 373 | # by f and the bandwidth of the formant is a constant 50Hz. The 374 | # sampling frequency in Hz is given by fs. 375 | if f1 > 0: 376 | y = formant_filter(f1, sample_rate, y, bw) 377 | 378 | if f2 > 0: 379 | y = formant_filter(f2, sample_rate, y, bw) 380 | 381 | if f3 > 0: 382 | y = formant_filter(f3, sample_rate, y, bw) 383 | 384 | return y 385 | 386 | 387 | def glottal_filter(sample_rate, x): 388 | """Glottal filter""" 389 | a = math.exp(-250*2*math.pi/sample_rate) 390 | return lfilter(x, torch.tensor([1, 0, -a*a], dtype=torch.float64), 391 | torch.tensor([1, 0, 0], dtype=torch.float64), clamp=False) 392 | 393 | 394 | def formant_filter(f, sample_rate, x, bw): 395 | """Filter with a formant filter.""" 396 | cft = f/sample_rate 397 | q = f/bw 398 | rho = math.exp(-math.pi * cft / q) 399 | theta = 2 * math.pi * cft * math.sqrt(1-1/(4 * q*q)) 400 | a2 = -2*rho*math.cos(theta) 401 | a3 = rho*rho 402 | a_coeffs = torch.tensor([1, a2, a3], dtype=torch.float64) 403 | b_coeffs = torch.tensor([1+a2+a3, 0, 0], dtype=torch.float64) 404 | return lfilter(x, a_coeffs, b_coeffs, clamp=False) 405 | 406 | 407 | def fm_points(sample_len: int, 408 | freq: float, 409 | fm_freq: float = 6., 410 | fm_amp: float = None, 411 | sampling_rate: float = 22050.) -> torch.Tensor: 412 | """Generate impulse train corresponding to a vibrato. 413 | 414 | The code is directly adapted from FMPoints by Malcolm Slaney 415 | 416 | Basic formula: phase angle = 2*pi*freq*t + 417 | (fm_amp/fm_freq)*sin(2*pi*fm_freq*t) 418 | k-th zero crossing approximately at sample number 419 | (fs/freq)*(k - (fm_amp/(2*pi*fm_freq))*sin(2*pi*k*(fm_freq/freq))) 420 | 421 | Parameters 422 | ---------- 423 | sample_len : int 424 | How much data to generate, in samples 425 | freq : float 426 | Base frequency of the output signal (Hz) 427 | fm_freq : float 428 | Vibrato frequency (in Hz) 429 | fm_amp : float 430 | Magnitude of the FM deviation (in Hz) 431 | sampling_rate : float 432 | Sample rate for the output signal. 433 | 434 | Returns 435 | ------- 436 | y : torch.Tensor 437 | An impulse train, indicating the positive-going zero crossing 438 | of the phase funcion. 439 | 440 | """ 441 | 442 | if fm_amp is None: 443 | fm_amp = 0.05*freq 444 | 445 | kmax = int(math.floor(freq*(sample_len/sampling_rate))) 446 | points = torch.arange(kmax, dtype=torch.float64) 447 | 448 | # The following is shifted back by one sample relative to FMPoints.m in the 449 | # Matlab toolbox. 450 | y = (sampling_rate/freq)*(points-( 451 | fm_amp/(2*math.pi*fm_freq))*torch.sin(2*math.pi*(fm_freq/freq)*points)) 452 | 453 | return y 454 | 455 | 456 | 457 | def correlogram_frame(data: torch.Tensor, pic_width: int, 458 | start: int = 0, win_len: int = 0, 459 | dtype: Optional[torch.dtype] = torch.float64, 460 | ) -> torch.Tensor: 461 | """Generate one frame of a correlogram using FFTs to compute autocorrelation. 462 | 463 | Example: 464 | ---------- 465 | import torch 466 | import math 467 | c = torch.zeros(20,256,dtype=torch.float64) 468 | for j in torch.arange(20,0,-1): 469 | t = torch.arange(1,257,dtype=torch.float64) 470 | c[j-1,:] = torch.nn.ReLU()(torch.sin(t/256*(21-j)*3*2*math.pi)) 471 | picture = correlogram_frame(c,128,0,256) 472 | 473 | 474 | Parameters 475 | ---------- 476 | data : torch.Tensor 477 | A (num_channel x time) or (..., num_channel x time) array of input 478 | waveforms, one time domain signal per channel. 479 | pic_width : int 480 | Number of pixels (time lags) in the final correlogram frame. 481 | start : int 482 | The starting sample 483 | win_len : int 484 | How much data to take from the input signal when computing the 485 | autocorrelation. 486 | dtype : Optional[torch.dtype], optional 487 | The default is torch.float64. 488 | 489 | Returns 490 | ------- 491 | pic : torch.Tensor 492 | An array of size (num_channels x pic_width) containing one 493 | frame of the correlogram. If input has size (..., num_channel x time) then 494 | output will be of size (..., num_channels x pic_width). 495 | 496 | """ 497 | input_dimensions = list(data.shape) 498 | data_len = input_dimensions[-1] 499 | if not win_len: 500 | win_len = data_len 501 | 502 | # Round up to double the window size, and then the next power of 2. 503 | fft_size = int(2**(math.ceil(math.log2(2*max(pic_width, win_len))))) 504 | 505 | start = max(0, start) 506 | last = min(data_len, start+win_len) 507 | 508 | # Generate a window that is win_len long 509 | a = .54 510 | b = -.46 511 | wr = math.sqrt(64/256) 512 | phi = math.pi/win_len 513 | ws = 2*wr/math.sqrt(4*a*a+2*b*b)*( 514 | a + b*torch.cos(2*math.pi*(torch.arange(win_len, dtype=dtype))/win_len 515 | + phi)) 516 | 517 | # Intialize output 518 | output_dimensions = list(data.shape) 519 | output_dimensions[-1] = fft_size 520 | f = torch.zeros(output_dimensions, dtype=dtype) 521 | 522 | f[..., :(last-start)] = data[..., start:last] * ws[:(last-start)] 523 | # pylint: disable=not-callable 524 | f = torch.fft.fft(f, axis=-1) 525 | # pylint: disable=not-callable 526 | f = torch.fft.ifft(f * torch.conj(f), axis=-1) 527 | 528 | # Output pic 529 | pic = torch.maximum(torch.tensor(0.0), torch.real(f[..., :pic_width])) 530 | 531 | # Make sure first column is bigger than the rest 532 | good_rows = torch.logical_and((pic[..., 0] > 0), 533 | torch.logical_and((pic[..., 0] > pic[..., 1]), 534 | (pic[..., 0] > pic[..., 2]))) 535 | 536 | # Define that pic is normalized by sqrt(pic[...,0]). Define further that 537 | # zero entries and bad rows are masked out. 538 | norm_factor = torch.zeros_like(pic) 539 | norm_factor[good_rows] = 1./torch.sqrt(pic[good_rows][...,[0]]) 540 | pic = pic * norm_factor 541 | 542 | return pic 543 | 544 | 545 | 546 | def correlogram_array(data: torch.Tensor, sampling_rate: float, 547 | frame_rate: int = 12, width: int = 256, 548 | dtype: Optional[torch.dtype] = torch.float64, 549 | ) -> torch.Tensor: 550 | """Generate an array of correlogram frames. 551 | 552 | Parameters 553 | ---------- 554 | data : torch.Tensor 555 | The filterbank's output, size (num_channel x time) or 556 | (..., num_channel x time) 557 | sampling_rate : float 558 | The sample rate for the data (needed when computing the frame times) 559 | frame_rate : int 560 | How often (in Hz) correlogram frames should be generated. 561 | width: int 562 | The width (in lags) of the correlogram 563 | dtype : Optional[torch.dtype], optional 564 | The default is torch.float64. 565 | 566 | Returns 567 | ------- 568 | movie : torch.Tensor 569 | A (num_frames x num_channels x width) tensor or a 570 | (..., num_frames x num_channels x width) tensor of correlogram frames. 571 | 572 | """ 573 | if data.ndim==2: 574 | data = data.unsqueeze(-2) 575 | sample_len = data.shape[-1] 576 | frame_increment = int(sampling_rate/frame_rate) 577 | frame_count = int((sample_len-width)/frame_increment) + 1 578 | 579 | movie = [] 580 | for i in range(frame_count): 581 | start = i*frame_increment 582 | frame = correlogram_frame(data, 583 | pic_width = width, 584 | start = start, 585 | win_len = frame_increment*4, 586 | dtype = dtype).unsqueeze(-3) 587 | movie.append(frame) 588 | movie = torch.cat(movie,dim=-3) 589 | return movie 590 | 591 | 592 | 593 | 594 | def correlogram_pitch(correlogram: torch.Tensor, 595 | width: int = 256, 596 | sr: float = 22254.54, 597 | low_pitch: float = 0., 598 | high_pitch: float = 20000., 599 | dtype: Optional[torch.dtype] = torch.float64, 600 | ) -> Tuple[torch.Tensor,torch.Tensor]: 601 | """Compute the summary of a correlogram to find the pitch. 602 | 603 | Computes the pitch of a correlogram sequence by finding the time lag 604 | with the largest correlation energy. 605 | 606 | The correlogram_pitch function uses optional low_pitch and high_pitch 607 | arguments to limit the range of legal pitch values. It is important to 608 | note that correlogram_pitch do not include any other higher-level knowledge 609 | about pitch. Notably, this work does not enforce any frame-to-frame 610 | continuity in the pitch. Each pitch estimate is independent and there 611 | is no restriction preventing the estimate to change instantaneously from 612 | frame to frame. 613 | 614 | Parameters 615 | ---------- 616 | correlogram : torch.Tensor 617 | A 3D correlogram array, output from correlogram_array of size 618 | (num_frames x num_channels x num_times) 619 | width : int 620 | Width of the correlogram. Historical parameter. Should be 621 | equal to correlogram.shape[1]. The default is 256 622 | sr : float 623 | The sample rate. The default is 22254.54. 624 | low_pitch : float 625 | Lowest allowable pitch (Hz). Pitch peaks are only searched 626 | within the region low_pitch to high_pitch. The default is 0. 627 | high_pitch : float 628 | The default is 20000.. 629 | dtype : Optional[torch.dtype], optional 630 | The default is torch.float64. 631 | 632 | Raises 633 | ------ 634 | TypeError 635 | The input data has be of size (num_frames x num_channels x num_times). 636 | 637 | Returns 638 | ------- 639 | pitch : torch.Tensor 640 | A one-dimensional tensor of length num_frames indicating the pitch 641 | or 0 if no pitch is found. 642 | salience : torch.Tensor 643 | A one-dimensional tensor indicating the pitch salience on a scale 644 | from 0 (no pitch found) to 1 clear pitch. 645 | 646 | """ 647 | if not correlogram.ndim == 3: 648 | raise TypeError('Input should be (num_frames x num_channels x num_times)') 649 | 650 | drop_low = int(sr/high_pitch) 651 | if low_pitch > 0: 652 | drop_high = int(min(width, math.ceil(sr/low_pitch))) 653 | else: 654 | drop_high = width 655 | 656 | frames = correlogram.shape[-3] 657 | 658 | pitch = torch.zeros(frames, dtype=dtype) 659 | salience = torch.zeros(frames, dtype=dtype) 660 | for j in range(frames): 661 | # Get one frame from the correlogram and compute 662 | # the sum (as a function of time lag) across all channels. 663 | summary = torch.sum(correlogram[j, :, :], axis=0) 664 | zero_lag = torch.sum(correlogram[j, :, :], axis=0)[0] 665 | # Now we need to find the first pitch past the peak at zero 666 | # lag. The following lines smooth the summary pitch a bit, then 667 | # look for the first point where the summary goes back up. 668 | # Everything up to this point is zeroed out. 669 | window_length = 16 670 | b_coefs = torch.ones(window_length, dtype=dtype) 671 | a_coefs = torch.zeros(window_length, dtype=dtype) 672 | a_coefs[0] = 1. 673 | sumfilt = lfilter(summary, a_coefs, b_coefs, clamp=False, batching=True) 674 | 675 | sumdif = sumfilt[..., 1:width] - sumfilt[..., :width-1] 676 | sumdif[:window_length] = 0 677 | valleys = torch.argwhere(sumdif > 0) 678 | summary[:int(valleys[0, 0])] = 0 679 | summary[1:drop_low] = 0 680 | summary[drop_high:] = 0 681 | 682 | # Now find the location of the biggest peak and call this the pitch 683 | p = torch.argmax(summary) 684 | if p > 0: 685 | pitch[j] = sr/float(p) 686 | 687 | salience[j] = summary[p]/zero_lag 688 | 689 | return pitch, salience 690 | -------------------------------------------------------------------------------- /python_auditory_toolbox/auditory_toolbox_torch_test.py: -------------------------------------------------------------------------------- 1 | """Code to test the auditory toolbox.""" 2 | 3 | from absl.testing import absltest 4 | import torch 5 | import torch.fft 6 | import numpy as np 7 | import auditory_toolbox_torch as pat 8 | 9 | class AuditoryToolboxTests(absltest.TestCase): 10 | """Test cases for the filterbank.""" 11 | 12 | def test_erb_space(self): 13 | """Test ERB space.""" 14 | low_freq = 100.0 15 | high_freq = 44100/4.0 16 | num_channels = 100 17 | cf_array = pat.erb_space(low_freq=low_freq, high_freq=high_freq, 18 | n=num_channels) 19 | cf_array = cf_array.numpy().squeeze() 20 | self.assertLen(cf_array, num_channels) 21 | # Make sure low and high CF's are where we expect them to be. 22 | self.assertAlmostEqual(cf_array[-1], low_freq) 23 | self.assertLess(cf_array[0], high_freq) 24 | 25 | def test_make_erb_filters(self): 26 | """Test fcoefs""" 27 | # Ten channel ERB Filterbank. Make sure return has the right size. 28 | # Will test coefficients when we test the filterbank. 29 | fs = 16000 30 | low_freq = 100 31 | num_chan = 10 32 | fcoefs = pat.make_erb_filters(fs, num_chan, low_freq) 33 | self.assertLen(fcoefs, 10) 34 | # Test all the filter coefficient array shapes 35 | a0, a11, a12, a13, a14, a2, b0, b1, b2, gain = fcoefs 36 | self.assertLen(fcoefs, 10) 37 | self.assertEqual(a0.numpy().shape, (num_chan, 1)) 38 | self.assertEqual(a11.numpy().shape, (num_chan, 1)) 39 | self.assertEqual(a12.numpy().shape, (num_chan, 1)) 40 | self.assertEqual(a13.numpy().shape, (num_chan, 1)) 41 | self.assertEqual(a14.numpy().shape, (num_chan, 1)) 42 | self.assertEqual(a2.numpy().shape, (num_chan, 1)) 43 | self.assertEqual(b0.numpy().shape, (num_chan, 1)) 44 | self.assertEqual(b1.numpy().shape, (num_chan, 1)) 45 | self.assertEqual(b2.numpy().shape, (num_chan, 1)) 46 | self.assertEqual(gain.numpy().shape, (num_chan, 1)) 47 | 48 | 49 | def test_erb_filterbank_peaks(self): 50 | """Test peaks.""" 51 | 52 | 53 | impulse_len = 512 54 | x = torch.zeros(1, impulse_len, dtype=torch.float64) 55 | x[:, 0] = 1.0 56 | 57 | fbank = pat.ErbFilterBank(sampling_rate=16000, 58 | num_channels=10, 59 | lowest_frequency=100) 60 | y = fbank(x).numpy() 61 | 62 | self.assertEqual(y.shape, (1, 10, impulse_len)) 63 | self.assertAlmostEqual(np.max(y), 0.10657410, delta=0.01) 64 | 65 | resp = 20 * np.log10(np.abs(np.fft.fft(y, axis=-1))) 66 | resp = resp.squeeze() 67 | 68 | # Test to make sure spectral peaks are in the right place for each channel 69 | matlab_peak_locs = [184, 132, 94, 66, 46, 32, 21, 14, 8, 4] 70 | python_peak_locs = np.argmax(resp[:, :impulse_len // 2], axis=-1) 71 | np.testing.assert_equal(matlab_peak_locs, python_peak_locs+1) 72 | 73 | self.assertEqual(resp.shape, torch.Size([10, 512])) 74 | self.assertEqual(list(python_peak_locs+1), matlab_peak_locs) 75 | 76 | matlab_out_peak_locs = [12, 13, 23, 32, 46, 51, 77, 122, 143, 164] 77 | python_out_peak_locs = np.argmax(y.squeeze(), axis=-1) 78 | self.assertEqual(list(python_out_peak_locs + 1), matlab_out_peak_locs) 79 | 80 | def test_fm_points(self): 81 | """Test fm points""" 82 | base_pitch = 160 83 | sample_rate = 16000 84 | fmfreq = 10 85 | fmamp = 20 86 | points = pat.fm_points(100000, base_pitch, fmfreq, fmamp, 16000) 87 | 88 | # Make sure the average glottal pulse locations is 1 over the pitch 89 | d_points = points[1:] - points[:-1] 90 | d_points = d_points.numpy() 91 | self.assertAlmostEqual(np.mean(d_points),sample_rate/base_pitch, delta=1) 92 | 93 | def test_make_vowel(self): 94 | """Test make vowels.""" 95 | 96 | def local_peaks(x): 97 | i = np.argwhere(np.logical_and(x[:-2] < x[1:-1], 98 | x[2:] < x[1:-1])) + 1 99 | return [j[0] for j in i] 100 | 101 | test_seq = local_peaks(np.array([1, 2, 3, 2, 1, 1, 2, 2, 3, 4, 1])) 102 | np.testing.assert_equal(test_seq, np.array([2, 9])) 103 | 104 | def vowel_peaks(vowel): 105 | """Find the frequencies of the spectral peaks.""" 106 | sample_rate = 16000 107 | vowel = pat.make_vowel(1024, [1,], sample_rate, vowel) 108 | vowel = vowel.numpy() 109 | spectrum = 20*np.log10(np.abs(np.fft.fft(vowel))) 110 | freqs = np.arange(len(vowel))*sample_rate/len(vowel) 111 | return freqs[local_peaks(spectrum)[:3]] 112 | 113 | # Make sure the spectrum of each vowel has peaks in the right spots. 114 | bin_width = 16000/1024 115 | np.testing.assert_allclose(vowel_peaks('a'), 116 | np.array([730, 1090, 2440]), 117 | atol=bin_width) 118 | np.testing.assert_allclose(vowel_peaks('i'), 119 | np.array([270, 2290, 3010]), 120 | atol=bin_width) 121 | np.testing.assert_allclose(vowel_peaks('u'), 122 | np.array([300, 870, 2240]), 123 | atol=bin_width) 124 | 125 | def test_erb_filterbank_output_shapes(self): 126 | """Test output shapes.""" 127 | x1 = torch.zeros(64, 512, dtype=torch.float64) 128 | x1[:, 0] = 1.0 129 | x2 = torch.zeros(1, 512, dtype=torch.float64) 130 | x2[0, 0] = 1.0 131 | 132 | fbank = pat.ErbFilterBank(sampling_rate=16000, 133 | num_channels=10, 134 | lowest_frequency=100) 135 | 136 | y1 = fbank(x1).numpy() 137 | y2 = fbank(x2).numpy() 138 | 139 | assert np.isclose(y1,y2).all() 140 | self.assertEqual(list(y1.shape), [64, 10, 512]) 141 | self.assertEqual(list(y2.shape), [1, 10, 512]) 142 | self.assertAlmostEqual(np.abs(y1-y2).mean(), 0.0) 143 | 144 | x = torch.zeros(5, 2, 3, 10000, dtype=torch.float64) 145 | y = fbank(x) 146 | self.assertEqual(list(y.shape), [5, 2, 3, 10, 10000]) 147 | 148 | def test_erb_filterbank_dtype(self): 149 | """Test data type.""" 150 | x = torch.rand(5, 2, 3, 1000, dtype=torch.float32) 151 | fbank = pat.ErbFilterBank(sampling_rate=44100, 152 | num_channels=10, 153 | lowest_frequency=100) 154 | fbank.to(dtype=torch.float32) 155 | y = fbank(x) 156 | self.assertEqual(list(y.shape), [5, 2, 3, 10, 1000]) 157 | 158 | fbank = pat.ErbFilterBank(sampling_rate=44100, 159 | num_channels=10, 160 | lowest_frequency=100, 161 | dtype = torch.float32) 162 | y = fbank(x) 163 | self.assertEqual(list(y.shape), [5, 2, 3, 10, 1000]) 164 | 165 | def test_erb_filterbank_num_channels(self): 166 | """Test shapes with different number of channels.""" 167 | x = torch.randn(64, 512, dtype=torch.float64) 168 | 169 | fbank1 = pat.ErbFilterBank(sampling_rate=16000, 170 | num_channels=10, 171 | lowest_frequency=100) 172 | fbank2 = pat.ErbFilterBank(sampling_rate=16000, 173 | num_channels=32, 174 | lowest_frequency=100) 175 | fbank3 = pat.ErbFilterBank(sampling_rate=16000, 176 | num_channels=64, 177 | lowest_frequency=100) 178 | 179 | self.assertEqual(list(fbank1(x).shape), [64, 10, 512]) 180 | self.assertEqual(list(fbank2(x).shape), [64, 32, 512]) 181 | self.assertEqual(list(fbank3(x).shape), [64, 64, 512]) 182 | 183 | def test_make_vowels_peaks_i(self): 184 | """Test peaks /i/""" 185 | wav_len = 8000 186 | loc = np.zeros(4, dtype=np.int16) 187 | for p, pitch in enumerate([50., 100., 512., 1024.]): 188 | y = pat.make_vowel(wav_len, pitch, sample_rate=16000, f='i').numpy() 189 | y = y - np.mean(y) 190 | y_fft = np.fft.fft(y) 191 | loc[p] = np.argmax(20*np.log10(np.abs(y_fft[:wav_len//2]))) 192 | 193 | self.assertEqual(list(loc+1), [126, 151, 257, 1537]) 194 | 195 | def test_make_vowels_peaks_u(self): 196 | """Test peaks /u/""" 197 | wav_len = 8000 198 | loc = np.zeros(4, dtype=np.int16) 199 | for p, pitch in enumerate([50., 100., 512., 1024.]): 200 | y = pat.make_vowel(wav_len, pitch, sample_rate=16000, f='u').numpy() 201 | y = y - np.mean(y) 202 | y_fft = np.fft.fft(y) 203 | loc[p] = np.argmax(20*np.log10(np.abs(y_fft[:wav_len//2]))) 204 | 205 | self.assertEqual(list(loc+1), [151, 151, 257, 513]) 206 | 207 | def test_make_vowels_peaks_a(self): 208 | """Test peaks /a/""" 209 | wav_len = 8000 210 | loc = np.zeros(4, dtype=np.int16) 211 | for p, pitch in enumerate([50., 100., 512., 1024.]): 212 | y = pat.make_vowel(wav_len, pitch, sample_rate=16000, f='a').numpy() 213 | y = y - np.mean(y) 214 | y_fft = np.fft.fft(y) 215 | loc[p] = np.argmax(20*np.log10(np.abs(y_fft[:wav_len//2]))) 216 | 217 | self.assertEqual(list(loc+1), [376, 351, 513, 513]) 218 | 219 | def test_make_vowels_bw(self): 220 | # Need to write tests. 221 | pass 222 | 223 | def test_correlogram_array(self): 224 | """Test correlogram_frame.""" 225 | def local_peaks(x): 226 | i = np.argwhere(np.logical_and(x[:-2] < x[1:-1], 227 | x[2:] < x[1:-1])) + 1 228 | return [j[0] for j in i] 229 | 230 | test_impulses = torch.zeros((1, 1024), dtype=torch.float64) 231 | test_impulses[0, range(0, test_impulses.shape[1], 100)] = 1 232 | 233 | test_frame = pat.correlogram_frame(test_impulses, 256, 0, 0) 234 | locs = list(torch.where(test_frame > 0.1)[1]) 235 | self.assertEqual(locs, [0, 100, 200]) 236 | 237 | # Now test with cochlear input to correlogram 238 | impulse_len = 512 239 | 240 | fbank = pat.ErbFilterBank(sampling_rate=16000, 241 | num_channels=64, 242 | lowest_frequency=100) 243 | 244 | # Make harmonic input signal 245 | s = 0 246 | pitch_lag = 200 247 | for h in range(1, 10): 248 | t_vec = torch.arange(impulse_len,dtype=torch.float64) 249 | s = s + torch.sin(2*torch.pi*t_vec/pitch_lag*h) 250 | s = s.unsqueeze(0) 251 | y = fbank(s) 252 | 253 | frame_width = 256 254 | frame = pat.correlogram_frame(y, frame_width) 255 | 256 | self.assertEqual(frame.shape, (1, 64, frame_width)) 257 | 258 | # Make sure the top channels have no output. 259 | spectral_profile = torch.sum(frame, dim=-1) 260 | no_output = torch.where(spectral_profile < 2)[-1] 261 | self.assertEqual(list(no_output.numpy()),list(np.arange(31))) 262 | 263 | # Make sure we have spectral peaks at the right locations 264 | spectral_peaks = local_peaks(spectral_profile.numpy()[0]) 265 | self.assertEqual(spectral_peaks, [42, 44, 46, 48, 50, 53, 56, 60]) 266 | 267 | # Make sure the first peak (after 0 lag) is at the pitch lag 268 | summary_correlogram = torch.sum(frame.squeeze(0), 0) 269 | skip_lags = 100 270 | self.assertEqual(torch.argmax(summary_correlogram[skip_lags:]).numpy() + 271 | skip_lags, 272 | pitch_lag) 273 | 274 | def test_correlogram_pitch(self): 275 | """Test correlogram_pitch.""" 276 | sample_len = 20000 277 | sample_rate = 22254 278 | pitch_center = 120 279 | u = pat.make_vowel(sample_len, pat.fm_points(sample_len, pitch_center), 280 | sample_rate, 'u') 281 | u = u.unsqueeze(0) 282 | low_freq = 60 283 | num_chan = 100 284 | 285 | fbank = pat.ErbFilterBank(sampling_rate=sample_rate, 286 | num_channels=num_chan, 287 | lowest_frequency=low_freq) 288 | 289 | 290 | 291 | coch = fbank(u) 292 | cor = pat.correlogram_array(coch,sample_rate,50,256) 293 | 294 | 295 | cor = cor[0] 296 | [pitch,sal] = pat.correlogram_pitch(cor, 256, sample_rate,100,200) 297 | 298 | # Make sure center and overall pitch deviation are as expected. 299 | self.assertAlmostEqual(torch.mean(pitch).numpy(), pitch_center, delta=2) 300 | self.assertAlmostEqual(torch.min(pitch).numpy(), pitch_center-6, delta=2) 301 | self.assertAlmostEqual(torch.max(pitch).numpy(), pitch_center+6, delta=2) 302 | np.testing.assert_array_less(0.8, sal.numpy()[:40]) 303 | 304 | # Now test salience when we add noise 305 | grid = torch.arange(sample_len,dtype=torch.float64) 306 | n = torch.randn(sample_len, dtype=torch.float64)*grid/sample_len 307 | un=u + n/4 308 | 309 | low_freq = 60 310 | num_chan = 100 311 | 312 | fbank2 = pat.ErbFilterBank(sampling_rate=sample_rate, 313 | num_channels=num_chan, 314 | lowest_frequency=low_freq) 315 | 316 | 317 | coch = fbank2(un) 318 | cor = pat.correlogram_array(coch,sample_rate,50,256) 319 | # Remove first dim 320 | cor = cor[0] 321 | [pitch,sal] = pat.correlogram_pitch(cor,256,22254,100,200) 322 | 323 | sal = sal.numpy() 324 | 325 | 326 | # Avoid scipy dependency 327 | design = np.ones((len(sal),2)) 328 | design[:,1] = np.arange(len(sal)) 329 | lr = np.linalg.lstsq(design,sal[:,None],rcond=None) 330 | r_value = np.corrcoef(design[:,1],sal)[0,1] 331 | 332 | self.assertAlmostEqual(lr[0][1][0], -0.012, delta=0.01) 333 | self.assertAlmostEqual(r_value, -0.963, delta=0.03) 334 | 335 | # lr = scipy.stats.linregress(range(len(sal)), y=sal, alternative='less') 336 | # self.assertAlmostEqual(lr.slope, -0.012, delta=0.01) 337 | # self.assertAlmostEqual(lr.rvalue, -0.963, delta=0.03) 338 | 339 | if __name__ == '__main__': 340 | absltest.main() 341 | -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/CorrelogramPitchExample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/CorrelogramPitchExample.png -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/DudaTones.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/DudaTones.wav -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/DudaVowelsCorrelogram.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/DudaVowelsCorrelogram.mp4 -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/GammatoneFilterResponse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/GammatoneFilterResponse.png -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/LeonPitch.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/LeonPitch.wav -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/LeonVowels.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/LeonVowels.wav -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/TapestryFilterbank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/TapestryFilterbank.png -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/TapestryGammatoneFeatures.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/TapestryGammatoneFeatures.png -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/TapestryReconstruction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/TapestryReconstruction.png -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/TapestrySpectrogram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/TapestrySpectrogram.png -------------------------------------------------------------------------------- /python_auditory_toolbox/examples/tapestry.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MalcolmSlaney/python_auditory_toolbox/ead46d835117047a4c5e01634358f54dc4937ad9/python_auditory_toolbox/examples/tapestry.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | jax 3 | jaxlib 4 | matplotlib 5 | numpy 6 | scipy 7 | torch 8 | torchaudio 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The python_auditory_toolbox Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2023 Google Inc. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | # ============================================================================== 29 | 30 | # Malcolm Notes: 31 | # Cleanup: remove previous dist/* and update version # below. 32 | # To create distributions: python3 -m build 33 | # To upload distribution: python3 -m twine upload dist/* 34 | # Use __token__ as the User name and an API key for the password 35 | # Probably need to generate a new API token at pypi.org/manage/account/token 36 | # To test in a new environment: conda create --name test 37 | # Then: conda activate test 38 | # Then: pip install python_auditory_toolbox 39 | # Then: python3 40 | # from python_auditory_toolbox import auditory_toolbox_jax as pat 41 | # pat.ErbSpace() 42 | 43 | """Create the python_auditory_toolbox package files. 44 | """ 45 | 46 | import setuptools 47 | 48 | with open('README.md', 'r', encoding='utf-8') as fh: 49 | long_description = fh.read() 50 | 51 | setuptools.setup( 52 | name='python_auditory_toolbox', 53 | version='1.0.5', 54 | author='Malcolm Slaney', 55 | author_email='malcolm@ieee.org', 56 | description='Several simple auditory models in JAX, Numpy and Torch', 57 | long_description=long_description, 58 | long_description_content_type='text/markdown', 59 | url='https://github.com/MalcolmSlaney/python_auditory_toolbox', 60 | packages=['python_auditory_toolbox'], 61 | classifiers=[ 62 | 'Programming Language :: Python :: 3', 63 | 'License :: OSI Approved :: Apache Software License', 64 | 'Operating System :: OS Independent', 65 | 'Topic :: Multimedia :: Sound/Audio :: Analysis', 66 | ], 67 | python_requires='>=3.6', 68 | install_requires=[ 69 | 'absl-py', 70 | 'numpy', 71 | 'jax', 72 | 'jaxlib', 73 | 'matplotlib', 74 | 'scipy', 75 | 'torch', 76 | 'torchaudio', 77 | ], 78 | include_package_data=True, # Using the files specified in MANIFEST.in 79 | ) 80 | --------------------------------------------------------------------------------