├── .gitattributes ├── .github └── workflows │ ├── requirements-dev.txt │ └── static_checks.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── .reuse └── dep5 ├── CONTRIBUTING.md ├── LICENSE.md ├── LICENSES └── Apache-2.0.md ├── README.md ├── copy_files_to_espnet.sh ├── egs2 ├── dns_ins20 │ └── enh1 │ │ ├── conf │ │ └── tuning │ │ │ └── train_enh_tflocoformer.yaml │ │ └── exp │ │ └── enh_train_enh_tflocoformer_raw │ │ ├── RESULTS.md │ │ ├── config.yaml │ │ ├── images │ │ ├── l1_timedomain+magspec_loss.png │ │ ├── loss.png │ │ └── si_snr_loss.png │ │ └── valid.loss.ave_5best.pth ├── librimix │ └── enh1 │ │ ├── conf │ │ └── tuning │ │ │ └── train_enh_tflocoformer.yaml │ │ ├── exp │ │ └── enh_train_enh_tflocoformer_raw │ │ │ ├── RESULTS.md │ │ │ ├── config.yaml │ │ │ ├── images │ │ │ ├── loss.png │ │ │ └── si_snr_loss.png │ │ │ └── valid.loss.ave_5best.pth │ │ └── local │ │ └── data.patch ├── whamr │ └── enh1 │ │ ├── conf │ │ └── tuning │ │ │ ├── train_enh_tflocoformer.yaml │ │ │ └── train_enh_tflocoformer_small.yaml │ │ ├── exp │ │ ├── enh_train_enh_tflocoformer_raw │ │ │ ├── RESULTS.md │ │ │ ├── config.yaml │ │ │ ├── images │ │ │ │ ├── loss.png │ │ │ │ └── si_snr_loss.png │ │ │ └── valid.loss.ave_5best.pth │ │ └── enh_train_enh_tflocoformer_small_raw │ │ │ ├── RESULTS.md │ │ │ ├── config.yaml │ │ │ ├── images │ │ │ ├── loss.png │ │ │ └── si_snr_loss.png │ │ │ └── valid.loss.ave_5best.pth │ │ └── local │ │ └── whamr_data_prep.patch └── wsj0_2mix │ └── enh1 │ ├── conf │ └── tuning │ │ └── train_enh_tflocoformer.yaml │ ├── exp │ └── enh_train_enh_tflocoformer_raw │ │ ├── RESULTS.md │ │ ├── config.yaml │ │ ├── images │ │ ├── loss.png │ │ └── si_snr_loss.png │ │ └── valid.loss.ave_5best.pth │ └── separate.py ├── espnet2 ├── enh │ └── separator │ │ └── tflocoformer_separator.py └── tasks │ └── enh.patch └── requirements.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | *.pth filter=lfs diff=lfs merge=lfs -text 6 | -------------------------------------------------------------------------------- /.github/workflows/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | pre-commit 6 | -------------------------------------------------------------------------------- /.github/workflows/static_checks.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | name: Static code checks 6 | 7 | on: # yamllint disable-line rule:truthy 8 | pull_request: 9 | push: 10 | branches: 11 | - '**' 12 | tags-ignore: 13 | - '**' 14 | 15 | env: 16 | LICENSE: Apache-2.0 17 | FETCH_DEPTH: 1 18 | FULL_HISTORY: 0 19 | SKIP_WORD_PRESENCE_CHECK: 0 20 | 21 | jobs: 22 | static-code-check: 23 | if: endsWith(github.event.repository.name, 'private') 24 | 25 | name: Run static code checks 26 | runs-on: ubuntu-latest 27 | defaults: 28 | run: 29 | shell: bash -l {0} 30 | 31 | steps: 32 | - name: Setup history 33 | if: github.ref == 'refs/heads/oss' 34 | run: | 35 | echo "FETCH_DEPTH=0" >> $GITHUB_ENV 36 | echo "FULL_HISTORY=1" >> $GITHUB_ENV 37 | 38 | - name: Setup version 39 | if: github.ref == 'refs/heads/melco' 40 | run: | 41 | echo "SKIP_WORD_PRESENCE_CHECK=1" >> $GITHUB_ENV 42 | 43 | - name: Check out code 44 | uses: actions/checkout@v4 45 | with: 46 | fetch-depth: ${{ env.FETCH_DEPTH }} # '0' to check full history 47 | 48 | - name: Set up environment 49 | run: git config user.email github-bot@merl.com 50 | 51 | - name: Set up python 52 | uses: actions/setup-python@v5 53 | with: 54 | python-version: '3.10' 55 | cache: 'pip' 56 | cache-dependency-path: '.github/workflows/requirements-dev.txt' 57 | 58 | - name: Install python packages 59 | run: pip install -r .github/workflows/requirements-dev.txt 60 | 61 | - name: Ensure lint and pre-commit steps have been run 62 | uses: pre-commit/action@v3.0.1 63 | 64 | - name: Check files 65 | uses: merl-oss-private/merl-file-check-action@v1 66 | with: 67 | license: ${{ env.LICENSE }} 68 | full-history: ${{ env.FULL_HISTORY }} # If true, use fetch-depth 0 above 69 | skip-word-presence-check: ${{ env.SKIP_WORD_PRESENCE_CHECK }} 70 | 71 | - name: Check license compatibility 72 | if: github.ref != 'refs/heads/melco' 73 | uses: merl-oss-private/merl_license_compatibility_checker@v1 74 | with: 75 | input-filename: requirements.txt 76 | license: ${{ env.LICENSE }} 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/#use-with-ide 114 | .pdm.toml 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | 166 | # Experiment 167 | *.ckpt 168 | *.wav 169 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Pre-commit configuration. See https://pre-commit.com 6 | 7 | default_language_version: 8 | python: python3 9 | 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v4.4.0 13 | hooks: 14 | - id: end-of-file-fixer 15 | - id: trailing-whitespace 16 | - id: check-yaml 17 | - id: check-added-large-files 18 | args: ['--maxkb=5000'] 19 | 20 | - repo: https://gitlab.com/bmares/check-json5 21 | rev: v1.0.0 22 | hooks: 23 | - id: check-json5 24 | 25 | - repo: https://github.com/homebysix/pre-commit-macadmin 26 | rev: v1.12.3 27 | hooks: 28 | - id: check-git-config-email 29 | args: ['--domains', 'merl.com'] 30 | 31 | - repo: https://github.com/psf/black 32 | rev: 22.12.0 33 | hooks: 34 | - id: black 35 | args: 36 | - --line-length=120 37 | 38 | - repo: https://github.com/pycqa/isort 39 | rev: 5.12.0 40 | hooks: 41 | - id: isort 42 | args: ["--profile", "black", "--filter-files", "--line-length", "120", "--skip-gitignore"] 43 | 44 | # Uncomment to use pyupgrade (https://github.com/asottile/pyupgrade) to automatically upgrade syntax for newer python 45 | # - repo: https://github.com/asottile/pyupgrade 46 | # rev: v3.3.1 47 | # hooks: 48 | # - id: pyupgrade 49 | 50 | # To stop flake8 error from causing a failure, use --exit-zero. By default, pre-commit will not show the warnings, 51 | # so use verbose: true to see them. 52 | - repo: https://github.com/pycqa/flake8 53 | rev: 5.0.4 54 | hooks: 55 | - id: flake8 56 | # Black compatibility, Eradicate options 57 | args: ["--max-line-length=120", "--extend-ignore=E203", 58 | "--eradicate-whitelist-extend", "eradicate:\\s*no", 59 | "--exit-zero"] 60 | verbose: true 61 | additional_dependencies: [ 62 | # https://github.com/myint/eradicate, https://github.com/wemake-services/flake8-eradicate 63 | "flake8-eradicate" 64 | ] 65 | -------------------------------------------------------------------------------- /.reuse/dep5: -------------------------------------------------------------------------------- 1 | Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ 2 | 3 | Files: *.png 4 | Copyright: 2024 Mitsubishi Electric Research Laboratories (MERL) 5 | License: Apache-2.0 6 | 7 | Files: *.pth 8 | Copyright: 2024 Mitsubishi Electric Research Laboratories (MERL) 9 | License: Apache-2.0 10 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 6 | # Contributing 7 | 8 | Sorry, but we do not currently accept contributions in the form of pull requests to this repository. 9 | However, you are welcome to post issues (bug reports, feature requests, questions, etc). 10 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2024 Mitsubishi Electric Research Laboratories (MERL) 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /LICENSES/Apache-2.0.md: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2024 Mitsubishi Electric Research Laboratories (MERL) 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | # TF-Locoformer: Transformer with Local Modeling by Convolution for Speech Separation and Enhancement 8 | 9 | This repository includes source code of the TF-Locoformer model proposed in the following paper: 10 | 11 | ``` 12 | @InProceedings{Saijo2024_TFLoco, 13 | author = {Saijo, Kohei and Wichern, Gordon and Germain, Fran\c{c}ois G. and Pan, Zexu and {Le Roux}, Jonathan}, 14 | title = {TF-Locoformer: Transformer with Local Modeling by Convolution for Speech Separation and Enhancement}, 15 | booktitle = {Proc. International Workshop on Acoustic Signal Enhancement (IWAENC)}, 16 | year = 2024, 17 | month = sep 18 | } 19 | ``` 20 | 21 | ## Table of contents 22 | 23 | 1. [Environmental setup: Installing ESPnet from source and injecting the TF-Locoformer code](#environmental-setup-installing-espnet-from-source-and-injecting-the-tf-locoformer-code) 24 | 2. [Using a pre-trained model](#using-a-pre-trained-model) 25 | 3. [Example of training and inference](#example-of-training-and-inference) 26 | 4. [Instructions for running training on each dataset in the ESPnet pipeline (Librimix, WHAMR! and DNS)](#instructions-for-running-training-on-each-dataset-in-the-espnet-pipeline) 27 | 5. [Contributing](#contributing) 28 | 6. [Copyright and license](#copyright-and-license) 29 | 30 | ## Environmental setup: Installing ESPnet from source and injecting the TF-Locoformer code 31 | 32 | In this repo, we provide the code for TF-Locoformer along with scripts to run training and inference in ESPnet. 33 | The following commands install ESPnet from source and copy the TF-Locoformer code to the appropriate directories in ESPnet. 34 | 35 | For more details on installing ESPnet, please refer to https://espnet.github.io/espnet/installation.html. 36 | 37 | ```sh 38 | # Clone espnet code. 39 | git clone https://github.com/espnet/espnet.git 40 | 41 | # Checkout the commit where we tested our code. 42 | cd ./espnet && git checkout 90eed8e53498e7af682bc6ff39d9067ae440d6a4 43 | 44 | # Set up conda environment. 45 | # ./setup_anaconda /path/to/conda environment-name python-version 46 | cd ./tools && ./setup_anaconda.sh /path/to/conda tflocoformer 3.10.8 47 | 48 | # Install espnet from source with other dependencies. We used torch 2.1.0 and cuda 11.8. 49 | # NOTE: torch version must be 2.x.x for other dependencies. 50 | make TH_VERSION=2.1.0 CUDA_VERSION=11.8 51 | 52 | # Install the RoPE package. 53 | conda activate tflocoformer && pip install rotary-embedding-torch==0.6.1 54 | 55 | # Copy the TF-Locoformer code to ESPnet. 56 | # NOTE: ./copy_files_to_espnet.sh changes `espnet2/tasks/enh.py`. Please be careful when using your existing ESPnet environment. 57 | cd ../../ && git clone https://github.com/merlresearch/tf-locoformer.git && cd tf_locoformer 58 | ./copy_files_to_espnet.sh /path/to/espnet-root 59 | ``` 60 | 61 | ## Using a pre-trained model 62 | 63 | This repo supports speech separation/enhancement on 4 datasets: 64 | 65 | - WSJ0-2mix (`egs2/wsj0_2mix/enh1`) 66 | - Libri2mix (`egs2/librimix/enh1`) 67 | - WHAMR! (`egs2/whamr/enh1`) 68 | - DNS-Interspeech2020 dataset (`egs2/dns_ins20/enh1`) 69 | 70 | In each `egs2` directory, you can find the pre-trained model under the `exp` directory. 71 | 72 | One can easily use the pre-trained model to separate an audio mixture as follows: 73 | 74 | ```sh 75 | # assuming you are now at ./egs2/wsj0_2mix/enh1 76 | python separate.py \ 77 | --model_path ./exp/enh_train_enh_tflocoformer_pretrained/valid.loss.ave_5best.pth \ 78 | --audio_path /path/to/input_audio \ 79 | --audio_output_dir /path/to/output_directory 80 | ``` 81 | 82 | ## Example of training and inference 83 | 84 | Here are example commands to run the WSJ0-2mix recipe. 85 | Other dataset recipes are similar, but require additional steps (refer to the next section). 86 | 87 | ```sh 88 | # Go to the corresponding example directory. 89 | cd ../espnet/egs2/wsj0_2mix/enh1 90 | 91 | # Data preparation and stats collection if necessary. 92 | # NOTE: please fill the corresponding part of db.sh for data preparation. 93 | ./run.sh --stage 1 --stop_stage 5 94 | 95 | # Training. We used 4 GPUs for training (batch size was 1 on each GPU; GPU RAM depends on dataset). 96 | ./run.sh --stage 6 --stop_stage 6 --enh_config conf/tuning/train_enh_tflocoformer.yaml --ngpu 4 97 | 98 | # Inference. 99 | ./run.sh --stage 7 --stop_stage 7 --enh_config conf/tuning/train_enh_tflocoformer.yaml --ngpu 1 --gpu_inference true --inference_model valid.loss.ave_5best.pth 100 | 101 | # Scoring. Scores are written in RESULT.md. 102 | ./run.sh --stage 8 --stop_stage 8 --enh_config conf/tuning/train_enh_tflocoformer.yaml 103 | ``` 104 | 105 | ## Instructions for running training on each dataset in the ESPnet pipeline 106 | 107 | Some recipe changes are required to run the experiments as in the paper. 108 | After finishing the processes below, you can run the recipe in a normal way as described above. 109 | 110 | ### WHAMR! 111 | 112 | First, please install pyroomacoustics: `pip install pyroomacoustics==0.2.0`. 113 | 114 | The default task in ESPnet is noisy reverberant *speech enhancement without dereverberation* (using mix_single_reverb subset), while we did noisy reverberant *speech separation with dereverberation*. 115 | To do the same task as in the paper, please run the following commands in `egs2/whamr/enh1`: 116 | 117 | ```sh 118 | # Speech enhancement task -> speech separation task. 119 | sed -i '13,15s|single|both|' run.sh 120 | sed -i '23s|1|2|' run.sh 121 | 122 | # Modify the url of the WHAM! noise and the WHAMR! script. 123 | sed -i '42s|.*| wham_noise_url=https://my-bucket-a8b4b49c25c811ee9a7e8bba05fa24c7.s3.amazonaws.com/wham_noise.zip|' local/whamr_create_mixture.sh 124 | sed -i '52s|.*|script_url=https://my-bucket-a8b4b49c25c811ee9a7e8bba05fa24c7.s3.amazonaws.com/whamr_scripts.tar.gz|' local/whamr_create_mixture.sh 125 | 126 | cd local && patch -b < whamr_data_prep.patch && cd .. 127 | ``` 128 | 129 | Then, you can start running the recipe from stage 1. 130 | 131 | ### Libri2mix 132 | 133 | The default task in ESPnet is *noisy speech separation*, while we did *noise-free speech separation*. 134 | To do the same task as in the paper, run the following commands in `egs2/librimix/enh1`: 135 | 136 | ```sh 137 | # Apply the patch file to data.sh. 138 | cd local && patch -b < data.patch && cd .. 139 | 140 | # Use only train-360. By default, both train-100 and train-360 are used. 141 | sed -i '12s|"train"|"train-360"|' run.sh 142 | 143 | # Noisy separation -> noise-free separation. 144 | sed -i '17s|true|false|' run.sh 145 | 146 | # Data preparation in the "clean" condition (noise-free separation). 147 | ./run.sh --stage 1 --stop_stage 5 --local_data_opts "--sample_rate 8k --min_or_max min --cond clean" 148 | ``` 149 | 150 | ### DNS interspeech2020 dataset 151 | 152 | In the paper, we simulated 3000 hours of noisy speech: 2700 h for training and 300 h for validation. 153 | To reproduce the paper's result, run the following commands in `egs2/dns_ins20/enh1`: 154 | 155 | ```sh 156 | sed -i '18s|.*|total_hours=3000|' local/dns_create_mixture.sh 157 | sed -i '19s|.*|snr_lower=-5|' local/dns_create_mixture.sh 158 | sed -i '20s|.*|snr_upper=15|' local/dns_create_mixture.sh 159 | ``` 160 | 161 | We recommend reducing the size of the validation data to save training time since the validation loop with 300 h takes a very long time. 162 | 163 | ## Contributing 164 | 165 | See [CONTRIBUTING.md](CONTRIBUTING.md) for our policy on contributions. 166 | 167 | ## Copyright and license 168 | 169 | Released under Apache-2.0 license, as found in the [LICENSE.md](LICENSE.md) file. 170 | 171 | All files, except as noted below: 172 | 173 | ``` 174 | Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 175 | 176 | SPDX-License-Identifier: Apache-2.0 177 | ``` 178 | 179 | The following patch files: 180 | 181 | - `espnet2/tasks/enh.patch` 182 | - `egs2/librimix/local/data.patch` 183 | - `egs2/whamr/local/whamr_data_prep.patch` 184 | 185 | include code from (license included in [LICENSES/Apache-2.0.md](LICENSES/Apache-2.0.md)) 186 | 187 | ``` 188 | Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 189 | Copyright (C) 2017 ESPnet Developers 190 | 191 | SPDX-License-Identifier: Apache-2.0 192 | SPDX-License-Identifier: Apache-2.0 193 | ``` 194 | -------------------------------------------------------------------------------- /copy_files_to_espnet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 3 | # 4 | # SPDX-License-Identifier: Apache-2.0 5 | 6 | espnet_root=$1 7 | cdir=${PWD} 8 | 9 | # Copy the TF-Locoformer code. 10 | cp ${cdir}/espnet2/enh/separator/tflocoformer_separator.py ${espnet_root}/espnet2/enh/separator 11 | 12 | # Apply the patch file to enh.py to import TF-Locoformer. 13 | # enh.py is directly modified by patching and the original enh.py is saved as enh.py.orig. 14 | # .orig file can be deleted if it is not necessary. 15 | cd ${espnet_root}/espnet2/tasks && patch -b < ${cdir}/espnet2/tasks/enh.patch && cd ${cdir} 16 | 17 | # Copy other files. 18 | for dset in wsj0_2mix whamr librimix dns_ins20; do 19 | # Copy separate.py on each egs2 20 | cp ${cdir}/egs2/wsj0_2mix/enh1/separate.py ${espnet_root}/egs2/${dset}/enh1 21 | 22 | # Copy the config file 23 | cp ${cdir}/egs2/${dset}/enh1/conf/tuning/train_enh_tflocoformer.yaml ${espnet_root}/egs2/${dset}/enh1/conf/tuning 24 | 25 | # Copy the pre-trained model 26 | mkdir -p ${espnet_root}/egs2/${dset}/enh1/exp 27 | cp -r ${cdir}/egs2/${dset}/enh1/exp/enh_train_enh_tflocoformer_raw ${espnet_root}/egs2/${dset}/enh1/exp/enh_train_enh_tflocoformer_pretrained 28 | 29 | # whamr has small and medium models 30 | if [ $dset = "whamr" ]; then 31 | cp ${cdir}/egs2/${dset}/enh1/conf/tuning/train_enh_tflocoformer_small.yaml ${espnet_root}/egs2/${dset}/enh1/conf/tuning 32 | cp -r ${cdir}/egs2/${dset}/enh1/exp/enh_train_enh_tflocoformer_small_raw ${espnet_root}/egs2/${dset}/enh1/exp/enh_train_enh_tflocoformer_small_pretrained 33 | fi 34 | done 35 | 36 | # Copy patch files. 37 | cp ${cdir}/egs2/whamr/enh1/local/whamr_data_prep.patch ${espnet_root}/egs2/whamr/enh1/local 38 | cp ${cdir}/egs2/librimix/enh1/local/data.patch ${espnet_root}/egs2/librimix/enh1/local 39 | -------------------------------------------------------------------------------- /egs2/dns_ins20/enh1/conf/tuning/train_enh_tflocoformer.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | init: xavier_uniform 7 | max_epoch: 150 8 | use_amp: false 9 | batch_type: folded 10 | batch_size: 4 # batch size 4 on 4 Quadro RTX 6000 (24GiB) 11 | num_workers: 4 12 | 13 | # preprocessor 14 | preprocessor: enh 15 | num_spk: &num_spk 1 16 | iterator_type: chunk 17 | chunk_length: 64000 18 | sample_rate: 16000 19 | num_iters_per_epoch: 1000 20 | force_single_channel: true 21 | 22 | # espnet model configuration 23 | model_conf: 24 | normalize_variance: true 25 | 26 | # optimizer and scheduler 27 | optim: adamw 28 | optim_conf: 29 | lr: 1.0e-03 30 | eps: 1.0e-08 31 | weight_decay: 1.0e-02 32 | patience: 10 33 | val_scheduler_criterion: 34 | - valid 35 | - loss 36 | best_model_criterion: 37 | - - valid 38 | - si_snr 39 | - max 40 | - - valid 41 | - loss 42 | - min 43 | keep_nbest_models: 5 44 | scheduler: warmupreducelronplateau 45 | scheduler_conf: 46 | warmup_steps: 4000 47 | mode: min 48 | factor: 0.5 49 | patience: 3 50 | 51 | # model configuration 52 | encoder: &encoder stft 53 | encoder_conf: 54 | n_fft: &n_fft 256 55 | hop_length: &hop_length 128 56 | decoder: *encoder 57 | decoder_conf: 58 | n_fft: *n_fft 59 | hop_length: *hop_length 60 | separator: tflocoformer 61 | separator_conf: 62 | num_spk: *num_spk 63 | n_layers: 6 64 | # general setup 65 | emb_dim: 128 66 | norm_type: rmsgroupnorm 67 | num_groups: 4 68 | tf_order: ft 69 | # self-attention 70 | n_heads: 4 71 | flash_attention: false 72 | # ffn 73 | ffn_type: 74 | - swiglu_conv1d 75 | - swiglu_conv1d 76 | ffn_hidden_dim: 77 | - 384 78 | - 384 # list order must be the same as ffn_type 79 | conv1d_kernel: 4 80 | conv1d_shift: 1 81 | dropout: 0.0 82 | # others 83 | eps: 1.0e-5 84 | 85 | criterions: 86 | # The first criterion 87 | - name: mr_l1_tfd 88 | conf: 89 | window_sz: [256, 512, 768, 1024] 90 | time_domain_weight: 0.5 91 | reduction: sum 92 | eps: 1.0e-8 93 | wrapper: fixed_order 94 | wrapper_conf: 95 | weight: 1.0 96 | 97 | - name: si_snr 98 | conf: 99 | eps: 1.0e-7 100 | wrapper: fixed_order 101 | wrapper_conf: 102 | weight: 0.0 103 | -------------------------------------------------------------------------------- /egs2/dns_ins20/enh1/exp/enh_train_enh_tflocoformer_raw/RESULTS.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | # RESULTS 9 | ## Environments 10 | - date: `Tue Jun 25 15:59:55 EDT 2024` 11 | - python version: `3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0]` 12 | - espnet version: `espnet 202402` 13 | - pytorch version: `pytorch 2.1.0` 14 | - Git hash: `90eed8e53498e7af682bc6ff39d9067ae440d6a4` 15 | - Commit date: `Mon May 27 22:42:15 2024 -0700` 16 | 17 | 18 | ## enh_train_enh_tflocoformer_raw 19 | 20 | config: conf/tuning/train_enh_tflocoformer.yaml 21 | 22 | |dataset|STOI|SAR|SDR|SIR|SI_SNR| 23 | |---|---|---|---|---|---| 24 | |enhanced_tt_synthetic_no_reverb|98.79|23.35|23.35|0.00|23.23| 25 | |enhanced_tt_synthetic_with_reverb|83.19|13.17|13.17|0.00|10.97| 26 | -------------------------------------------------------------------------------- /egs2/dns_ins20/enh1/exp/enh_train_enh_tflocoformer_raw/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | config: conf/tuning/train_enh_tflocoformer.yaml 7 | print_config: false 8 | log_level: INFO 9 | drop_last_iter: false 10 | dry_run: false 11 | iterator_type: chunk 12 | valid_iterator_type: null 13 | output_dir: exp/enh_train_enh_tflocoformer_raw 14 | ngpu: 1 15 | seed: 0 16 | num_workers: 4 17 | num_att_plot: 3 18 | dist_backend: nccl 19 | dist_init_method: env:// 20 | dist_world_size: 4 21 | dist_rank: 0 22 | local_rank: 0 23 | dist_master_addr: localhost 24 | dist_master_port: 56555 25 | dist_launcher: null 26 | multiprocessing_distributed: true 27 | unused_parameters: false 28 | sharded_ddp: false 29 | cudnn_enabled: true 30 | cudnn_benchmark: false 31 | cudnn_deterministic: true 32 | collect_stats: false 33 | write_collected_feats: false 34 | max_epoch: 150 35 | patience: 10 36 | val_scheduler_criterion: 37 | - valid 38 | - loss 39 | early_stopping_criterion: 40 | - valid 41 | - loss 42 | - min 43 | best_model_criterion: 44 | - - valid 45 | - si_snr 46 | - max 47 | - - valid 48 | - loss 49 | - min 50 | keep_nbest_models: 5 51 | nbest_averaging_interval: 0 52 | grad_clip: 5.0 53 | grad_clip_type: 2.0 54 | grad_noise: false 55 | accum_grad: 1 56 | no_forward_run: false 57 | resume: true 58 | train_dtype: float32 59 | use_amp: false 60 | log_interval: null 61 | use_matplotlib: true 62 | use_tensorboard: true 63 | create_graph_in_tensorboard: false 64 | use_wandb: false 65 | wandb_project: null 66 | wandb_id: null 67 | wandb_entity: null 68 | wandb_name: null 69 | wandb_model_log_interval: -1 70 | detect_anomaly: false 71 | use_adapter: false 72 | adapter: lora 73 | save_strategy: all 74 | adapter_conf: {} 75 | pretrain_path: null 76 | init_param: [] 77 | ignore_init_mismatch: false 78 | freeze_param: [] 79 | num_iters_per_epoch: 1000 80 | batch_size: 4 81 | valid_batch_size: null 82 | batch_bins: 1000000 83 | valid_batch_bins: null 84 | train_shape_file: 85 | - exp/enh_stats_16k/train/speech_mix_shape 86 | - exp/enh_stats_16k/train/speech_ref1_shape 87 | valid_shape_file: 88 | - exp/enh_stats_16k/valid/speech_mix_shape 89 | - exp/enh_stats_16k/valid/speech_ref1_shape 90 | batch_type: folded 91 | valid_batch_type: null 92 | fold_length: 93 | - 80000 94 | - 80000 95 | sort_in_batch: descending 96 | shuffle_within_batch: false 97 | sort_batch: descending 98 | multiple_iterator: false 99 | chunk_length: 64000 100 | chunk_shift_ratio: 0.5 101 | num_cache_chunks: 1024 102 | chunk_excluded_key_prefixes: [] 103 | chunk_default_fs: null 104 | train_data_path_and_name_and_type: 105 | - - dump/raw/tr_synthetic/wav.scp 106 | - speech_mix 107 | - sound 108 | - - dump/raw/tr_synthetic/spk1.scp 109 | - speech_ref1 110 | - sound 111 | valid_data_path_and_name_and_type: 112 | - - dump/raw/cv_synthetic_small/wav.scp 113 | - speech_mix 114 | - sound 115 | - - dump/raw/cv_synthetic_small/spk1.scp 116 | - speech_ref1 117 | - sound 118 | allow_variable_data_keys: false 119 | max_cache_size: 0.0 120 | max_cache_fd: 32 121 | allow_multi_rates: false 122 | valid_max_cache_size: null 123 | exclude_weight_decay: false 124 | exclude_weight_decay_conf: {} 125 | optim: adamw 126 | optim_conf: 127 | lr: 0.001 128 | eps: 1.0e-08 129 | weight_decay: 0.01 130 | scheduler: warmupreducelronplateau 131 | scheduler_conf: 132 | warmup_steps: 4000 133 | mode: min 134 | factor: 0.5 135 | patience: 3 136 | init: xavier_uniform 137 | model_conf: 138 | normalize_variance: true 139 | criterions: 140 | - name: mr_l1_tfd 141 | conf: 142 | window_sz: 143 | - 256 144 | - 512 145 | - 768 146 | - 1024 147 | time_domain_weight: 0.5 148 | reduction: sum 149 | eps: 1.0e-08 150 | wrapper: fixed_order 151 | wrapper_conf: 152 | weight: 1.0 153 | - name: si_snr 154 | conf: 155 | eps: 1.0e-07 156 | wrapper: fixed_order 157 | wrapper_conf: 158 | weight: 0.0 159 | speech_volume_normalize: null 160 | rir_scp: null 161 | rir_apply_prob: 1.0 162 | noise_scp: null 163 | noise_apply_prob: 1.0 164 | noise_db_range: '13_15' 165 | short_noise_thres: 0.5 166 | use_reverberant_ref: false 167 | num_spk: 1 168 | num_noise_type: 1 169 | sample_rate: 16000 170 | force_single_channel: true 171 | channel_reordering: false 172 | categories: [] 173 | speech_segment: null 174 | avoid_allzero_segment: true 175 | flexible_numspk: false 176 | dynamic_mixing: false 177 | utt2spk: null 178 | dynamic_mixing_gain_db: 0.0 179 | encoder: stft 180 | encoder_conf: 181 | n_fft: 256 182 | hop_length: 128 183 | separator: tflocoformer 184 | separator_conf: 185 | num_spk: 1 186 | n_layers: 6 187 | emb_dim: 128 188 | norm_type: rmsgroupnorm 189 | num_groups: 4 190 | tf_order: ft 191 | n_heads: 4 192 | flash_attention: false 193 | ffn_type: 194 | - swiglu_conv1d 195 | - swiglu_conv1d 196 | ffn_hidden_dim: 197 | - 384 198 | - 384 199 | conv1d_kernel: 4 200 | conv1d_shift: 1 201 | dropout: 0.0 202 | eps: 1.0e-05 203 | decoder: stft 204 | decoder_conf: 205 | n_fft: 256 206 | hop_length: 128 207 | mask_module: multi_mask 208 | mask_module_conf: {} 209 | preprocessor: enh 210 | preprocessor_conf: {} 211 | diffusion_model: null 212 | diffusion_model_conf: {} 213 | required: 214 | - output_dir 215 | version: '202402' 216 | distributed: true 217 | -------------------------------------------------------------------------------- /egs2/dns_ins20/enh1/exp/enh_train_enh_tflocoformer_raw/images/l1_timedomain+magspec_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/dns_ins20/enh1/exp/enh_train_enh_tflocoformer_raw/images/l1_timedomain+magspec_loss.png -------------------------------------------------------------------------------- /egs2/dns_ins20/enh1/exp/enh_train_enh_tflocoformer_raw/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/dns_ins20/enh1/exp/enh_train_enh_tflocoformer_raw/images/loss.png -------------------------------------------------------------------------------- /egs2/dns_ins20/enh1/exp/enh_train_enh_tflocoformer_raw/images/si_snr_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/dns_ins20/enh1/exp/enh_train_enh_tflocoformer_raw/images/si_snr_loss.png -------------------------------------------------------------------------------- /egs2/dns_ins20/enh1/exp/enh_train_enh_tflocoformer_raw/valid.loss.ave_5best.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6ffa518f2b73bed289e38349366eb777c94b9130acc32bf4c085cb049d9d1666 3 | size 59972116 4 | -------------------------------------------------------------------------------- /egs2/librimix/enh1/conf/tuning/train_enh_tflocoformer.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | init: xavier_uniform 7 | max_epoch: 150 8 | use_amp: false 9 | batch_type: folded 10 | batch_size: 4 # batch size 4 on 4 RTX 2080Ti 11 | num_workers: 4 12 | 13 | # preprocessor 14 | preprocessor: enh 15 | num_spk: &num_spk 2 16 | iterator_type: sequence # not to discard short samples 17 | speech_segment: 32000 18 | shuffle_within_batch: true 19 | 20 | # espnet model configuration 21 | model_conf: 22 | normalize_variance: true 23 | 24 | # optimizer and scheduler 25 | optim: adamw 26 | optim_conf: 27 | lr: 1.0e-03 28 | eps: 1.0e-08 29 | weight_decay: 1.0e-02 30 | patience: 10 31 | val_scheduler_criterion: 32 | - valid 33 | - loss 34 | best_model_criterion: 35 | - - valid 36 | - si_snr 37 | - max 38 | - - valid 39 | - loss 40 | - min 41 | keep_nbest_models: 5 42 | scheduler: warmupreducelronplateau 43 | scheduler_conf: 44 | warmup_steps: 4000 45 | mode: min 46 | factor: 0.5 47 | patience: 3 48 | 49 | # model configuration 50 | encoder: &encoder stft 51 | encoder_conf: 52 | n_fft: &n_fft 128 53 | hop_length: &hop_length 64 54 | decoder: *encoder 55 | decoder_conf: 56 | n_fft: *n_fft 57 | hop_length: *hop_length 58 | separator: tflocoformer 59 | separator_conf: 60 | num_spk: *num_spk 61 | n_layers: 6 62 | # general setup 63 | emb_dim: 128 64 | norm_type: rmsgroupnorm 65 | num_groups: 4 66 | tf_order: ft 67 | # self-attention 68 | n_heads: 4 69 | flash_attention: false 70 | # ffn 71 | ffn_type: 72 | - swiglu_conv1d 73 | - swiglu_conv1d 74 | ffn_hidden_dim: 75 | - 384 76 | - 384 # list order must be the same as ffn_type 77 | conv1d_kernel: 4 78 | conv1d_shift: 1 79 | dropout: 0.0 80 | # others 81 | eps: 1.0e-5 82 | 83 | criterions: 84 | # The first criterion 85 | - name: si_snr 86 | conf: 87 | eps: 1.0e-7 88 | wrapper: pit 89 | wrapper_conf: 90 | weight: 1.0 91 | independent_perm: true 92 | -------------------------------------------------------------------------------- /egs2/librimix/enh1/exp/enh_train_enh_tflocoformer_raw/RESULTS.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | # RESULTS 8 | ## Environments 9 | - date: `Mon Jan 25 19:16:45 CST 2021` 10 | - python version: `3.6.3 |Anaconda, Inc.| (default, Nov 20 2017, 20:41:42) [GCC 7.2.0]` 11 | - espnet version: `espnet 0.9.7` 12 | - pytorch version: `pytorch 1.6.0` 13 | - Git hash: `dcaba2585e28b85c815807165ba9953565ee8694` 14 | - Commit date: `Thu Jan 21 21:26:59 2021 +0800` 15 | 16 | ## enh_train_raw 17 | - Model link: https://zenodo.org/record/4480771/files/enh_train_raw_valid.si_snr.ave.zip?download=1 18 | - config: ./conf/train.yaml 19 | - sample_rate: 8k 20 | - min_or_max: min 21 | 22 | |dataset|STOI|SAR|SDR|SIR| 23 | |---|---|---|---|---| 24 | |enhanced_dev|0.85|11.10|10.67|22.65| 25 | |enhanced_test|0.85|10.92|10.42|22.08| 26 | -------------------------------------------------------------------------------- /egs2/librimix/enh1/exp/enh_train_enh_tflocoformer_raw/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | config: conf/tuning/train_enh_tflocoformer.yaml 7 | print_config: false 8 | log_level: INFO 9 | drop_last_iter: false 10 | dry_run: false 11 | iterator_type: sequence 12 | valid_iterator_type: null 13 | output_dir: ./exp_train360/enh_train_enh_tflocoformer_raw 14 | ngpu: 1 15 | seed: 0 16 | num_workers: 4 17 | num_att_plot: 3 18 | dist_backend: nccl 19 | dist_init_method: env:// 20 | dist_world_size: 4 21 | dist_rank: 0 22 | local_rank: 0 23 | dist_master_addr: localhost 24 | dist_master_port: 47197 25 | dist_launcher: null 26 | multiprocessing_distributed: true 27 | unused_parameters: false 28 | sharded_ddp: false 29 | cudnn_enabled: true 30 | cudnn_benchmark: false 31 | cudnn_deterministic: true 32 | collect_stats: false 33 | write_collected_feats: false 34 | max_epoch: 150 35 | patience: 10 36 | val_scheduler_criterion: 37 | - valid 38 | - loss 39 | early_stopping_criterion: 40 | - valid 41 | - loss 42 | - min 43 | best_model_criterion: 44 | - - valid 45 | - si_snr 46 | - max 47 | - - valid 48 | - loss 49 | - min 50 | keep_nbest_models: 5 51 | nbest_averaging_interval: 0 52 | grad_clip: 5.0 53 | grad_clip_type: 2.0 54 | grad_noise: false 55 | accum_grad: 1 56 | no_forward_run: false 57 | resume: true 58 | train_dtype: float32 59 | use_amp: false 60 | log_interval: null 61 | use_matplotlib: true 62 | use_tensorboard: true 63 | create_graph_in_tensorboard: false 64 | use_wandb: false 65 | wandb_project: null 66 | wandb_id: null 67 | wandb_entity: null 68 | wandb_name: null 69 | wandb_model_log_interval: -1 70 | detect_anomaly: false 71 | use_adapter: false 72 | adapter: lora 73 | save_strategy: all 74 | adapter_conf: {} 75 | pretrain_path: null 76 | init_param: [] 77 | ignore_init_mismatch: false 78 | freeze_param: [] 79 | num_iters_per_epoch: null 80 | batch_size: 4 81 | valid_batch_size: null 82 | batch_bins: 1000000 83 | valid_batch_bins: null 84 | train_shape_file: 85 | - ./exp_train360/enh_stats_8k/train/speech_mix_shape 86 | - ./exp_train360/enh_stats_8k/train/speech_ref1_shape 87 | - ./exp_train360/enh_stats_8k/train/speech_ref2_shape 88 | - ./exp_train360/enh_stats_8k/train/noise_ref1_shape 89 | valid_shape_file: 90 | - ./exp_train360/enh_stats_8k/valid/speech_mix_shape 91 | - ./exp_train360/enh_stats_8k/valid/speech_ref1_shape 92 | - ./exp_train360/enh_stats_8k/valid/speech_ref2_shape 93 | - ./exp_train360/enh_stats_8k/valid/noise_ref1_shape 94 | batch_type: folded 95 | valid_batch_type: null 96 | fold_length: 97 | - 80000 98 | - 80000 99 | - 80000 100 | - 80000 101 | sort_in_batch: descending 102 | shuffle_within_batch: true 103 | sort_batch: descending 104 | multiple_iterator: false 105 | chunk_length: 500 106 | chunk_shift_ratio: 0.5 107 | num_cache_chunks: 1024 108 | chunk_excluded_key_prefixes: [] 109 | chunk_default_fs: null 110 | train_data_path_and_name_and_type: 111 | - - dump/raw/train-360_local/wav.scp 112 | - speech_mix 113 | - sound 114 | - - dump/raw/train-360_local/spk1.scp 115 | - speech_ref1 116 | - sound 117 | - - dump/raw/train-360_local/spk2.scp 118 | - speech_ref2 119 | - sound 120 | - - dump/raw/train-360_local/noise1.scp 121 | - noise_ref1 122 | - sound 123 | valid_data_path_and_name_and_type: 124 | - - dump/raw/dev_local/wav.scp 125 | - speech_mix 126 | - sound 127 | - - dump/raw/dev_local/spk1.scp 128 | - speech_ref1 129 | - sound 130 | - - dump/raw/dev_local/spk2.scp 131 | - speech_ref2 132 | - sound 133 | - - dump/raw/dev_local/noise1.scp 134 | - noise_ref1 135 | - sound 136 | allow_variable_data_keys: false 137 | max_cache_size: 0.0 138 | max_cache_fd: 32 139 | allow_multi_rates: false 140 | valid_max_cache_size: null 141 | exclude_weight_decay: false 142 | exclude_weight_decay_conf: {} 143 | optim: adamw 144 | optim_conf: 145 | lr: 0.001 146 | eps: 1.0e-08 147 | weight_decay: 0.01 148 | scheduler: warmupreducelronplateau 149 | scheduler_conf: 150 | warmup_steps: 4000 151 | mode: min 152 | factor: 0.5 153 | patience: 3 154 | init: xavier_uniform 155 | model_conf: 156 | normalize_variance: true 157 | criterions: 158 | - name: si_snr 159 | conf: 160 | eps: 1.0e-07 161 | wrapper: pit 162 | wrapper_conf: 163 | weight: 1.0 164 | independent_perm: true 165 | speech_volume_normalize: null 166 | rir_scp: null 167 | rir_apply_prob: 1.0 168 | noise_scp: null 169 | noise_apply_prob: 1.0 170 | noise_db_range: '13_15' 171 | short_noise_thres: 0.5 172 | use_reverberant_ref: false 173 | num_spk: 2 174 | num_noise_type: 1 175 | sample_rate: 8000 176 | force_single_channel: false 177 | channel_reordering: false 178 | categories: [] 179 | speech_segment: 32000 180 | avoid_allzero_segment: true 181 | flexible_numspk: false 182 | dynamic_mixing: false 183 | utt2spk: null 184 | dynamic_mixing_gain_db: 0.0 185 | encoder: stft 186 | encoder_conf: 187 | n_fft: 128 188 | hop_length: 64 189 | separator: tflocoformer 190 | separator_conf: 191 | num_spk: 2 192 | n_layers: 6 193 | emb_dim: 128 194 | norm_type: rmsgroupnorm 195 | num_groups: 4 196 | tf_order: ft 197 | n_heads: 4 198 | flash_attention: false 199 | ffn_type: 200 | - swiglu_conv1d 201 | - swiglu_conv1d 202 | ffn_hidden_dim: 203 | - 384 204 | - 384 205 | conv1d_kernel: 4 206 | conv1d_shift: 1 207 | dropout: 0.0 208 | eps: 1.0e-05 209 | decoder: stft 210 | decoder_conf: 211 | n_fft: 128 212 | hop_length: 64 213 | mask_module: multi_mask 214 | mask_module_conf: {} 215 | preprocessor: enh 216 | preprocessor_conf: {} 217 | diffusion_model: null 218 | diffusion_model_conf: {} 219 | required: 220 | - output_dir 221 | version: '202402' 222 | distributed: true 223 | -------------------------------------------------------------------------------- /egs2/librimix/enh1/exp/enh_train_enh_tflocoformer_raw/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/librimix/enh1/exp/enh_train_enh_tflocoformer_raw/images/loss.png -------------------------------------------------------------------------------- /egs2/librimix/enh1/exp/enh_train_enh_tflocoformer_raw/images/si_snr_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/librimix/enh1/exp/enh_train_enh_tflocoformer_raw/images/si_snr_loss.png -------------------------------------------------------------------------------- /egs2/librimix/enh1/exp/enh_train_enh_tflocoformer_raw/valid.loss.ave_5best.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9c691cee4bb0d9664a3fac024b031dfdf8ca3e66fc22a66f2f17119767b49129 3 | size 59979488 4 | -------------------------------------------------------------------------------- /egs2/librimix/enh1/local/data.patch: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # Copyright (C) 2017 ESPnet Developers 3 | # 4 | # SPDX-License-Identifier: Apache-2.0 5 | # SPDX-License-Identifier: Apache-2.0 6 | 7 | 8 | --- local/data.sh 2024-06-26 11:15:24.682654080 -0400 9 | +++ local/data.sh.new 2024-06-26 11:15:31.698605036 -0400 10 | @@ -30,6 +30,8 @@ min_or_max=max 11 | sample_rate=16k 12 | num_spk=2 13 | 14 | +cond=noisy # noisy or clean 15 | + 16 | stage=0 17 | stop_stage=100 18 | 19 | @@ -70,7 +72,7 @@ if [ ${stage} -le 0 ] && [ ${stop_stage} 20 | if [ -z "${wham_noise}" ]; then 21 | # 17.65 GB unzipping to 35 GB 22 | mkdir -p ${cdir}/data/wham_noise 23 | - wham_noise_url=https://storage.googleapis.com/whisper-public/wham_noise.zip 24 | + wham_noise_url=https://my-bucket-a8b4b49c25c811ee9a7e8bba05fa24c7.s3.amazonaws.com/wham_noise.zip 25 | wget --continue -O "${cdir}/data/wham_noise.zip" ${wham_noise_url} 26 | num_wavs=$(find "${cdir}/data/wham_noise" -iname "*.wav" | wc -l) 27 | if [ "${num_wavs}" = "4" ]; then 28 | @@ -116,21 +118,35 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} 29 | librimix="data/LibriMix/libri_mix/Libri2Mix" 30 | for dset in dev test train; do 31 | mkdir -p "data/${dset}" 32 | - if [ "$dset" = "train" ]; then 33 | - cat ${librimix}/wav${sample_rate}/${min_or_max}/metadata/mixture_train-*_mix_both.csv | grep -v mixture_ID | sort -u > "data/${dset}/tmp" 34 | + 35 | + if [ "$cond" = "noisy" ]; then 36 | + if [ "$dset" = "train" ]; then 37 | + cat ${librimix}/wav${sample_rate}/${min_or_max}/metadata/mixture_train-*_mix_both.csv | grep -v mixture_ID | sort -u > "data/${dset}/tmp" 38 | + else 39 | + grep -v mixture_ID ${librimix}/wav${sample_rate}/${min_or_max}/metadata/mixture_${dset}_mix_both.csv | sort -u > "data/${dset}/tmp" 40 | + fi 41 | else 42 | - grep -v mixture_ID ${librimix}/wav${sample_rate}/${min_or_max}/metadata/mixture_${dset}_mix_both.csv | sort -u > "data/${dset}/tmp" 43 | + if [ "$dset" = "train" ]; then 44 | + cat ${librimix}/wav${sample_rate}/${min_or_max}/metadata/mixture_train-*_mix_clean.csv | grep -v mixture_ID | sort -u > "data/${dset}/tmp" 45 | + else 46 | + grep -v mixture_ID ${librimix}/wav${sample_rate}/${min_or_max}/metadata/mixture_${dset}_mix_clean.csv | sort -u > "data/${dset}/tmp" 47 | + fi 48 | fi 49 | + 50 | awk -F ',' '{print $1, $1}' "data/${dset}/tmp" > "data/${dset}/utt2spk" 51 | awk -F ',' '{print $1, $1}' "data/${dset}/tmp" > "data/${dset}/spk2utt" 52 | awk -F ',' '{print $1, $2}' "data/${dset}/tmp" > "data/${dset}/wav.scp" 53 | awk -F ',' '{print $1, $3}' "data/${dset}/tmp" > "data/${dset}/spk1.scp" 54 | awk -F ',' '{print $1, $4}' "data/${dset}/tmp" > "data/${dset}/spk2.scp" 55 | if [ $num_spk -eq 2 ]; then 56 | - awk -F ',' '{print $1, $5}' "data/${dset}/tmp" > "data/${dset}/noise1.scp" 57 | + if [ "$cond" = "noisy" ]; then 58 | + awk -F ',' '{print $1, $5}' "data/${dset}/tmp" > "data/${dset}/noise1.scp" 59 | + fi 60 | else 61 | awk -F ',' '{print $1, $5}' "data/${dset}/tmp" > "data/${dset}/spk3.scp" 62 | - awk -F ',' '{print $1, $6}' "data/${dset}/tmp" > "data/${dset}/noise1.scp" 63 | + if [ "$cond" = "noisy" ]; then 64 | + awk -F ',' '{print $1, $6}' "data/${dset}/tmp" > "data/${dset}/noise1.scp" 65 | + fi 66 | fi 67 | rm "data/${dset}/tmp" 68 | done 69 | @@ -147,6 +163,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} 70 | [ "$f" = "data/train/wav.scp" ] || utils/filter_scp.pl "data/${subset}/wav.scp" "$f" > "data/${subset}/$(basename $f)" 71 | done 72 | utils/filter_scp.pl "data/${subset}/wav.scp" data/train/utt2spk > data/${subset}/utt2spk 73 | + utils/utt2spk_to_spk2utt.pl data/${subset}/utt2spk > data/${subset}/spk2utt 74 | done 75 | fi 76 | -------------------------------------------------------------------------------- /egs2/whamr/enh1/conf/tuning/train_enh_tflocoformer.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | init: xavier_uniform 7 | max_epoch: 150 8 | use_amp: false 9 | batch_type: folded 10 | batch_size: 4 # batch size 4 on 4 Quadro RTX 6000 (24GiB) 11 | num_workers: 4 12 | 13 | # preprocessor 14 | preprocessor: enh 15 | num_spk: &num_spk 2 16 | iterator_type: sequence # not to discard short samples 17 | speech_segment: 32000 18 | shuffle_within_batch: true 19 | force_single_channel: true 20 | 21 | # espnet model configuration 22 | model_conf: 23 | normalize_variance: true 24 | 25 | # optimizer and scheduler 26 | optim: adamw 27 | optim_conf: 28 | lr: 1.0e-03 29 | eps: 1.0e-08 30 | weight_decay: 1.0e-02 31 | patience: 10 32 | val_scheduler_criterion: 33 | - valid 34 | - loss 35 | best_model_criterion: 36 | - - valid 37 | - si_snr 38 | - max 39 | - - valid 40 | - loss 41 | - min 42 | keep_nbest_models: 5 43 | scheduler: warmupreducelronplateau 44 | scheduler_conf: 45 | warmup_steps: 4000 46 | mode: min 47 | factor: 0.5 48 | patience: 3 49 | 50 | # model configuration 51 | encoder: &encoder stft 52 | encoder_conf: 53 | n_fft: &n_fft 256 54 | hop_length: &hop_length 64 55 | decoder: *encoder 56 | decoder_conf: 57 | n_fft: *n_fft 58 | hop_length: *hop_length 59 | separator: tflocoformer 60 | separator_conf: 61 | num_spk: *num_spk 62 | n_layers: 6 63 | # general setup 64 | emb_dim: 128 65 | norm_type: rmsgroupnorm 66 | num_groups: 4 67 | tf_order: ft 68 | # self-attention 69 | n_heads: 4 70 | flash_attention: false 71 | # ffn 72 | ffn_type: 73 | - swiglu_conv1d 74 | - swiglu_conv1d 75 | ffn_hidden_dim: 76 | - 192 77 | - 192 # list order must be the same as ffn_type 78 | conv1d_kernel: 8 79 | conv1d_shift: 1 80 | dropout: 0.0 81 | # others 82 | eps: 1.0e-5 83 | 84 | criterions: 85 | # The first criterion 86 | - name: si_snr 87 | conf: 88 | eps: 1.0e-7 89 | wrapper: pit 90 | wrapper_conf: 91 | weight: 1.0 92 | independent_perm: true 93 | -------------------------------------------------------------------------------- /egs2/whamr/enh1/conf/tuning/train_enh_tflocoformer_small.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | init: xavier_uniform 7 | max_epoch: 150 8 | use_amp: false 9 | batch_type: folded 10 | batch_size: 4 # batch size 4 on 4 RTX 2080Ti 11 | num_workers: 4 12 | 13 | # preprocessor 14 | preprocessor: enh 15 | num_spk: &num_spk 2 16 | iterator_type: sequence # not to discard short samples 17 | speech_segment: 32000 18 | shuffle_within_batch: true 19 | force_single_channel: true 20 | 21 | # espnet model configuration 22 | model_conf: 23 | normalize_variance: true 24 | 25 | # optimizer and scheduler 26 | optim: adamw 27 | optim_conf: 28 | lr: 1.0e-03 29 | eps: 1.0e-08 30 | weight_decay: 1.0e-02 31 | patience: 10 32 | val_scheduler_criterion: 33 | - valid 34 | - loss 35 | best_model_criterion: 36 | - - valid 37 | - si_snr 38 | - max 39 | - - valid 40 | - loss 41 | - min 42 | keep_nbest_models: 5 43 | scheduler: warmupreducelronplateau 44 | scheduler_conf: 45 | warmup_steps: 4000 46 | mode: min 47 | factor: 0.5 48 | patience: 3 49 | 50 | # model configuration 51 | encoder: &encoder stft 52 | encoder_conf: 53 | n_fft: &n_fft 256 54 | hop_length: &hop_length 64 55 | decoder: *encoder 56 | decoder_conf: 57 | n_fft: *n_fft 58 | hop_length: *hop_length 59 | separator: tflocoformer 60 | separator_conf: 61 | num_spk: *num_spk 62 | n_layers: 4 63 | # general setup 64 | emb_dim: 96 65 | norm_type: rmsgroupnorm 66 | num_groups: 4 67 | tf_order: ft 68 | # self-attention 69 | n_heads: 4 70 | flash_attention: false 71 | # ffn 72 | ffn_type: 73 | - swiglu_conv1d 74 | - swiglu_conv1d 75 | ffn_hidden_dim: 76 | - 128 77 | - 128 # list order must be the same as ffn_type 78 | conv1d_kernel: 8 79 | conv1d_shift: 1 80 | dropout: 0.0 81 | # others 82 | eps: 1.0e-5 83 | 84 | criterions: 85 | # The first criterion 86 | - name: si_snr 87 | conf: 88 | eps: 1.0e-7 89 | wrapper: pit 90 | wrapper_conf: 91 | weight: 1.0 92 | independent_perm: true 93 | -------------------------------------------------------------------------------- /egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_raw/RESULTS.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | # RESULTS 9 | ## Environments 10 | - date: `Mon Jun 24 10:56:31 EDT 2024` 11 | - python version: `3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0]` 12 | - espnet version: `espnet 202402` 13 | - pytorch version: `pytorch 2.1.0` 14 | - Git hash: `90eed8e53498e7af682bc6ff39d9067ae440d6a4` 15 | - Commit date: `Mon May 27 22:42:15 2024 -0700` 16 | 17 | 18 | ## enh_train_enh_tflocoformer_raw 19 | 20 | config: ./conf/tuning/train_enh_tflocoformer.yaml 21 | 22 | |dataset|STOI|SAR|SDR|SIR|SI_SNR| 23 | |---|---|---|---|---|---| 24 | |enhanced_cv_mix_both_reverb_min_8k|91.20|12.91|12.76|28.78|11.78| 25 | |enhanced_tt_mix_both_reverb_min_8k|93.27|13.29|13.14|29.18|12.15| 26 | -------------------------------------------------------------------------------- /egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_raw/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | config: ./conf/tuning/train_enh_tflocoformer.yaml 7 | print_config: false 8 | log_level: INFO 9 | drop_last_iter: false 10 | dry_run: false 11 | iterator_type: sequence 12 | valid_iterator_type: null 13 | output_dir: exp/enh_train_enh_tflocoformer_raw 14 | ngpu: 1 15 | seed: 0 16 | num_workers: 4 17 | num_att_plot: 3 18 | dist_backend: nccl 19 | dist_init_method: env:// 20 | dist_world_size: 4 21 | dist_rank: 0 22 | local_rank: 0 23 | dist_master_addr: localhost 24 | dist_master_port: 49627 25 | dist_launcher: null 26 | multiprocessing_distributed: true 27 | unused_parameters: false 28 | sharded_ddp: false 29 | cudnn_enabled: true 30 | cudnn_benchmark: false 31 | cudnn_deterministic: true 32 | collect_stats: false 33 | write_collected_feats: false 34 | max_epoch: 150 35 | patience: 10 36 | val_scheduler_criterion: 37 | - valid 38 | - loss 39 | early_stopping_criterion: 40 | - valid 41 | - loss 42 | - min 43 | best_model_criterion: 44 | - - valid 45 | - si_snr 46 | - max 47 | - - valid 48 | - loss 49 | - min 50 | keep_nbest_models: 5 51 | nbest_averaging_interval: 0 52 | grad_clip: 5.0 53 | grad_clip_type: 2.0 54 | grad_noise: false 55 | accum_grad: 1 56 | no_forward_run: false 57 | resume: true 58 | train_dtype: float32 59 | use_amp: false 60 | log_interval: null 61 | use_matplotlib: true 62 | use_tensorboard: true 63 | create_graph_in_tensorboard: false 64 | use_wandb: false 65 | wandb_project: null 66 | wandb_id: null 67 | wandb_entity: null 68 | wandb_name: null 69 | wandb_model_log_interval: -1 70 | detect_anomaly: false 71 | use_adapter: false 72 | adapter: lora 73 | save_strategy: all 74 | adapter_conf: {} 75 | pretrain_path: null 76 | init_param: [] 77 | ignore_init_mismatch: false 78 | freeze_param: [] 79 | num_iters_per_epoch: null 80 | batch_size: 4 81 | valid_batch_size: null 82 | batch_bins: 1000000 83 | valid_batch_bins: null 84 | train_shape_file: 85 | - exp/enh_stats_8k/train/speech_mix_shape 86 | - exp/enh_stats_8k/train/speech_ref1_shape 87 | - exp/enh_stats_8k/train/speech_ref2_shape 88 | - exp/enh_stats_8k/train/noise_ref1_shape 89 | valid_shape_file: 90 | - exp/enh_stats_8k/valid/speech_mix_shape 91 | - exp/enh_stats_8k/valid/speech_ref1_shape 92 | - exp/enh_stats_8k/valid/speech_ref2_shape 93 | - exp/enh_stats_8k/valid/noise_ref1_shape 94 | batch_type: folded 95 | valid_batch_type: null 96 | fold_length: 97 | - 80000 98 | - 80000 99 | - 80000 100 | - 80000 101 | sort_in_batch: descending 102 | shuffle_within_batch: true 103 | sort_batch: descending 104 | multiple_iterator: false 105 | chunk_length: 500 106 | chunk_shift_ratio: 0.5 107 | num_cache_chunks: 1024 108 | chunk_excluded_key_prefixes: [] 109 | chunk_default_fs: null 110 | train_data_path_and_name_and_type: 111 | - - dump/raw/tr_mix_both_reverb_min_8k/wav.scp 112 | - speech_mix 113 | - sound 114 | - - dump/raw/tr_mix_both_reverb_min_8k/spk1.scp 115 | - speech_ref1 116 | - sound 117 | - - dump/raw/tr_mix_both_reverb_min_8k/spk2.scp 118 | - speech_ref2 119 | - sound 120 | - - dump/raw/tr_mix_both_reverb_min_8k/noise1.scp 121 | - noise_ref1 122 | - sound 123 | valid_data_path_and_name_and_type: 124 | - - dump/raw/cv_mix_both_reverb_min_8k/wav.scp 125 | - speech_mix 126 | - sound 127 | - - dump/raw/cv_mix_both_reverb_min_8k/spk1.scp 128 | - speech_ref1 129 | - sound 130 | - - dump/raw/cv_mix_both_reverb_min_8k/spk2.scp 131 | - speech_ref2 132 | - sound 133 | - - dump/raw/cv_mix_both_reverb_min_8k/noise1.scp 134 | - noise_ref1 135 | - sound 136 | allow_variable_data_keys: false 137 | max_cache_size: 0.0 138 | max_cache_fd: 32 139 | allow_multi_rates: false 140 | valid_max_cache_size: null 141 | exclude_weight_decay: false 142 | exclude_weight_decay_conf: {} 143 | optim: adamw 144 | optim_conf: 145 | lr: 0.001 146 | eps: 1.0e-08 147 | weight_decay: 0.01 148 | scheduler: warmupreducelronplateau 149 | scheduler_conf: 150 | warmup_steps: 4000 151 | mode: min 152 | factor: 0.5 153 | patience: 3 154 | init: xavier_uniform 155 | model_conf: 156 | normalize_variance: true 157 | criterions: 158 | - name: si_snr 159 | conf: 160 | eps: 1.0e-07 161 | wrapper: pit 162 | wrapper_conf: 163 | weight: 1.0 164 | independent_perm: true 165 | speech_volume_normalize: null 166 | rir_scp: null 167 | rir_apply_prob: 1.0 168 | noise_scp: null 169 | noise_apply_prob: 1.0 170 | noise_db_range: '13_15' 171 | short_noise_thres: 0.5 172 | use_reverberant_ref: false 173 | num_spk: 2 174 | num_noise_type: 1 175 | sample_rate: 8000 176 | force_single_channel: true 177 | channel_reordering: false 178 | categories: [] 179 | speech_segment: 32000 180 | avoid_allzero_segment: true 181 | flexible_numspk: false 182 | dynamic_mixing: false 183 | utt2spk: null 184 | dynamic_mixing_gain_db: 0.0 185 | encoder: stft 186 | encoder_conf: 187 | n_fft: 256 188 | hop_length: 64 189 | separator: tflocoformer 190 | separator_conf: 191 | num_spk: 2 192 | n_layers: 6 193 | emb_dim: 128 194 | norm_type: rmsgroupnorm 195 | num_groups: 4 196 | tf_order: ft 197 | n_heads: 4 198 | flash_attention: false 199 | ffn_type: 200 | - swiglu_conv1d 201 | - swiglu_conv1d 202 | ffn_hidden_dim: 203 | - 192 204 | - 192 205 | conv1d_kernel: 8 206 | conv1d_shift: 1 207 | dropout: 0.0 208 | eps: 1.0e-05 209 | decoder: stft 210 | decoder_conf: 211 | n_fft: 256 212 | hop_length: 64 213 | mask_module: multi_mask 214 | mask_module_conf: {} 215 | preprocessor: enh 216 | preprocessor_conf: {} 217 | diffusion_model: null 218 | diffusion_model_conf: {} 219 | required: 220 | - output_dir 221 | version: '202402' 222 | distributed: true 223 | -------------------------------------------------------------------------------- /egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_raw/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_raw/images/loss.png -------------------------------------------------------------------------------- /egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_raw/images/si_snr_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_raw/images/si_snr_loss.png -------------------------------------------------------------------------------- /egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_raw/valid.loss.ave_5best.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3e14c736d19cd2158e09f0446fe214be4b47e28ed6d6c11d037cf2d4dae82d6a 3 | size 59942624 4 | -------------------------------------------------------------------------------- /egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_small_raw/RESULTS.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | # RESULTS 9 | ## Environments 10 | - date: `Sat Jun 22 14:29:00 EDT 2024` 11 | - python version: `3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0]` 12 | - espnet version: `espnet 202402` 13 | - pytorch version: `pytorch 2.1.0` 14 | - Git hash: `90eed8e53498e7af682bc6ff39d9067ae440d6a4` 15 | - Commit date: `Mon May 27 22:42:15 2024 -0700` 16 | 17 | 18 | ## enh_train_enh_tflocoformer_small_raw 19 | 20 | config: ./conf/tuning/train_enh_tflocoformer_small.yaml 21 | 22 | |dataset|STOI|SAR|SDR|SIR|SI_SNR| 23 | |---|---|---|---|---|---| 24 | |enhanced_cv_mix_both_reverb_min_8k|89.95|12.21|12.02|27.52|10.98| 25 | |enhanced_tt_mix_both_reverb_min_8k|92.31|12.61|12.44|27.97|11.43| 26 | -------------------------------------------------------------------------------- /egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_small_raw/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | config: ./conf/tuning/train_enh_tflocoformer_small.yaml 7 | print_config: false 8 | log_level: INFO 9 | drop_last_iter: false 10 | dry_run: false 11 | iterator_type: sequence 12 | valid_iterator_type: null 13 | output_dir: exp/enh_train_enh_tflocoformer_small_raw 14 | ngpu: 1 15 | seed: 0 16 | num_workers: 4 17 | num_att_plot: 3 18 | dist_backend: nccl 19 | dist_init_method: env:// 20 | dist_world_size: 4 21 | dist_rank: 0 22 | local_rank: 0 23 | dist_master_addr: localhost 24 | dist_master_port: 48565 25 | dist_launcher: null 26 | multiprocessing_distributed: true 27 | unused_parameters: false 28 | sharded_ddp: false 29 | cudnn_enabled: true 30 | cudnn_benchmark: false 31 | cudnn_deterministic: true 32 | collect_stats: false 33 | write_collected_feats: false 34 | max_epoch: 150 35 | patience: 10 36 | val_scheduler_criterion: 37 | - valid 38 | - loss 39 | early_stopping_criterion: 40 | - valid 41 | - loss 42 | - min 43 | best_model_criterion: 44 | - - valid 45 | - si_snr 46 | - max 47 | - - valid 48 | - loss 49 | - min 50 | keep_nbest_models: 5 51 | nbest_averaging_interval: 0 52 | grad_clip: 5.0 53 | grad_clip_type: 2.0 54 | grad_noise: false 55 | accum_grad: 1 56 | no_forward_run: false 57 | resume: true 58 | train_dtype: float32 59 | use_amp: false 60 | log_interval: null 61 | use_matplotlib: true 62 | use_tensorboard: true 63 | create_graph_in_tensorboard: false 64 | use_wandb: false 65 | wandb_project: null 66 | wandb_id: null 67 | wandb_entity: null 68 | wandb_name: null 69 | wandb_model_log_interval: -1 70 | detect_anomaly: false 71 | use_adapter: false 72 | adapter: lora 73 | save_strategy: all 74 | adapter_conf: {} 75 | pretrain_path: null 76 | init_param: [] 77 | ignore_init_mismatch: false 78 | freeze_param: [] 79 | num_iters_per_epoch: null 80 | batch_size: 4 81 | valid_batch_size: null 82 | batch_bins: 1000000 83 | valid_batch_bins: null 84 | train_shape_file: 85 | - exp/enh_stats_8k/train/speech_mix_shape 86 | - exp/enh_stats_8k/train/speech_ref1_shape 87 | - exp/enh_stats_8k/train/speech_ref2_shape 88 | - exp/enh_stats_8k/train/noise_ref1_shape 89 | valid_shape_file: 90 | - exp/enh_stats_8k/valid/speech_mix_shape 91 | - exp/enh_stats_8k/valid/speech_ref1_shape 92 | - exp/enh_stats_8k/valid/speech_ref2_shape 93 | - exp/enh_stats_8k/valid/noise_ref1_shape 94 | batch_type: folded 95 | valid_batch_type: null 96 | fold_length: 97 | - 80000 98 | - 80000 99 | - 80000 100 | - 80000 101 | sort_in_batch: descending 102 | shuffle_within_batch: true 103 | sort_batch: descending 104 | multiple_iterator: false 105 | chunk_length: 500 106 | chunk_shift_ratio: 0.5 107 | num_cache_chunks: 1024 108 | chunk_excluded_key_prefixes: [] 109 | chunk_default_fs: null 110 | train_data_path_and_name_and_type: 111 | - - dump/raw/tr_mix_both_reverb_min_8k/wav.scp 112 | - speech_mix 113 | - sound 114 | - - dump/raw/tr_mix_both_reverb_min_8k/spk1.scp 115 | - speech_ref1 116 | - sound 117 | - - dump/raw/tr_mix_both_reverb_min_8k/spk2.scp 118 | - speech_ref2 119 | - sound 120 | - - dump/raw/tr_mix_both_reverb_min_8k/noise1.scp 121 | - noise_ref1 122 | - sound 123 | valid_data_path_and_name_and_type: 124 | - - dump/raw/cv_mix_both_reverb_min_8k/wav.scp 125 | - speech_mix 126 | - sound 127 | - - dump/raw/cv_mix_both_reverb_min_8k/spk1.scp 128 | - speech_ref1 129 | - sound 130 | - - dump/raw/cv_mix_both_reverb_min_8k/spk2.scp 131 | - speech_ref2 132 | - sound 133 | - - dump/raw/cv_mix_both_reverb_min_8k/noise1.scp 134 | - noise_ref1 135 | - sound 136 | allow_variable_data_keys: false 137 | max_cache_size: 0.0 138 | max_cache_fd: 32 139 | allow_multi_rates: false 140 | valid_max_cache_size: null 141 | exclude_weight_decay: false 142 | exclude_weight_decay_conf: {} 143 | optim: adamw 144 | optim_conf: 145 | lr: 0.001 146 | eps: 1.0e-08 147 | weight_decay: 0.01 148 | scheduler: warmupreducelronplateau 149 | scheduler_conf: 150 | warmup_steps: 4000 151 | mode: min 152 | factor: 0.5 153 | patience: 3 154 | init: xavier_uniform 155 | model_conf: 156 | normalize_variance: true 157 | criterions: 158 | - name: si_snr 159 | conf: 160 | eps: 1.0e-07 161 | wrapper: pit 162 | wrapper_conf: 163 | weight: 1.0 164 | independent_perm: true 165 | speech_volume_normalize: null 166 | rir_scp: null 167 | rir_apply_prob: 1.0 168 | noise_scp: null 169 | noise_apply_prob: 1.0 170 | noise_db_range: '13_15' 171 | short_noise_thres: 0.5 172 | use_reverberant_ref: false 173 | num_spk: 2 174 | num_noise_type: 1 175 | sample_rate: 8000 176 | force_single_channel: true 177 | channel_reordering: false 178 | categories: [] 179 | speech_segment: 32000 180 | avoid_allzero_segment: true 181 | flexible_numspk: false 182 | dynamic_mixing: false 183 | utt2spk: null 184 | dynamic_mixing_gain_db: 0.0 185 | encoder: stft 186 | encoder_conf: 187 | n_fft: 256 188 | hop_length: 64 189 | separator: tflocoformer 190 | separator_conf: 191 | num_spk: 2 192 | n_layers: 4 193 | emb_dim: 96 194 | norm_type: rmsgroupnorm 195 | num_groups: 4 196 | tf_order: ft 197 | n_heads: 4 198 | flash_attention: false 199 | ffn_type: 200 | - swiglu_conv1d 201 | - swiglu_conv1d 202 | ffn_hidden_dim: 203 | - 128 204 | - 128 205 | conv1d_kernel: 8 206 | conv1d_shift: 1 207 | dropout: 0.0 208 | eps: 1.0e-05 209 | decoder: stft 210 | decoder_conf: 211 | n_fft: 256 212 | hop_length: 64 213 | mask_module: multi_mask 214 | mask_module_conf: {} 215 | preprocessor: enh 216 | preprocessor_conf: {} 217 | diffusion_model: null 218 | diffusion_model_conf: {} 219 | required: 220 | - output_dir 221 | version: '202402' 222 | distributed: true 223 | -------------------------------------------------------------------------------- /egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_small_raw/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_small_raw/images/loss.png -------------------------------------------------------------------------------- /egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_small_raw/images/si_snr_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_small_raw/images/si_snr_loss.png -------------------------------------------------------------------------------- /egs2/whamr/enh1/exp/enh_train_enh_tflocoformer_small_raw/valid.loss.ave_5best.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4ba8ffd35258687aeb6f4dd0cf30622728d440b2581e35ae6b4e1f8274dc5dd8 3 | size 20553016 4 | -------------------------------------------------------------------------------- /egs2/whamr/enh1/local/whamr_data_prep.patch: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # Copyright (C) 2017 ESPnet Developers 3 | # 4 | # SPDX-License-Identifier: Apache-2.0 5 | # SPDX-License-Identifier: Apache-2.0 6 | 7 | --- whamr_data_prep.sh.orig 2024-06-25 16:06:05.031799614 -0400 8 | +++ whamr_data_prep.sh 2024-06-25 17:05:10.303535373 -0400 9 | @@ -72,13 +72,38 @@ for x in tr cv tt; do 10 | > ${data}/${ddir}/noise1.scp 11 | fi 12 | 13 | - spk1_wav_dir=${rootdir}/s1_${cond} 14 | - sed -e "s#${mixwav_dir}#${spk1_wav_dir}#g" ${data}/${ddir}/wav.scp \ 15 | - > ${data}/${ddir}/spk1.scp 16 | - if [[ "$mixtype" != "single" ]]; then 17 | - spk2_wav_dir=${rootdir}/s2_${cond} 18 | - sed -e "s#${mixwav_dir}#${spk2_wav_dir}#g" ${data}/${ddir}/wav.scp \ 19 | - > ${data}/${ddir}/spk2.scp 20 | + 21 | + # NOTE: modified to do dereverberation and separation 22 | + if [[ "$cond" = "reverb" ]]; then 23 | + # make anechoic spk scp files 24 | + spk1_wav_dir=${rootdir}/s1_anechoic 25 | + sed -e "s#${mixwav_dir}#${spk1_wav_dir}#g" ${data}/${ddir}/wav.scp \ 26 | + > ${data}/${ddir}/spk1.scp 27 | + if [[ "$mixtype" != "single" ]]; then 28 | + spk2_wav_dir=${rootdir}/s2_anechoic 29 | + sed -e "s#${mixwav_dir}#${spk2_wav_dir}#g" ${data}/${ddir}/wav.scp \ 30 | + > ${data}/${ddir}/spk2.scp 31 | + fi 32 | + 33 | + # reverb scps 34 | + spk1_wav_dir=${rootdir}/s1_${cond} 35 | + sed -e "s#${mixwav_dir}#${spk1_wav_dir}#g" ${data}/${ddir}/wav.scp \ 36 | + > ${data}/${ddir}/spk1_reverb.scp 37 | + if [[ "$mixtype" != "single" ]]; then 38 | + spk2_wav_dir=${rootdir}/s2_${cond} 39 | + sed -e "s#${mixwav_dir}#${spk2_wav_dir}#g" ${data}/${ddir}/wav.scp \ 40 | + > ${data}/${ddir}/spk2_reverb.scp 41 | + fi 42 | + else 43 | + # original code 44 | + spk1_wav_dir=${rootdir}/s1_${cond} 45 | + sed -e "s#${mixwav_dir}#${spk1_wav_dir}#g" ${data}/${ddir}/wav.scp \ 46 | + > ${data}/${ddir}/spk1.scp 47 | + if [[ "$mixtype" != "single" ]]; then 48 | + spk2_wav_dir=${rootdir}/s2_${cond} 49 | + sed -e "s#${mixwav_dir}#${spk2_wav_dir}#g" ${data}/${ddir}/wav.scp \ 50 | + > ${data}/${ddir}/spk2.scp 51 | + fi 52 | fi 53 | 54 | if [[ "$cond" = "reverb" ]]; then 55 | -------------------------------------------------------------------------------- /egs2/wsj0_2mix/enh1/conf/tuning/train_enh_tflocoformer.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | # Takes ~4.5 days using 4 RTX 2080Ti 7 | init: xavier_uniform 8 | max_epoch: 150 9 | use_amp: false 10 | batch_type: folded 11 | batch_size: 4 # batch size 4 on 4 RTX 2080Ti 12 | num_workers: 4 13 | 14 | # preprocessor 15 | preprocessor: enh 16 | num_spk: &num_spk 2 17 | iterator_type: sequence # not to discard short samples 18 | speech_segment: 32000 19 | shuffle_within_batch: true 20 | 21 | # espnet model configuration 22 | model_conf: 23 | normalize_variance: true 24 | 25 | # optimizer and scheduler 26 | optim: adamw 27 | optim_conf: 28 | lr: 1.0e-03 29 | eps: 1.0e-08 30 | weight_decay: 1.0e-02 31 | patience: 10 32 | val_scheduler_criterion: 33 | - valid 34 | - loss 35 | best_model_criterion: 36 | - - valid 37 | - si_snr 38 | - max 39 | - - valid 40 | - loss 41 | - min 42 | keep_nbest_models: 5 43 | scheduler: warmupreducelronplateau 44 | scheduler_conf: 45 | warmup_steps: 4000 46 | mode: min 47 | factor: 0.5 48 | patience: 3 49 | 50 | # model configuration 51 | encoder: &encoder stft 52 | encoder_conf: 53 | n_fft: &n_fft 128 54 | hop_length: &hop_length 64 55 | decoder: *encoder 56 | decoder_conf: 57 | n_fft: *n_fft 58 | hop_length: *hop_length 59 | separator: tflocoformer 60 | separator_conf: 61 | num_spk: *num_spk 62 | n_layers: 6 63 | # general setup 64 | emb_dim: 128 65 | norm_type: rmsgroupnorm 66 | num_groups: 4 67 | tf_order: ft 68 | # self-attention 69 | n_heads: 4 70 | flash_attention: false 71 | # ffn 72 | ffn_type: 73 | - swiglu_conv1d 74 | - swiglu_conv1d 75 | ffn_hidden_dim: 76 | - 384 77 | - 384 # list order must be the same as ffn_type 78 | conv1d_kernel: 4 79 | conv1d_shift: 1 80 | dropout: 0.0 81 | # others 82 | eps: 1.0e-5 83 | 84 | criterions: 85 | # The first criterion 86 | - name: si_snr 87 | conf: 88 | eps: 1.0e-7 89 | wrapper: pit 90 | wrapper_conf: 91 | weight: 1.0 92 | independent_perm: true 93 | -------------------------------------------------------------------------------- /egs2/wsj0_2mix/enh1/exp/enh_train_enh_tflocoformer_raw/RESULTS.md: -------------------------------------------------------------------------------- 1 | 6 | 7 | 8 | # RESULTS 9 | ## Environments 10 | - date: `Mon Jun 3 10:48:32 EDT 2024` 11 | - python version: `3.10.8 (main, Nov 24 2022, 14:13:03) [GCC 11.2.0]` 12 | - espnet version: `espnet 202402` 13 | - pytorch version: `pytorch 2.1.0` 14 | - Git hash: `90eed8e53498e7af682bc6ff39d9067ae440d6a4` 15 | - Commit date: `Mon May 27 22:42:15 2024 -0700` 16 | 17 | 18 | ## enh_train_enh_tflocoformer_raw 19 | 20 | config: conf/tuning/train_enh_tflocoformer.yaml 21 | 22 | |dataset|STOI|SAR|SDR|SIR|SI_SNR| 23 | |---|---|---|---|---|---| 24 | |enhanced_cv_min_8k|98.18|23.61|23.28|35.28|22.98| 25 | |enhanced_tt_min_8k|98.91|24.25|23.93|36.18|23.64| 26 | -------------------------------------------------------------------------------- /egs2/wsj0_2mix/enh1/exp/enh_train_enh_tflocoformer_raw/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | config: conf/tuning/train_enh_tflocoformer.yaml 7 | print_config: false 8 | log_level: INFO 9 | drop_last_iter: false 10 | dry_run: false 11 | iterator_type: sequence 12 | valid_iterator_type: null 13 | output_dir: exp/enh_train_enh_tflocoformer_raw 14 | ngpu: 1 15 | seed: 0 16 | num_workers: 4 17 | num_att_plot: 3 18 | dist_backend: nccl 19 | dist_init_method: env:// 20 | dist_world_size: 4 21 | dist_rank: 0 22 | local_rank: 0 23 | dist_master_addr: localhost 24 | dist_master_port: 37889 25 | dist_launcher: null 26 | multiprocessing_distributed: true 27 | unused_parameters: false 28 | sharded_ddp: false 29 | cudnn_enabled: true 30 | cudnn_benchmark: false 31 | cudnn_deterministic: true 32 | collect_stats: false 33 | write_collected_feats: false 34 | max_epoch: 150 35 | patience: 10 36 | val_scheduler_criterion: 37 | - valid 38 | - loss 39 | early_stopping_criterion: 40 | - valid 41 | - loss 42 | - min 43 | best_model_criterion: 44 | - - valid 45 | - si_snr 46 | - max 47 | - - valid 48 | - loss 49 | - min 50 | keep_nbest_models: 5 51 | nbest_averaging_interval: 0 52 | grad_clip: 5.0 53 | grad_clip_type: 2.0 54 | grad_noise: false 55 | accum_grad: 1 56 | no_forward_run: false 57 | resume: true 58 | train_dtype: float32 59 | use_amp: false 60 | log_interval: null 61 | use_matplotlib: true 62 | use_tensorboard: true 63 | create_graph_in_tensorboard: false 64 | use_wandb: false 65 | wandb_project: null 66 | wandb_id: null 67 | wandb_entity: null 68 | wandb_name: null 69 | wandb_model_log_interval: -1 70 | detect_anomaly: false 71 | use_adapter: false 72 | adapter: lora 73 | save_strategy: all 74 | adapter_conf: {} 75 | pretrain_path: null 76 | init_param: [] 77 | ignore_init_mismatch: false 78 | freeze_param: [] 79 | num_iters_per_epoch: null 80 | batch_size: 4 81 | valid_batch_size: null 82 | batch_bins: 1000000 83 | valid_batch_bins: null 84 | train_shape_file: 85 | - exp/enh_stats_8k/train/speech_mix_shape 86 | - exp/enh_stats_8k/train/speech_ref1_shape 87 | - exp/enh_stats_8k/train/speech_ref2_shape 88 | valid_shape_file: 89 | - exp/enh_stats_8k/valid/speech_mix_shape 90 | - exp/enh_stats_8k/valid/speech_ref1_shape 91 | - exp/enh_stats_8k/valid/speech_ref2_shape 92 | batch_type: folded 93 | valid_batch_type: null 94 | fold_length: 95 | - 80000 96 | - 80000 97 | - 80000 98 | sort_in_batch: descending 99 | shuffle_within_batch: true 100 | sort_batch: descending 101 | multiple_iterator: false 102 | chunk_length: 500 103 | chunk_shift_ratio: 0.5 104 | num_cache_chunks: 1024 105 | chunk_excluded_key_prefixes: [] 106 | chunk_default_fs: null 107 | train_data_path_and_name_and_type: 108 | - - dump/raw/tr_min_8k/wav.scp 109 | - speech_mix 110 | - sound 111 | - - dump/raw/tr_min_8k/spk1.scp 112 | - speech_ref1 113 | - sound 114 | - - dump/raw/tr_min_8k/spk2.scp 115 | - speech_ref2 116 | - sound 117 | valid_data_path_and_name_and_type: 118 | - - dump/raw/cv_min_8k/wav.scp 119 | - speech_mix 120 | - sound 121 | - - dump/raw/cv_min_8k/spk1.scp 122 | - speech_ref1 123 | - sound 124 | - - dump/raw/cv_min_8k/spk2.scp 125 | - speech_ref2 126 | - sound 127 | allow_variable_data_keys: false 128 | max_cache_size: 0.0 129 | max_cache_fd: 32 130 | allow_multi_rates: false 131 | valid_max_cache_size: null 132 | exclude_weight_decay: false 133 | exclude_weight_decay_conf: {} 134 | optim: adamw 135 | optim_conf: 136 | lr: 0.001 137 | eps: 1.0e-08 138 | weight_decay: 0.01 139 | scheduler: warmupreducelronplateau 140 | scheduler_conf: 141 | warmup_steps: 4000 142 | mode: min 143 | factor: 0.5 144 | patience: 3 145 | init: xavier_uniform 146 | model_conf: 147 | normalize_variance: true 148 | criterions: 149 | - name: si_snr 150 | conf: 151 | eps: 1.0e-07 152 | wrapper: pit 153 | wrapper_conf: 154 | weight: 1.0 155 | independent_perm: true 156 | speech_volume_normalize: null 157 | rir_scp: null 158 | rir_apply_prob: 1.0 159 | noise_scp: null 160 | noise_apply_prob: 1.0 161 | noise_db_range: '13_15' 162 | short_noise_thres: 0.5 163 | use_reverberant_ref: false 164 | num_spk: 2 165 | num_noise_type: 1 166 | sample_rate: 8000 167 | force_single_channel: false 168 | channel_reordering: false 169 | categories: [] 170 | speech_segment: 32000 171 | avoid_allzero_segment: true 172 | flexible_numspk: false 173 | dynamic_mixing: false 174 | utt2spk: null 175 | dynamic_mixing_gain_db: 0.0 176 | encoder: stft 177 | encoder_conf: 178 | n_fft: 128 179 | hop_length: 64 180 | separator: tflocoformer 181 | separator_conf: 182 | num_spk: 2 183 | n_layers: 6 184 | emb_dim: 128 185 | norm_type: rmsgroupnorm 186 | num_groups: 4 187 | tf_order: ft 188 | n_heads: 4 189 | flash_attention: false 190 | ffn_type: 191 | - swiglu_conv1d 192 | - swiglu_conv1d 193 | ffn_hidden_dim: 194 | - 384 195 | - 384 196 | conv1d_kernel: 4 197 | conv1d_shift: 1 198 | dropout: 0.0 199 | eps: 1.0e-05 200 | decoder: stft 201 | decoder_conf: 202 | n_fft: 128 203 | hop_length: 64 204 | mask_module: multi_mask 205 | mask_module_conf: {} 206 | preprocessor: enh 207 | preprocessor_conf: {} 208 | diffusion_model: null 209 | diffusion_model_conf: {} 210 | required: 211 | - output_dir 212 | version: '202402' 213 | distributed: true 214 | -------------------------------------------------------------------------------- /egs2/wsj0_2mix/enh1/exp/enh_train_enh_tflocoformer_raw/images/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/wsj0_2mix/enh1/exp/enh_train_enh_tflocoformer_raw/images/loss.png -------------------------------------------------------------------------------- /egs2/wsj0_2mix/enh1/exp/enh_train_enh_tflocoformer_raw/images/si_snr_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/merlresearch/tf-locoformer/b76d38baad9428629cb2fbf23a951e4c76290015/egs2/wsj0_2mix/enh1/exp/enh_train_enh_tflocoformer_raw/images/si_snr_loss.png -------------------------------------------------------------------------------- /egs2/wsj0_2mix/enh1/exp/enh_train_enh_tflocoformer_raw/valid.loss.ave_5best.pth: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c536fa8499b28e2cf812b1906d899ae6337665232ca523dbe628682dd7df983c 3 | size 59979488 4 | -------------------------------------------------------------------------------- /egs2/wsj0_2mix/enh1/separate.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | import argparse 7 | from pathlib import Path 8 | 9 | import soundfile as sf 10 | from espnet2.bin.enh_inference import SeparateSpeech 11 | 12 | if __name__ == "__main__": 13 | 14 | parser = argparse.ArgumentParser() 15 | 16 | parser.add_argument( 17 | "--model_path", type=Path, required=True, help="Path to pre-trained model parameters (.pth file)." 18 | ) 19 | parser.add_argument("--audio_path", type=Path, required=True, help="Path to the audio file to separate.") 20 | parser.add_argument( 21 | "--audio_output_dir", type=Path, default="./audio_outputs", help="Directory to save the separated audios." 22 | ) 23 | args = parser.parse_args() 24 | 25 | config_path = args.model_path.parent / "config.yaml" 26 | 27 | separation_model = SeparateSpeech( 28 | train_config=config_path, 29 | model_file=args.model_path, 30 | normalize_output_wav=True, 31 | device="cuda:0", 32 | ) 33 | 34 | mix, sample_rate = sf.read(args.audio_path, dtype="float32") 35 | 36 | # Normalize the input 37 | mix /= mix.std(axis=-1) 38 | 39 | # Shape of input mixture must be (1, n_samples) 40 | speeches = separation_model(mix[None], sample_rate) # list of numpy arrays 41 | 42 | # Save the separated audios 43 | args.audio_output_dir.mkdir(exist_ok=True, parents=True) 44 | for i, speech in enumerate(speeches): 45 | filename = f"{args.audio_path.stem}_{i+1}.wav" 46 | sf.write(args.audio_output_dir / filename, speech[0], sample_rate) 47 | -------------------------------------------------------------------------------- /espnet2/enh/separator/tflocoformer_separator.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | 6 | import math 7 | from collections import OrderedDict 8 | from typing import Dict, List, Optional, Tuple, Union 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from espnet2.enh.layers.complex_utils import new_complex_like 14 | from packaging.version import parse as V 15 | from rotary_embedding_torch import RotaryEmbedding 16 | 17 | from espnet2.enh.separator.abs_separator import AbsSeparator 18 | 19 | is_torch_2_0_plus = V(torch.__version__) >= V("2.0.0") 20 | 21 | 22 | class TFLocoformerSeparator(AbsSeparator): 23 | """TF-Locoformer model presented in [1]. 24 | 25 | Reference: 26 | [1] Kohei Saijo, Gordon Wichern, François G. Germain, Zexu Pan, and Jonathan Le Roux, 27 | "TF-Locoformer: Transformer with Local Modeling by Convolution for Speech Separation 28 | and Enhancement," in Proc. International Workshop on Acoustic Signal Enhancement (IWAENC), 29 | Sep. 2024. 30 | 31 | Args: 32 | input_dim: int 33 | placeholder, not used 34 | num_spk: int 35 | number of output sources/speakers. 36 | n_layers: int 37 | number of Locoformer blocks. 38 | emb_dim: int 39 | Size of hidden dimension in the encoding Conv2D. 40 | norm_type: str 41 | Normalization layer. Must be either "layernorm" or "rmsgroupnorm". 42 | num_groups: int 43 | Number of groups in RMSGroupNorm layer. 44 | tf_order: str 45 | Order of frequency and temporal modeling. Must be either "ft" or "tf". 46 | n_heads: int 47 | Number of heads in multi-head self-attention. 48 | flash_attention: bool 49 | Whether to use flash attention. Only compatible with half precision. 50 | ffn_type: str or list 51 | Feed-forward network (FFN)-type chosen from "conv1d" or "swiglu_conv1d". 52 | Giving the list (e.g., ["conv1d", "conv1d"]) makes the model Macaron-style. 53 | ffn_hidden_dim: int or list 54 | Number of hidden dimensions in FFN. 55 | Giving the list (e.g., [256, 256]) makes the model Macaron-style. 56 | conv1d_kernel: int 57 | Kernel size in Conv1d. 58 | conv1d_shift: int 59 | Shift size of Conv1d kernel. 60 | dropout: float 61 | Dropout probability. 62 | eps: float 63 | Small constant for normalization layer. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | input_dim, 69 | num_spk: int = 2, 70 | n_layers: int = 6, 71 | # general setup 72 | emb_dim: int = 128, 73 | norm_type: str = "rmsgrouporm", 74 | num_groups: int = 4, # used only in RMSGroupNorm 75 | tf_order: str = "ft", 76 | # self-attention related 77 | n_heads: int = 4, 78 | flash_attention: bool = False, # available when using mixed precision 79 | attention_dim: int = 128, 80 | # ffn related 81 | ffn_type: Union[str, list] = "swiglu_conv1d", 82 | ffn_hidden_dim: Union[int, list] = 384, 83 | conv1d_kernel: int = 4, 84 | conv1d_shift: int = 1, 85 | dropout: float = 0.0, 86 | # others 87 | eps: float = 1.0e-5, 88 | ): 89 | super().__init__() 90 | assert is_torch_2_0_plus, "Support only pytorch >= 2.0.0" 91 | 92 | self._num_spk = num_spk 93 | self.n_layers = n_layers 94 | 95 | t_ksize = 3 96 | ks, padding = (t_ksize, 3), (t_ksize // 2, 1) 97 | self.conv = nn.Sequential( 98 | nn.Conv2d(2, emb_dim, ks, padding=padding), 99 | nn.GroupNorm(1, emb_dim, eps=eps), # gLN 100 | ) 101 | 102 | assert attention_dim % n_heads == 0, (attention_dim, n_heads) 103 | rope_freq = RotaryEmbedding(attention_dim // n_heads) 104 | rope_time = RotaryEmbedding(attention_dim // n_heads) 105 | 106 | self.blocks = nn.ModuleList([]) 107 | for _ in range(n_layers): 108 | self.blocks.append( 109 | TFLocoformerBlock( 110 | rope_freq, 111 | rope_time, 112 | # general setup 113 | emb_dim=emb_dim, 114 | norm_type=norm_type, 115 | num_groups=num_groups, 116 | tf_order=tf_order, 117 | # self-attention related 118 | n_heads=n_heads, 119 | flash_attention=flash_attention, 120 | attention_dim=attention_dim, 121 | # ffn related 122 | ffn_type=ffn_type, 123 | ffn_hidden_dim=ffn_hidden_dim, 124 | conv1d_kernel=conv1d_kernel, 125 | conv1d_shift=conv1d_shift, 126 | dropout=dropout, 127 | eps=eps, 128 | ) 129 | ) 130 | 131 | self.deconv = nn.ConvTranspose2d(emb_dim, num_spk * 2, ks, padding=padding) 132 | 133 | def forward( 134 | self, 135 | input: torch.Tensor, 136 | ilens: torch.Tensor, 137 | additional: Optional[Dict] = None, 138 | ) -> Tuple[List[torch.Tensor], torch.Tensor, OrderedDict]: 139 | """Forward. 140 | 141 | Args: 142 | input (torch.Tensor): batched single-channel audio tensor with 143 | in TF-domain [B, T, F] 144 | ilens (torch.Tensor): input lengths [B] 145 | additional (Dict or None): other data, currently unused in this model. 146 | 147 | Returns: 148 | enhanced (List[Union(torch.Tensor)]): 149 | [(B, T), ...] list of len num_spk 150 | of mono audio tensors with T samples. 151 | ilens (torch.Tensor): (B,) 152 | additional (Dict or None): other data, currently unused in this model, 153 | we return it also in the output. 154 | """ 155 | if input.ndim == 3: 156 | # in case the input does not have channel dimension 157 | batch0 = input.unsqueeze(1) 158 | elif input.ndim == 4: 159 | assert batch0.shape[1] == 1, "Only monaural input is supported." 160 | batch0 = input.transpose(1, 2) # [B, M, T, F] 161 | 162 | batch = torch.cat((batch0.real, batch0.imag), dim=1) # [B, 2*M, T, F] 163 | n_batch, _, n_frames, n_freqs = batch.shape 164 | 165 | with torch.cuda.amp.autocast(enabled=False): 166 | batch = self.conv(batch) # [B, -1, T, F] 167 | 168 | # separation 169 | for ii in range(self.n_layers): 170 | batch = self.blocks[ii](batch) # [B, -1, T, F] 171 | 172 | with torch.cuda.amp.autocast(enabled=False): 173 | batch = self.deconv(batch) # [B, num_spk*2, T, F] 174 | batch = batch.view([n_batch, self.num_spk, 2, n_frames, n_freqs]) 175 | 176 | batch = new_complex_like(batch0, (batch[:, :, 0], batch[:, :, 1])) 177 | batch = [batch[:, src] for src in range(self.num_spk)] 178 | 179 | return batch, ilens, OrderedDict() 180 | 181 | @property 182 | def num_spk(self): 183 | return self._num_spk 184 | 185 | 186 | class TFLocoformerBlock(nn.Module): 187 | def __init__( 188 | self, 189 | rope_freq, 190 | rope_time, 191 | # general setup 192 | emb_dim=128, 193 | norm_type="rmsgrouporm", 194 | num_groups=4, 195 | tf_order="ft", 196 | # self-attention related 197 | n_heads=4, 198 | flash_attention=False, 199 | attention_dim=128, 200 | # ffn related 201 | ffn_type="swiglu_conv1d", 202 | ffn_hidden_dim=384, 203 | conv1d_kernel=4, 204 | conv1d_shift=1, 205 | dropout=0.0, 206 | eps=1.0e-5, 207 | ): 208 | super().__init__() 209 | 210 | assert tf_order in ["tf", "ft"], tf_order 211 | self.tf_order = tf_order 212 | self.conv1d_kernel = conv1d_kernel 213 | self.conv1d_shift = conv1d_shift 214 | 215 | self.freq_path = LocoformerBlock( 216 | rope_freq, 217 | # general setup 218 | emb_dim=emb_dim, 219 | norm_type=norm_type, 220 | num_groups=num_groups, 221 | # self-attention related 222 | n_heads=n_heads, 223 | flash_attention=flash_attention, 224 | attention_dim=attention_dim, 225 | # ffn related 226 | ffn_type=ffn_type, 227 | ffn_hidden_dim=ffn_hidden_dim, 228 | conv1d_kernel=conv1d_kernel, 229 | conv1d_shift=conv1d_shift, 230 | dropout=dropout, 231 | eps=eps, 232 | ) 233 | self.frame_path = LocoformerBlock( 234 | rope_time, 235 | # general setup 236 | emb_dim=emb_dim, 237 | norm_type=norm_type, 238 | num_groups=num_groups, 239 | # self-attention related 240 | n_heads=n_heads, 241 | flash_attention=flash_attention, 242 | attention_dim=attention_dim, 243 | # ffn related 244 | ffn_type=ffn_type, 245 | ffn_hidden_dim=ffn_hidden_dim, 246 | conv1d_kernel=conv1d_kernel, 247 | conv1d_shift=conv1d_shift, 248 | dropout=dropout, 249 | eps=eps, 250 | ) 251 | 252 | def forward(self, input): 253 | """TF-Locoformer forward. 254 | 255 | input: torch.Tensor 256 | Input tensor, (n_batch, channel, n_frame, n_freq) 257 | """ 258 | 259 | if self.tf_order == "ft": 260 | output = self.freq_frame_process(input) 261 | else: 262 | output = self.frame_freq_process(input) 263 | 264 | return output 265 | 266 | def freq_frame_process(self, input): 267 | output = input.movedim(1, -1) # (B, T, Q_old, H) 268 | output = self.freq_path(output) 269 | 270 | output = output.transpose(1, 2) # (B, F, T, H) 271 | output = self.frame_path(output) 272 | return output.transpose(-1, 1) 273 | 274 | def frame_freq_process(self, input): 275 | # Input tensor, (n_batch, hidden, n_frame, n_freq) 276 | output = input.transpose(1, -1) # (B, F, T, H) 277 | output = self.frame_path(output) 278 | 279 | output = output.transpose(1, 2) # (B, T, F, H) 280 | output = self.freq_path(output) 281 | return output.movedim(-1, 1) 282 | 283 | 284 | class LocoformerBlock(nn.Module): 285 | def __init__( 286 | self, 287 | rope, 288 | # general setup 289 | emb_dim=128, 290 | norm_type="rmsgrouporm", 291 | num_groups=4, 292 | # self-attention related 293 | n_heads=4, 294 | flash_attention=False, 295 | attention_dim=128, 296 | # ffn related 297 | ffn_type="swiglu_conv1d", 298 | ffn_hidden_dim=384, 299 | conv1d_kernel=4, 300 | conv1d_shift=1, 301 | dropout=0.0, 302 | eps=1.0e-5, 303 | ): 304 | super().__init__() 305 | 306 | FFN = { 307 | "conv1d": ConvDeconv1d, 308 | "swiglu_conv1d": SwiGLUConvDeconv1d, 309 | } 310 | Norm = { 311 | "layernorm": nn.LayerNorm, 312 | "rmsgroupnorm": RMSGroupNorm, 313 | } 314 | assert norm_type in Norm, norm_type 315 | 316 | self.macaron_style = isinstance(ffn_type, list) and len(ffn_type) == 2 317 | if self.macaron_style: 318 | assert ( 319 | isinstance(ffn_hidden_dim, list) and len(ffn_hidden_dim) == 2 320 | ), "Two FFNs required when using Macaron-style model" 321 | 322 | # initialize FFN 323 | self.ffn_norm = nn.ModuleList([]) 324 | self.ffn = nn.ModuleList([]) 325 | for f_type, f_dim in zip(ffn_type[::-1], ffn_hidden_dim[::-1]): 326 | assert f_type in FFN, f_type 327 | if norm_type == "rmsgroupnorm": 328 | self.ffn_norm.append(Norm[norm_type](num_groups, emb_dim, eps=eps)) 329 | else: 330 | self.ffn_norm.append(Norm[norm_type](emb_dim, eps=eps)) 331 | self.ffn.append( 332 | FFN[f_type]( 333 | emb_dim, 334 | f_dim, 335 | conv1d_kernel, 336 | conv1d_shift, 337 | dropout=dropout, 338 | ) 339 | ) 340 | 341 | # initialize self-attention 342 | if norm_type == "rmsgroupnorm": 343 | self.attn_norm = Norm[norm_type](num_groups, emb_dim, eps=eps) 344 | else: 345 | self.attn_norm = Norm[norm_type](emb_dim, eps=eps) 346 | self.attn = MultiHeadSelfAttention( 347 | emb_dim, 348 | attention_dim=attention_dim, 349 | n_heads=n_heads, 350 | rope=rope, 351 | dropout=dropout, 352 | flash_attention=flash_attention, 353 | ) 354 | 355 | def forward(self, x): 356 | """Locoformer block Forward. 357 | 358 | Args: 359 | x: torch.Tensor 360 | Input tensor, (n_batch, seq1, seq2, channel) 361 | seq1 (or seq2) is either the number of frames or freqs 362 | """ 363 | B, T, F, C = x.shape 364 | 365 | if self.macaron_style: 366 | # FFN before self-attention 367 | # Note that this implementation does not include the 1/2 factor described in the paper. 368 | # Experiments in the paper did use the 1/2 factor, but we removed it by mistake in this 369 | # implementation. We found that the 1/2 factor does not impact final performance, and 370 | # thus decided to keep the current implementation for consistency with the pre-trained 371 | # models that we provide. 372 | input_ = x 373 | output = self.ffn_norm[-1](x) # [B, T, F, C] 374 | output = self.ffn[-1](output) # [B, T, F, C] 375 | output = output + input_ 376 | else: 377 | output = x 378 | 379 | # Self-attention 380 | input_ = output 381 | output = self.attn_norm(output) 382 | output = output.view([B * T, F, C]) 383 | output = self.attn(output) 384 | output = output.view([B, T, F, C]) + input_ 385 | 386 | # FFN after self-attention 387 | input_ = output 388 | output = self.ffn_norm[0](output) # [B, T, F, C] 389 | output = self.ffn[0](output) # [B, T, F, C] 390 | output = output + input_ 391 | 392 | return output 393 | 394 | 395 | class MultiHeadSelfAttention(nn.Module): 396 | def __init__( 397 | self, 398 | emb_dim, 399 | attention_dim, 400 | n_heads=8, 401 | dropout=0.0, 402 | rope=None, 403 | flash_attention=False, 404 | ): 405 | super().__init__() 406 | 407 | self.n_heads = n_heads 408 | self.dropout = dropout 409 | 410 | self.rope = rope 411 | self.qkv = nn.Linear(emb_dim, attention_dim * 3, bias=False) 412 | self.aggregate_heads = nn.Sequential(nn.Linear(attention_dim, emb_dim, bias=False), nn.Dropout(dropout)) 413 | 414 | if flash_attention: 415 | self.flash_attention_config = dict(enable_flash=True, enable_math=False, enable_mem_efficient=False) 416 | else: 417 | self.flash_attention_config = dict(enable_flash=False, enable_math=True, enable_mem_efficient=True) 418 | 419 | def forward(self, input): 420 | # get query, key, and value 421 | query, key, value = self.get_qkv(input) 422 | 423 | # rotary positional encoding 424 | query, key = self.apply_rope(query, key) 425 | 426 | # pytorch 2.0 flash attention: q, k, v, mask, dropout, softmax_scale 427 | with torch.backends.cuda.sdp_kernel(**self.flash_attention_config): 428 | output = F.scaled_dot_product_attention( 429 | query=query, 430 | key=key, 431 | value=value, 432 | attn_mask=None, 433 | dropout_p=self.dropout if self.training else 0.0, 434 | ) # (batch, head, seq_len, -1) 435 | 436 | output = output.transpose(1, 2) # (batch, seq_len, head, -1) 437 | output = output.reshape(output.shape[:2] + (-1,)) 438 | return self.aggregate_heads(output) 439 | 440 | def get_qkv(self, input): 441 | n_batch, seq_len = input.shape[:2] 442 | x = self.qkv(input).reshape(n_batch, seq_len, 3, self.n_heads, -1) 443 | x = x.movedim(-2, 1) # (batch, head, seq_len, 3, -1) 444 | query, key, value = x[..., 0, :], x[..., 1, :], x[..., 2, :] 445 | return query, key, value 446 | 447 | @torch.cuda.amp.autocast(enabled=False) 448 | def apply_rope(self, query, key): 449 | query = self.rope.rotate_queries_or_keys(query) 450 | key = self.rope.rotate_queries_or_keys(key) 451 | return query, key 452 | 453 | 454 | class ConvDeconv1d(nn.Module): 455 | def __init__(self, dim, dim_inner, conv1d_kernel, conv1d_shift, dropout=0.0, **kwargs): 456 | super().__init__() 457 | 458 | self.diff_ks = conv1d_kernel - conv1d_shift 459 | 460 | self.net = nn.Sequential( 461 | nn.Conv1d(dim, dim_inner, conv1d_kernel, stride=conv1d_shift), 462 | nn.SiLU(inplace=True), 463 | nn.Dropout(dropout), 464 | nn.ConvTranspose1d(dim_inner, dim, conv1d_kernel, stride=conv1d_shift), 465 | nn.Dropout(dropout), 466 | ) 467 | 468 | def forward(self, x): 469 | """ConvDeconv1d forward 470 | 471 | Args: 472 | x: torch.Tensor 473 | Input tensor, (n_batch, seq1, seq2, channel) 474 | seq1 (or seq2) is either the number of frames or freqs 475 | """ 476 | b, s1, s2, h = x.shape 477 | x = x.view(b * s1, s2, h) 478 | x = x.transpose(-1, -2) 479 | x = self.net(x).transpose(-1, -2) 480 | x = x[..., self.diff_ks // 2 : self.diff_ks // 2 + s2, :] 481 | return x.view(b, s1, s2, h) 482 | 483 | 484 | class SwiGLUConvDeconv1d(nn.Module): 485 | def __init__(self, dim, dim_inner, conv1d_kernel, conv1d_shift, dropout=0.0, **kwargs): 486 | super().__init__() 487 | 488 | self.conv1d = nn.Conv1d(dim, dim_inner * 2, conv1d_kernel, stride=conv1d_shift) 489 | 490 | self.swish = nn.SiLU() 491 | self.deconv1d = nn.ConvTranspose1d(dim_inner, dim, conv1d_kernel, stride=conv1d_shift) 492 | self.dropout = nn.Dropout(dropout) 493 | self.dim_inner = dim_inner 494 | self.diff_ks = conv1d_kernel - conv1d_shift 495 | self.conv1d_kernel = conv1d_kernel 496 | self.conv1d_shift = conv1d_shift 497 | 498 | def forward(self, x): 499 | """SwiGLUConvDeconv1d forward 500 | 501 | Args: 502 | x: torch.Tensor 503 | Input tensor, (n_batch, seq1, seq2, channel) 504 | seq1 (or seq2) is either the number of frames or freqs 505 | """ 506 | b, s1, s2, h = x.shape 507 | x = x.contiguous().view(b * s1, s2, h) 508 | x = x.transpose(-1, -2) 509 | 510 | # padding 511 | seq_len = ( 512 | math.ceil((s2 + 2 * self.diff_ks - self.conv1d_kernel) / self.conv1d_shift) * self.conv1d_shift 513 | + self.conv1d_kernel 514 | ) 515 | x = F.pad(x, (self.diff_ks, seq_len - s2 - self.diff_ks)) 516 | 517 | # conv-deconv1d 518 | x = self.conv1d(x) 519 | gate = self.swish(x[..., self.dim_inner :, :]) 520 | x = x[..., : self.dim_inner, :] * gate 521 | x = self.dropout(x) 522 | x = self.deconv1d(x).transpose(-1, -2) 523 | 524 | # cut necessary part 525 | x = x[..., self.diff_ks : self.diff_ks + s2, :] 526 | return self.dropout(x).view(b, s1, s2, h) 527 | 528 | 529 | class RMSGroupNorm(nn.Module): 530 | def __init__(self, num_groups, dim, eps=1e-8, bias=False): 531 | """ 532 | Root Mean Square Group Normalization (RMSGroupNorm). 533 | Unlike Group Normalization in vision, RMSGroupNorm 534 | is applied to each TF bin. 535 | 536 | Args: 537 | num_groups: int 538 | Number of groups 539 | dim: int 540 | Number of dimensions 541 | eps: float 542 | Small constant to avoid division by zero. 543 | bias: bool 544 | Whether to add a bias term. RMSNorm does not use bias. 545 | 546 | """ 547 | super().__init__() 548 | 549 | assert dim % num_groups == 0, (dim, num_groups) 550 | self.num_groups = num_groups 551 | self.dim_per_group = dim // self.num_groups 552 | 553 | self.gamma = nn.Parameter(torch.Tensor(dim).to(torch.float32)) 554 | nn.init.ones_(self.gamma) 555 | 556 | self.bias = bias 557 | if self.bias: 558 | self.beta = nn.Parameter(torch.Tensor(dim).to(torch.float32)) 559 | nn.init.zeros_(self.beta) 560 | self.eps = eps 561 | self.num_groups = num_groups 562 | 563 | @torch.cuda.amp.autocast(enabled=False) 564 | def forward(self, input): 565 | others = input.shape[:-1] 566 | input = input.view(others + (self.num_groups, self.dim_per_group)) 567 | 568 | # normalization 569 | norm_ = input.norm(2, dim=-1, keepdim=True) 570 | rms = norm_ * self.dim_per_group ** (-1.0 / 2) 571 | output = input / (rms + self.eps) 572 | 573 | # reshape and affine transformation 574 | output = output.view(others + (-1,)) 575 | output = output * self.gamma 576 | if self.bias: 577 | output = output + self.beta 578 | 579 | return output 580 | -------------------------------------------------------------------------------- /espnet2/tasks/enh.patch: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # Copyright (C) 2017 ESPnet Developers 3 | # 4 | # SPDX-License-Identifier: Apache-2.0 5 | # SPDX-License-Identifier: Apache-2.0 6 | 7 | 8 | --- ./espnet/espnet2/tasks/enh_org.py 2024-05-30 13:30:00.865662269 -0400 9 | +++ ./espnet/espnet2/tasks/enh.py 2024-05-30 13:30:56.161269346 -0400 10 | @@ -63,6 +63,7 @@ from espnet2.enh.separator.svoice_separa 11 | from espnet2.enh.separator.tcn_separator import TCNSeparator 12 | from espnet2.enh.separator.tfgridnet_separator import TFGridNet 13 | from espnet2.enh.separator.tfgridnetv2_separator import TFGridNetV2 14 | +from espnet2.enh.separator.tflocoformer_separator import TFLocoformerSeparator 15 | from espnet2.enh.separator.transformer_separator import TransformerSeparator 16 | from espnet2.enh.separator.uses_separator import USESSeparator 17 | from espnet2.iterators.abs_iter_factory import AbsIterFactory 18 | @@ -112,6 +113,7 @@ separator_choices = ClassChoices( 19 | tfgridnet=TFGridNet, 20 | tfgridnetv2=TFGridNetV2, 21 | uses=USESSeparator, 22 | + tflocoformer=TFLocoformerSeparator, 23 | ), 24 | type_check=AbsSeparator, 25 | default="rnn", 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024 Mitsubishi Electric Research Laboratories (MERL) 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | espnet 6 | rotary-embedding-torch==0.6.1 7 | --------------------------------------------------------------------------------