├── .bumpversion.cfg ├── .github └── workflows │ └── tests.yml ├── .gitignore ├── .pylintrc ├── .travis.yml ├── LICENSE ├── README.rst ├── data ├── AMI_WSJ20-Array1-1_T10c0201.wav ├── AMI_WSJ20-Array1-2_T10c0201.wav ├── AMI_WSJ20-Array1-3_T10c0201.wav ├── AMI_WSJ20-Array1-4_T10c0201.wav ├── AMI_WSJ20-Array1-5_T10c0201.wav ├── AMI_WSJ20-Array1-6_T10c0201.wav ├── AMI_WSJ20-Array1-7_T10c0201.wav └── AMI_WSJ20-Array1-8_T10c0201.wav ├── docs ├── Makefile ├── conf.py ├── index.rst ├── make.bat ├── make_apidoc.sh ├── modules.rst ├── nara_wpe.benchmark_online_wpe.rst ├── nara_wpe.gradient_overrides.rst ├── nara_wpe.rst ├── nara_wpe.test_utils.rst ├── nara_wpe.tf_wpe.rst ├── nara_wpe.utils.rst └── nara_wpe.wpe.rst ├── examples ├── NTT_wrapper_offline.ipynb ├── WPE_Numpy_offline.ipynb ├── WPE_Numpy_online.ipynb ├── WPE_Tensorflow_offline.ipynb ├── WPE_Tensorflow_online.ipynb └── examples.rst ├── maintenance.md ├── nara_wpe ├── __init__.py ├── benchmark_online_wpe.py ├── ntt_wpe.py ├── test_utils.py ├── tf_wpe.py ├── torch_wpe.py ├── utils.py └── wpe.py ├── pyproject.toml ├── pytest.ini ├── setup.py └── tests ├── test_notebooks.py ├── test_tf_wpe.py └── test_wpe.py /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.0.11 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | [push, pull_request] 5 | 6 | jobs: 7 | build: 8 | 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] 13 | os: [ubuntu-latest] 14 | include: 15 | - os: ubuntu-22.04 16 | python-version: '3.7' 17 | - os: macos-latest 18 | python-version: '3.12' 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Install linux dependencies 29 | run: | 30 | sudo apt-get update 31 | sudo apt-get install libsndfile1 sox 32 | sudo apt-get install libzmq3-dev 33 | pip install numpy pyzmq # pymatbridge needs numpy and pyzmq preinstalled (i.e. does not work to list in setup.py) 34 | # Install of current version of pymatbridge on pypi does not work, using the git-repository instead 35 | pip install git+https://github.com/arokem/python-matlab-bridge.git@master 36 | if: matrix.os != 'macos-latest' 37 | - name: Install macos dependencies 38 | run: | 39 | brew install libsndfile 40 | pip install git+https://github.com/arokem/python-matlab-bridge.git@master 41 | if: matrix.os == 'macos-latest' 42 | - name: Install nara_wpe 43 | run: | 44 | pip install -e .[test] 45 | 46 | - name: Test with pytest 47 | run: | 48 | pytest "tests/" "nara_wpe/" 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.egg-info 3 | *.pyc 4 | .idea/ 5 | docs/_build 6 | docs/_templates 7 | .ipynb_checkpoints/ 8 | cache 9 | .pytest_cache 10 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | 3 | # Enable the message, report, category or checker with the given id(s). You can 4 | # either give multiple identifier separated by comma (,) or put this option 5 | # multiple time. 6 | #enable= 7 | 8 | # Disable the message, report, category or checker with the given id(s). You 9 | # can either give multiple identifier separated by comma (,) or put this option 10 | # multiple time (only on the command line, not in the configuration file where 11 | # it should appear only once). 12 | # C0103:Invalid variable name 13 | disable=C0103, 14 | 15 | extension-pkg-whitelist=numpy,scipy 16 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - 2.7 4 | - 3.5 5 | - 3.6 6 | - 3.7 7 | - 3.8 8 | - 3.9 9 | 10 | cache: pip 11 | 12 | install: 13 | - pip install -e . 14 | - pip install tensorflow==1.12.0 15 | - pip install coverage 16 | - pip install jupyter 17 | - pip install matplotlib 18 | - pip install scipy 19 | 20 | script: 21 | - pytest tests/ 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Communications Engineering Group, Paderborn University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | nara_wpe 3 | ======== 4 | 5 | .. image:: https://readthedocs.org/projects/nara-wpe/badge/?version=latest 6 | :target: http://nara-wpe.readthedocs.io/en/latest/ 7 | :alt: Documentation Status 8 | 9 | .. image:: https://github.com/fgnt/nara_wpe/actions/workflows/tests.yml/badge.svg?branch=master 10 | :target: https://github.com/fgnt/nara_wpe/actions/workflows/tests.yml 11 | :alt: Tests 12 | 13 | .. image:: https://img.shields.io/pypi/v/nara-wpe.svg 14 | :target: https://pypi.org/project/nara-wpe/ 15 | :alt: PyPI 16 | 17 | .. image:: https://img.shields.io/pypi/dm/nara-wpe.svg 18 | :target: https://pypi.org/project/nara-wpe/ 19 | :alt: PyPI 20 | 21 | .. image:: https://img.shields.io/badge/license-MIT-blue.svg 22 | :target: https://raw.githubusercontent.com/fgnt/nara_wpe/master/LICENSE 23 | :alt: MIT License 24 | 25 | Weighted Prediction Error for speech dereverberation 26 | ==================================================== 27 | 28 | Background noise and signal reverberation due to reflections in an enclosure are the two main impairments in acoustic 29 | signal processing and far-field speech recognition. This work addresses signal dereverberation techniques based on WPE for speech recognition and other far-field applications. 30 | WPE is a compelling algorithm to blindly dereverberate acoustic signals based on long-term linear prediction. 31 | 32 | The main algorithm is based on the following paper: 33 | Yoshioka, Takuya, and Tomohiro Nakatani. "Generalization of multi-channel linear prediction methods for blind MIMO impulse response shortening." IEEE Transactions on Audio, Speech, and Language Processing 20.10 (2012): 2707-2720. 34 | 35 | 36 | Content 37 | ======= 38 | 39 | - Iterative offline WPE/ block-online WPE/ recursive frame-online WPE 40 | - All algorithms implemented both in Numpy and in TensorFlow (works with version `1.12.0`). 41 | - Continuously tested with Python 3.7, 3.8, 3.9 and 3.10. 42 | - Automatically built documentation: `nara-wpe.readthedocs.io `_ 43 | - Modular design to facilitate changes for further research 44 | 45 | Installation 46 | ============ 47 | 48 | Install it directly with Pip, if you just want to use it: 49 | 50 | .. code-block:: bash 51 | 52 | pip install nara_wpe 53 | 54 | If you want to make changes or want the most recent version: Clone the repository and install it as follows: 55 | 56 | .. code-block:: bash 57 | 58 | git clone https://github.com/fgnt/nara_wpe.git 59 | cd nara_wpe 60 | pip install --editable . 61 | 62 | Check the `example notebook `_ for further details. 63 | If you download the example notebook, you can listen to the input audio examples and to the dereverberated output too. 64 | 65 | 66 | Citation 67 | ======== 68 | 69 | To cite this implementation, you can cite the following paper:: 70 | 71 | @InProceedings{Drude2018NaraWPE, 72 | Title = {{NARA-WPE}: A Python package for weighted prediction error dereverberation in {Numpy} and {Tensorflow} for online and offline processing}, 73 | Author = {Drude, Lukas and Heymann, Jahn and Boeddeker, Christoph and Haeb-Umbach, Reinhold}, 74 | Booktitle = {13. ITG Fachtagung Sprachkommunikation (ITG 2018)}, 75 | Year = {2018}, 76 | Month = {Oct}, 77 | } 78 | 79 | 80 | To view the paper see 81 | `IEEE Xplore `__ (`PDF `__) 82 | or for a preview see `Paderborn University RIS `__ (`PDF `__). 83 | 84 | 85 | 86 | Comparision with the NTT WPE implementation 87 | =========================================== 88 | 89 | The fairly recent John Hopkins University paper (Manohar, Vimal: `Acoustic Modeling for Overlapping Speech Recognition: JHU CHiME-5 Challenge System `_, ICASSP 2019) reporting on their CHiME 5 challenge results dedicate an entire table to the comparison of the Nara-WPE implementation and the NTT WPE implementation. 90 | Their result is, that the Nara-WPE implementation is as least as good as the NTT WPE implementation in all their reported conditions. 91 | 92 | 93 | Development history 94 | ==================== 95 | 96 | Since 2017-09-05 a TensorFlow implementation has been added to `nara_wpe`. 97 | It has been tested with a few test cases against the Numpy implementation. 98 | 99 | The first version of the Numpy implementation was written in June 2017 while 100 | Lukas Drude and Kateřina Žmolíková resided in Nara, Japan. The aim was to have 101 | a publicly available implementation of Takuya Yoshioka's 2012 paper. 102 | -------------------------------------------------------------------------------- /data/AMI_WSJ20-Array1-1_T10c0201.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-1_T10c0201.wav -------------------------------------------------------------------------------- /data/AMI_WSJ20-Array1-2_T10c0201.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-2_T10c0201.wav -------------------------------------------------------------------------------- /data/AMI_WSJ20-Array1-3_T10c0201.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-3_T10c0201.wav -------------------------------------------------------------------------------- /data/AMI_WSJ20-Array1-4_T10c0201.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-4_T10c0201.wav -------------------------------------------------------------------------------- /data/AMI_WSJ20-Array1-5_T10c0201.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-5_T10c0201.wav -------------------------------------------------------------------------------- /data/AMI_WSJ20-Array1-6_T10c0201.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-6_T10c0201.wav -------------------------------------------------------------------------------- /data/AMI_WSJ20-Array1-7_T10c0201.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-7_T10c0201.wav -------------------------------------------------------------------------------- /data/AMI_WSJ20-Array1-8_T10c0201.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-8_T10c0201.wav -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help 23 | help: 24 | @echo "Please use \`make ' where is one of" 25 | @echo " html to make standalone HTML files" 26 | @echo " dirhtml to make HTML files named index.html in directories" 27 | @echo " singlehtml to make a single large HTML file" 28 | @echo " pickle to make pickle files" 29 | @echo " json to make JSON files" 30 | @echo " htmlhelp to make HTML files and a HTML help project" 31 | @echo " qthelp to make HTML files and a qthelp project" 32 | @echo " applehelp to make an Apple Help Book" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | @echo " coverage to run coverage check of the documentation (if enabled)" 49 | 50 | .PHONY: clean 51 | clean: 52 | rm -rf $(BUILDDIR)/* 53 | 54 | .PHONY: html 55 | html: 56 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 57 | @echo 58 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 59 | 60 | .PHONY: dirhtml 61 | dirhtml: 62 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 63 | @echo 64 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 65 | 66 | .PHONY: singlehtml 67 | singlehtml: 68 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 69 | @echo 70 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 71 | 72 | .PHONY: pickle 73 | pickle: 74 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 75 | @echo 76 | @echo "Build finished; now you can process the pickle files." 77 | 78 | .PHONY: json 79 | json: 80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 81 | @echo 82 | @echo "Build finished; now you can process the JSON files." 83 | 84 | .PHONY: htmlhelp 85 | htmlhelp: 86 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 87 | @echo 88 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 89 | ".hhp project file in $(BUILDDIR)/htmlhelp." 90 | 91 | .PHONY: qthelp 92 | qthelp: 93 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 94 | @echo 95 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 96 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 97 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/nara_wpe.qhcp" 98 | @echo "To view the help file:" 99 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/nara_wpe.qhc" 100 | 101 | .PHONY: applehelp 102 | applehelp: 103 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 104 | @echo 105 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 106 | @echo "N.B. You won't be able to view it unless you put it in" \ 107 | "~/Library/Documentation/Help or install it in your application" \ 108 | "bundle." 109 | 110 | .PHONY: devhelp 111 | devhelp: 112 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 113 | @echo 114 | @echo "Build finished." 115 | @echo "To view the help file:" 116 | @echo "# mkdir -p $$HOME/.local/share/devhelp/nara_wpe" 117 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/nara_wpe" 118 | @echo "# devhelp" 119 | 120 | .PHONY: epub 121 | epub: 122 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 123 | @echo 124 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 125 | 126 | .PHONY: latex 127 | latex: 128 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 129 | @echo 130 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 131 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 132 | "(use \`make latexpdf' here to do that automatically)." 133 | 134 | .PHONY: latexpdf 135 | latexpdf: 136 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 137 | @echo "Running LaTeX files through pdflatex..." 138 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 139 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 140 | 141 | .PHONY: latexpdfja 142 | latexpdfja: 143 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 144 | @echo "Running LaTeX files through platex and dvipdfmx..." 145 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 146 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 147 | 148 | .PHONY: text 149 | text: 150 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 151 | @echo 152 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 153 | 154 | .PHONY: man 155 | man: 156 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 157 | @echo 158 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 159 | 160 | .PHONY: texinfo 161 | texinfo: 162 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 163 | @echo 164 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 165 | @echo "Run \`make' in that directory to run these through makeinfo" \ 166 | "(use \`make info' here to do that automatically)." 167 | 168 | .PHONY: info 169 | info: 170 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 171 | @echo "Running Texinfo files through makeinfo..." 172 | make -C $(BUILDDIR)/texinfo info 173 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 174 | 175 | .PHONY: gettext 176 | gettext: 177 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 178 | @echo 179 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 180 | 181 | .PHONY: changes 182 | changes: 183 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 184 | @echo 185 | @echo "The overview file is in $(BUILDDIR)/changes." 186 | 187 | .PHONY: linkcheck 188 | linkcheck: 189 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 190 | @echo 191 | @echo "Link check complete; look for any errors in the above output " \ 192 | "or in $(BUILDDIR)/linkcheck/output.txt." 193 | 194 | .PHONY: doctest 195 | doctest: 196 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 197 | @echo "Testing of doctests in the sources finished, look at the " \ 198 | "results in $(BUILDDIR)/doctest/output.txt." 199 | 200 | .PHONY: coverage 201 | coverage: 202 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 203 | @echo "Testing of coverage in the sources finished, look at the " \ 204 | "results in $(BUILDDIR)/coverage/python.txt." 205 | 206 | .PHONY: xml 207 | xml: 208 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 209 | @echo 210 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 211 | 212 | .PHONY: pseudoxml 213 | pseudoxml: 214 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 215 | @echo 216 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 217 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # nara_wpe documentation build configuration file, created by 4 | # sphinx-quickstart on Thu May 24 11:59:41 2018. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | 18 | sys.path.insert(0, os.path.abspath('../nara_wpe')) 19 | 20 | # -- General configuration ------------------------------------------------ 21 | 22 | extensions = [ 23 | 'sphinx.ext.autodoc', 24 | 'sphinx.ext.doctest', 25 | 'sphinx.ext.todo', 26 | 'sphinx.ext.mathjax', 27 | 'sphinx.ext.napoleon', 28 | 'sphinx.ext.viewcode', 29 | ] 30 | 31 | # Add any paths that contain templates here, relative to this directory. 32 | templates_path = ['_templates'] 33 | 34 | # The suffix(es) of source filenames. 35 | source_suffix = '.rst' 36 | 37 | # The encoding of source files. 38 | #source_encoding = 'utf-8-sig' 39 | 40 | # The master toctree document. 41 | master_doc = 'index' 42 | 43 | # General information about the project. 44 | project = u'nara_wpe' 45 | copyright = u'2018, Lukas Drude' 46 | author = u'Lukas Drude' 47 | 48 | # The version info for the project you're documenting, acts as replacement for 49 | # |version| and |release|, also used in various other places throughout the 50 | # built documents. 51 | # 52 | # The short X.Y version. 53 | version = u'0.0.0' 54 | # The full version, including alpha/beta/rc tags. 55 | release = u'0.0.0' 56 | 57 | # The language for content autogenerated by Sphinx. Refer to documentation 58 | # for a list of supported languages. 59 | # 60 | # This is also used if you do content translation via gettext catalogs. 61 | # Usually you set "language" from the command line for these cases. 62 | language = None 63 | 64 | # There are two options for replacing |today|: either, you set today to some 65 | # non-false value, then it is used: 66 | #today = '' 67 | # Else, today_fmt is used as the format for a strftime call. 68 | #today_fmt = '%B %d, %Y' 69 | 70 | # List of patterns, relative to source directory, that match files and 71 | # directories to ignore when looking for source files. 72 | exclude_patterns = ['_build'] 73 | 74 | autodoc_mock_imports = ['pathlib', 'nara_wpe', 'nara_wpe.benchmark_online_wpe', 75 | 'nara_wpe.gradient_overrides', 'nara_wpe.test_utils', 76 | 'nara_wpe.tf_wpe', 'nara_wpe.utils', 'nara_wpe.wpe', 77 | 'numpy', 'pandas', 'tensorflow', ] 78 | 79 | 80 | # The reST default role (used for this markup: `text`) to use for all 81 | # documents. 82 | #default_role = None 83 | 84 | # If true, '()' will be appended to :func: etc. cross-reference text. 85 | #add_function_parentheses = True 86 | 87 | # If true, the current module name will be prepended to all description 88 | # unit titles (such as .. function::). 89 | #add_module_names = True 90 | 91 | # If true, sectionauthor and moduleauthor directives will be shown in the 92 | # output. They are ignored by default. 93 | #show_authors = False 94 | 95 | # The name of the Pygments (syntax highlighting) style to use. 96 | pygments_style = 'sphinx' 97 | 98 | # A list of ignored prefixes for module index sorting. 99 | #modindex_common_prefix = [] 100 | 101 | # If true, keep warnings as "system message" paragraphs in the built documents. 102 | #keep_warnings = False 103 | 104 | # If true, `todo` and `todoList` produce output, else they produce nothing. 105 | todo_include_todos = True 106 | 107 | 108 | # -- Options for HTML output ---------------------------------------------- 109 | 110 | # The theme to use for HTML and HTML Help pages. See the documentation for 111 | # a list of builtin themes. 112 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 113 | 114 | if not on_rtd: # only import and set the theme if we're building docs locally 115 | import sphinx_rtd_theme 116 | html_theme = 'sphinx_rtd_theme' 117 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 118 | 119 | 120 | # Theme options are theme-specific and customize the look and feel of a theme 121 | # further. For a list of options available for each theme, see the 122 | # documentation. 123 | #html_theme_options = {} 124 | 125 | # Add any paths that contain custom themes here, relative to this directory. 126 | #html_theme_path = [] 127 | 128 | # The name for this set of Sphinx documents. If None, it defaults to 129 | # " v documentation". 130 | #html_title = None 131 | 132 | # A shorter title for the navigation bar. Default is the same as html_title. 133 | #html_short_title = None 134 | 135 | # The name of an image file (relative to this directory) to place at the top 136 | # of the sidebar. 137 | #html_logo = None 138 | 139 | # The name of an image file (relative to this directory) to use as a favicon of 140 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 141 | # pixels large. 142 | #html_favicon = None 143 | 144 | # Add any paths that contain custom static files (such as style sheets) here, 145 | # relative to this directory. They are copied after the builtin static files, 146 | # so a file named "default.css" will overwrite the builtin "default.css". 147 | html_static_path = ['_static'] 148 | 149 | # Add any extra paths that contain custom files (such as robots.txt or 150 | # .htaccess) here, relative to this directory. These files are copied 151 | # directly to the root of the documentation. 152 | #html_extra_path = [] 153 | 154 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 155 | # using the given strftime format. 156 | #html_last_updated_fmt = '%b %d, %Y' 157 | 158 | # If true, SmartyPants will be used to convert quotes and dashes to 159 | # typographically correct entities. 160 | #html_use_smartypants = True 161 | 162 | # Custom sidebar templates, maps document names to template names. 163 | #html_sidebars = {} 164 | 165 | # Additional templates that should be rendered to pages, maps page names to 166 | # template names. 167 | #html_additional_pages = {} 168 | 169 | # If false, no module index is generated. 170 | #html_domain_indices = True 171 | 172 | # If false, no index is generated. 173 | #html_use_index = True 174 | 175 | # If true, the index is split into individual pages for each letter. 176 | #html_split_index = False 177 | 178 | # If true, links to the reST sources are added to the pages. 179 | #html_show_sourcelink = True 180 | 181 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 182 | #html_show_sphinx = True 183 | 184 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 185 | #html_show_copyright = True 186 | 187 | # If true, an OpenSearch description file will be output, and all pages will 188 | # contain a tag referring to it. The value of this option must be the 189 | # base URL from which the finished HTML is served. 190 | #html_use_opensearch = '' 191 | 192 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 193 | #html_file_suffix = None 194 | 195 | # Language to be used for generating the HTML full-text search index. 196 | # Sphinx supports the following languages: 197 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' 198 | # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' 199 | #html_search_language = 'en' 200 | 201 | # A dictionary with options for the search language support, empty by default. 202 | # Now only 'ja' uses this config value 203 | #html_search_options = {'type': 'default'} 204 | 205 | # The name of a javascript file (relative to the configuration directory) that 206 | # implements a search results scorer. If empty, the default will be used. 207 | #html_search_scorer = 'scorer.js' 208 | 209 | # Output file base name for HTML help builder. 210 | htmlhelp_basename = 'nara_wpedoc' 211 | 212 | # -- Options for LaTeX output --------------------------------------------- 213 | 214 | latex_elements = { 215 | # The paper size ('letterpaper' or 'a4paper'). 216 | #'papersize': 'letterpaper', 217 | 218 | # The font size ('10pt', '11pt' or '12pt'). 219 | #'pointsize': '10pt', 220 | 221 | # Additional stuff for the LaTeX preamble. 222 | #'preamble': '', 223 | 224 | # Latex figure (float) alignment 225 | #'figure_align': 'htbp', 226 | } 227 | 228 | # Grouping the document tree into LaTeX files. List of tuples 229 | # (source start file, target name, title, 230 | # author, documentclass [howto, manual, or own class]). 231 | latex_documents = [ 232 | (master_doc, 'nara_wpe.tex', u'nara\\_wpe Documentation', 233 | u'Lukas Drude', 'manual'), 234 | ] 235 | 236 | # The name of an image file (relative to this directory) to place at the top of 237 | # the title page. 238 | #latex_logo = None 239 | 240 | # For "manual" documents, if this is true, then toplevel headings are parts, 241 | # not chapters. 242 | #latex_use_parts = False 243 | 244 | # If true, show page references after internal links. 245 | #latex_show_pagerefs = False 246 | 247 | # If true, show URL addresses after external links. 248 | #latex_show_urls = False 249 | 250 | # Documents to append as an appendix to all manuals. 251 | #latex_appendices = [] 252 | 253 | # If false, no module index is generated. 254 | #latex_domain_indices = True 255 | 256 | 257 | # -- Options for manual page output --------------------------------------- 258 | 259 | # One entry per manual page. List of tuples 260 | # (source start file, name, description, authors, manual section). 261 | man_pages = [ 262 | (master_doc, 'nara_wpe', u'nara_wpe Documentation', 263 | [author], 1) 264 | ] 265 | 266 | # If true, show URL addresses after external links. 267 | #man_show_urls = False 268 | 269 | 270 | # -- Options for Texinfo output ------------------------------------------- 271 | 272 | # Grouping the document tree into Texinfo files. List of tuples 273 | # (source start file, target name, title, author, 274 | # dir menu entry, description, category) 275 | texinfo_documents = [ 276 | (master_doc, 'nara_wpe', u'nara_wpe Documentation', 277 | author, 'nara_wpe', 'One line description of project.', 278 | 'Miscellaneous'), 279 | ] 280 | 281 | # Documents to append as an appendix to all manuals. 282 | #texinfo_appendices = [] 283 | 284 | # If false, no module index is generated. 285 | #texinfo_domain_indices = True 286 | 287 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 288 | #texinfo_show_urls = 'footnote' 289 | 290 | # If true, do not generate a @detailmenu in the "Top" node's menu. 291 | #texinfo_no_detailmenu = False 292 | 293 | 294 | # Example configuration for intersphinx: refer to the Python standard library. 295 | intersphinx_mapping = {'https://docs.python.org/': None} 296 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. nara_wpe documentation master file, created by 2 | sphinx-quickstart on Thu May 24 11:59:41 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | .. include:: ../README.rst 7 | 8 | Welcome to nara_wpe's documentation! 9 | ==================================== 10 | 11 | Table of contents: 12 | 13 | .. toctree:: 14 | :maxdepth: 2 15 | 16 | nara_wpe 17 | 18 | 19 | 20 | Indices and tables 21 | ================== 22 | 23 | * :ref:`genindex` 24 | * :ref:`modindex` 25 | * :ref:`search` 26 | 27 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | echo. coverage to run coverage check of the documentation if enabled 41 | goto end 42 | ) 43 | 44 | if "%1" == "clean" ( 45 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 46 | del /q /s %BUILDDIR%\* 47 | goto end 48 | ) 49 | 50 | 51 | REM Check if sphinx-build is available and fallback to Python version if any 52 | %SPHINXBUILD% 1>NUL 2>NUL 53 | if errorlevel 9009 goto sphinx_python 54 | goto sphinx_ok 55 | 56 | :sphinx_python 57 | 58 | set SPHINXBUILD=python -m sphinx.__init__ 59 | %SPHINXBUILD% 2> nul 60 | if errorlevel 9009 ( 61 | echo. 62 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 63 | echo.installed, then set the SPHINXBUILD environment variable to point 64 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 65 | echo.may add the Sphinx directory to PATH. 66 | echo. 67 | echo.If you don't have Sphinx installed, grab it from 68 | echo.http://sphinx-doc.org/ 69 | exit /b 1 70 | ) 71 | 72 | :sphinx_ok 73 | 74 | 75 | if "%1" == "html" ( 76 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 77 | if errorlevel 1 exit /b 1 78 | echo. 79 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 80 | goto end 81 | ) 82 | 83 | if "%1" == "dirhtml" ( 84 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 85 | if errorlevel 1 exit /b 1 86 | echo. 87 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 88 | goto end 89 | ) 90 | 91 | if "%1" == "singlehtml" ( 92 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 93 | if errorlevel 1 exit /b 1 94 | echo. 95 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 96 | goto end 97 | ) 98 | 99 | if "%1" == "pickle" ( 100 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 101 | if errorlevel 1 exit /b 1 102 | echo. 103 | echo.Build finished; now you can process the pickle files. 104 | goto end 105 | ) 106 | 107 | if "%1" == "json" ( 108 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 109 | if errorlevel 1 exit /b 1 110 | echo. 111 | echo.Build finished; now you can process the JSON files. 112 | goto end 113 | ) 114 | 115 | if "%1" == "htmlhelp" ( 116 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 117 | if errorlevel 1 exit /b 1 118 | echo. 119 | echo.Build finished; now you can run HTML Help Workshop with the ^ 120 | .hhp project file in %BUILDDIR%/htmlhelp. 121 | goto end 122 | ) 123 | 124 | if "%1" == "qthelp" ( 125 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 129 | .qhcp project file in %BUILDDIR%/qthelp, like this: 130 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\nara_wpe.qhcp 131 | echo.To view the help file: 132 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\nara_wpe.ghc 133 | goto end 134 | ) 135 | 136 | if "%1" == "devhelp" ( 137 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 138 | if errorlevel 1 exit /b 1 139 | echo. 140 | echo.Build finished. 141 | goto end 142 | ) 143 | 144 | if "%1" == "epub" ( 145 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 146 | if errorlevel 1 exit /b 1 147 | echo. 148 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 149 | goto end 150 | ) 151 | 152 | if "%1" == "latex" ( 153 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 154 | if errorlevel 1 exit /b 1 155 | echo. 156 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 157 | goto end 158 | ) 159 | 160 | if "%1" == "latexpdf" ( 161 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 162 | cd %BUILDDIR%/latex 163 | make all-pdf 164 | cd %~dp0 165 | echo. 166 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 167 | goto end 168 | ) 169 | 170 | if "%1" == "latexpdfja" ( 171 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 172 | cd %BUILDDIR%/latex 173 | make all-pdf-ja 174 | cd %~dp0 175 | echo. 176 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 177 | goto end 178 | ) 179 | 180 | if "%1" == "text" ( 181 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 182 | if errorlevel 1 exit /b 1 183 | echo. 184 | echo.Build finished. The text files are in %BUILDDIR%/text. 185 | goto end 186 | ) 187 | 188 | if "%1" == "man" ( 189 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 190 | if errorlevel 1 exit /b 1 191 | echo. 192 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 193 | goto end 194 | ) 195 | 196 | if "%1" == "texinfo" ( 197 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 198 | if errorlevel 1 exit /b 1 199 | echo. 200 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 201 | goto end 202 | ) 203 | 204 | if "%1" == "gettext" ( 205 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 206 | if errorlevel 1 exit /b 1 207 | echo. 208 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 209 | goto end 210 | ) 211 | 212 | if "%1" == "changes" ( 213 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 214 | if errorlevel 1 exit /b 1 215 | echo. 216 | echo.The overview file is in %BUILDDIR%/changes. 217 | goto end 218 | ) 219 | 220 | if "%1" == "linkcheck" ( 221 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 222 | if errorlevel 1 exit /b 1 223 | echo. 224 | echo.Link check complete; look for any errors in the above output ^ 225 | or in %BUILDDIR%/linkcheck/output.txt. 226 | goto end 227 | ) 228 | 229 | if "%1" == "doctest" ( 230 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 231 | if errorlevel 1 exit /b 1 232 | echo. 233 | echo.Testing of doctests in the sources finished, look at the ^ 234 | results in %BUILDDIR%/doctest/output.txt. 235 | goto end 236 | ) 237 | 238 | if "%1" == "coverage" ( 239 | %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage 240 | if errorlevel 1 exit /b 1 241 | echo. 242 | echo.Testing of coverage in the sources finished, look at the ^ 243 | results in %BUILDDIR%/coverage/python.txt. 244 | goto end 245 | ) 246 | 247 | if "%1" == "xml" ( 248 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 249 | if errorlevel 1 exit /b 1 250 | echo. 251 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 252 | goto end 253 | ) 254 | 255 | if "%1" == "pseudoxml" ( 256 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 257 | if errorlevel 1 exit /b 1 258 | echo. 259 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 260 | goto end 261 | ) 262 | 263 | :end 264 | -------------------------------------------------------------------------------- /docs/make_apidoc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sphinx-apidoc --no-toc -e -f -o . ../nara_wpe -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :maxdepth: 4 3 | 4 | nara_wpe -------------------------------------------------------------------------------- /docs/nara_wpe.benchmark_online_wpe.rst: -------------------------------------------------------------------------------- 1 | nara\_wpe.benchmark\_online\_wpe module 2 | ======================================= 3 | 4 | .. automodule:: benchmark_online_wpe 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/nara_wpe.gradient_overrides.rst: -------------------------------------------------------------------------------- 1 | nara\_wpe.gradient\_overrides module 2 | ==================================== 3 | 4 | .. automodule:: gradient_overrides 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/nara_wpe.rst: -------------------------------------------------------------------------------- 1 | nara\_wpe package 2 | ================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | 9 | nara_wpe.benchmark_online_wpe 10 | nara_wpe.gradient_overrides 11 | nara_wpe.test_utils 12 | nara_wpe.tf_wpe 13 | nara_wpe.utils 14 | nara_wpe.wpe 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: nara_wpe 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/nara_wpe.test_utils.rst: -------------------------------------------------------------------------------- 1 | nara\_wpe.test\_utils module 2 | ============================ 3 | 4 | .. automodule:: test_utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/nara_wpe.tf_wpe.rst: -------------------------------------------------------------------------------- 1 | nara\_wpe.tf\_wpe module 2 | ======================== 3 | 4 | .. automodule:: tf_wpe 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/nara_wpe.utils.rst: -------------------------------------------------------------------------------- 1 | nara\_wpe.utils module 2 | ====================== 3 | 4 | .. automodule:: utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/nara_wpe.wpe.rst: -------------------------------------------------------------------------------- 1 | nara\_wpe.wpe module 2 | ==================== 3 | 4 | .. automodule:: wpe 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /examples/NTT_wrapper_offline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%matplotlib inline\n", 12 | "\n", 13 | "import IPython\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import numpy as np\n", 16 | "import soundfile as sf\n", 17 | "import time\n", 18 | "\n", 19 | "from nara_wpe.ntt_wpe import ntt_wrapper as wpe\n", 20 | "from nara_wpe import project_root\n", 21 | "from nara_wpe.utils import stft, istft" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "# Minimal example with random data" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "def aquire_audio_data():\n", 38 | " D, T = 4, 10000\n", 39 | " y = np.random.normal(size=(D, T))\n", 40 | " return y" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "y = aquire_audio_data()\n", 50 | "\n", 51 | "start = time.perf_counter()\n", 52 | "x = wpe(y)\n", 53 | "end = time.perf_counter()\n", 54 | "\n", 55 | "print(f\"Time: {end-start}\")" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "# Example with real audio recordings" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "WPE estimates a filter to predict the current reverberation tail frame from K time frames which lie 3 (delay) time frames in the past. This frame (reverberation tail) is then subtracted from the observed signal.\n", 70 | "\n", 71 | "### Setup" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "channels = 8\n", 81 | "sampling_rate = 16000\n", 82 | "delay = 3\n", 83 | "iterations = 5\n", 84 | "taps = 10" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "### Audio data\n", 92 | "Shape: (frames, channels)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n", 102 | "signal_list = [\n", 103 | " sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n", 104 | " for d in range(channels)\n", 105 | "]\n", 106 | "y = np.stack(signal_list, axis=0)\n", 107 | "IPython.display.Audio(y[0], rate=sampling_rate)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "### iterative WPE\n", 115 | "The wpe function is fed with y. The STFT and ISTFT is included in the Matlab package of NTT. " 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "x = wpe(y, iterations=iterations)\n", 125 | "IPython.display.Audio(x[0], rate=sampling_rate)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "## Power spectrum \n", 133 | "Before and after applying NTT WPE" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "stft_options = dict(\n", 143 | " size=512,\n", 144 | " shift=128,\n", 145 | " window_length=None,\n", 146 | " fading=True,\n", 147 | " pad=True,\n", 148 | " symmetric_window=False\n", 149 | ")\n", 150 | "Y = stft(y, **stft_options).transpose(2, 0, 1)\n", 151 | "X = stft(x, **stft_options).transpose(2, 0, 1)\n", 152 | "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))\n", 153 | "im1 = ax1.imshow(20 * np.log10(np.abs(Y[ :, 0, 200:400])), origin='lower')\n", 154 | "ax1.set_xlabel('frames')\n", 155 | "_ = ax1.set_title('reverberated')\n", 156 | "im2 = ax2.imshow(20 * np.log10(np.abs(X[ :, 0, 200:400])), origin='lower', vmin=-120, vmax=0)\n", 157 | "ax2.set_xlabel('frames')\n", 158 | "_ = ax2.set_title('dereverberated')\n", 159 | "cb = fig.colorbar(im2)" 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "py36", 166 | "language": "python", 167 | "name": "py36" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.6.6" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 2 184 | } 185 | -------------------------------------------------------------------------------- /examples/WPE_Numpy_offline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%matplotlib inline\n", 12 | "\n", 13 | "import IPython\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import numpy as np\n", 16 | "import soundfile as sf\n", 17 | "from tqdm import tqdm\n", 18 | "\n", 19 | "from nara_wpe.wpe import wpe\n", 20 | "from nara_wpe.wpe import get_power\n", 21 | "from nara_wpe.utils import stft, istft, get_stft_center_frequencies\n", 22 | "from nara_wpe import project_root" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "stft_options = dict(size=512, shift=128)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# Minimal example with random data" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "def aquire_audio_data():\n", 48 | " D, T = 4, 10000\n", 49 | " y = np.random.normal(size=(D, T))\n", 50 | " return y" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "y = aquire_audio_data()\n", 60 | "Y = stft(y, **stft_options)\n", 61 | "Y = Y.transpose(2, 0, 1)\n", 62 | "\n", 63 | "Z = wpe(Y)\n", 64 | "z_np = istft(Z.transpose(1, 2, 0), size=stft_options['size'], shift=stft_options['shift'])" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "# Example with real audio recordings" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "WPE estimates a filter to predict the current reverberation tail frame from K time frames which lie 3 (delay) time frames in the past. This frame (reverberation tail) is then subtracted from the observed signal.\n", 79 | "\n", 80 | "### Setup" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "channels = 8\n", 90 | "sampling_rate = 16000\n", 91 | "delay = 3\n", 92 | "iterations = 5\n", 93 | "taps = 10\n", 94 | "alpha=0.9999" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "### Audio data\n", 102 | "Shape: (channels, frames)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n", 112 | "signal_list = [\n", 113 | " sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n", 114 | " for d in range(channels)\n", 115 | "]\n", 116 | "y = np.stack(signal_list, axis=0)\n", 117 | "IPython.display.Audio(y[0], rate=sampling_rate)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "### STFT\n", 125 | "A STFT is performed to obtain a Numpy array with shape (frequency bins, channels, frames)." 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "Y = stft(y, **stft_options).transpose(2, 0, 1)" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": {}, 140 | "source": [ 141 | "### Iterative WPE\n", 142 | "The wpe function is fed with Y. Finally, an inverse STFT is performed to obtain a dereverberated result in time domain. " 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "Z = wpe(\n", 152 | " Y,\n", 153 | " taps=taps,\n", 154 | " delay=delay,\n", 155 | " iterations=iterations,\n", 156 | " statistics_mode='full'\n", 157 | ").transpose(1, 2, 0)\n", 158 | "z = istft(Z, size=stft_options['size'], shift=stft_options['shift'])\n", 159 | "IPython.display.Audio(z[0], rate=sampling_rate)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "## Power spectrum \n", 167 | "Before and after applying WPE" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))\n", 177 | "im1 = ax1.imshow(20 * np.log10(np.abs(Y[ :, 0, 200:400])), origin='lower')\n", 178 | "ax1.set_xlabel('frames')\n", 179 | "_ = ax1.set_title('reverberated')\n", 180 | "im2 = ax2.imshow(20 * np.log10(np.abs(Z[0, 200:400, :])).T, origin='lower', vmin=-120, vmax=0)\n", 181 | "ax2.set_xlabel('frames')\n", 182 | "_ = ax2.set_title('dereverberated')\n", 183 | "cb = fig.colorbar(im2)" 184 | ] 185 | } 186 | ], 187 | "metadata": { 188 | "kernelspec": { 189 | "display_name": "Python 3", 190 | "language": "python", 191 | "name": "python3" 192 | }, 193 | "language_info": { 194 | "codemirror_mode": { 195 | "name": "ipython", 196 | "version": 3 197 | }, 198 | "file_extension": ".py", 199 | "mimetype": "text/x-python", 200 | "name": "python", 201 | "nbconvert_exporter": "python", 202 | "pygments_lexer": "ipython3", 203 | "version": "3.6.6" 204 | } 205 | }, 206 | "nbformat": 4, 207 | "nbformat_minor": 2 208 | } 209 | -------------------------------------------------------------------------------- /examples/WPE_Numpy_online.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%matplotlib inline\n", 12 | "\n", 13 | "import IPython\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import numpy as np\n", 16 | "import soundfile as sf\n", 17 | "import time\n", 18 | "from tqdm import tqdm\n", 19 | "\n", 20 | "from nara_wpe.wpe import online_wpe_step, get_power_online, OnlineWPE\n", 21 | "from nara_wpe.utils import stft, istft, get_stft_center_frequencies\n", 22 | "from nara_wpe import project_root" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "stft_options = dict(size=512, shift=128)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# Example with real audio recordings\n", 39 | "The iterations are dropped in contrast to the offline version. To use past observations the correlation matrix and the correlation vector are calculated recursively with a decaying window. $\\alpha$ is the decay factor." 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "### Setup" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "channels = 8\n", 56 | "sampling_rate = 16000\n", 57 | "delay = 3\n", 58 | "alpha=0.9999\n", 59 | "taps = 10\n", 60 | "frequency_bins = stft_options['size'] // 2 + 1" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "### Audio data" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": null, 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n", 77 | "signal_list = [\n", 78 | " sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n", 79 | " for d in range(channels)\n", 80 | "]\n", 81 | "y = np.stack(signal_list, axis=0)\n", 82 | "IPython.display.Audio(y[0], rate=sampling_rate)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "### Online buffer\n", 90 | "For simplicity the STFT is performed before providing the frames.\n", 91 | "\n", 92 | "Shape: (frames, frequency bins, channels)\n", 93 | "\n", 94 | "frames: K+delay+1" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "Y = stft(y, **stft_options).transpose(1, 2, 0)\n", 104 | "T, _, _ = Y.shape\n", 105 | "\n", 106 | "def aquire_framebuffer():\n", 107 | " buffer = list(Y[:taps+delay, :, :])\n", 108 | " for t in range(taps+delay+1, T):\n", 109 | " buffer.append(Y[t, :, :])\n", 110 | " yield np.array(buffer)\n", 111 | " buffer.pop(0)" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "### Non-iterative frame online approach\n", 119 | "A frame online example requires, that certain state variables are kept from frame to frame. That is the inverse correlation matrix $\\text{R}_{t, f}^{-1}$ which is stored in Q and initialized with an identity matrix, as well as filter coefficient matrix that is stored in G and initialized with zeros. \n", 120 | "\n", 121 | "Again for simplicity the ISTFT is applied afterwards." 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "Z_list = []\n", 131 | "Q = np.stack([np.identity(channels * taps) for a in range(frequency_bins)])\n", 132 | "G = np.zeros((frequency_bins, channels * taps, channels))\n", 133 | "\n", 134 | "for Y_step in tqdm(aquire_framebuffer()):\n", 135 | " Z, Q, G = online_wpe_step(\n", 136 | " Y_step,\n", 137 | " get_power_online(Y_step.transpose(1, 2, 0)),\n", 138 | " Q,\n", 139 | " G,\n", 140 | " alpha=alpha,\n", 141 | " taps=taps,\n", 142 | " delay=delay\n", 143 | " )\n", 144 | " Z_list.append(Z)\n", 145 | "\n", 146 | "Z_stacked = np.stack(Z_list)\n", 147 | "z = istft(np.asarray(Z_stacked).transpose(2, 0, 1), size=stft_options['size'], shift=stft_options['shift'])\n", 148 | "\n", 149 | "IPython.display.Audio(z[0], rate=sampling_rate)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "## Frame online WPE in class fashion:" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "Online WPE class holds the correlation Matrix and the coefficient matrix. " 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "Z_list = []\n", 173 | "online_wpe = OnlineWPE(\n", 174 | " taps=taps,\n", 175 | " delay=delay,\n", 176 | " alpha=alpha\n", 177 | ")\n", 178 | "for Y_step in tqdm(aquire_framebuffer()):\n", 179 | " Z_list.append(online_wpe.step_frame(Y_step))\n", 180 | "\n", 181 | "Z = np.stack(Z_list)\n", 182 | "z = istft(np.asarray(Z).transpose(2, 0, 1), size=stft_options['size'], shift=stft_options['shift'])\n", 183 | "\n", 184 | "IPython.display.Audio(z[0], rate=sampling_rate)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "# Power spectrum\n", 192 | "Before and after applying WPE." 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 8))\n", 202 | "im1 = ax1.imshow(20 * np.log10(np.abs(Y[200:400, :, 0])).T, origin='lower')\n", 203 | "ax1.set_xlabel('')\n", 204 | "_ = ax1.set_title('reverberated')\n", 205 | "im2 = ax2.imshow(20 * np.log10(np.abs(Z_stacked[200:400, :, 0])).T, origin='lower')\n", 206 | "_ = ax2.set_title('dereverberated')\n", 207 | "cb = fig.colorbar(im1)" 208 | ] 209 | } 210 | ], 211 | "metadata": { 212 | "kernelspec": { 213 | "display_name": "Python 3", 214 | "language": "python", 215 | "name": "python3" 216 | }, 217 | "language_info": { 218 | "codemirror_mode": { 219 | "name": "ipython", 220 | "version": 3 221 | }, 222 | "file_extension": ".py", 223 | "mimetype": "text/x-python", 224 | "name": "python", 225 | "nbconvert_exporter": "python", 226 | "pygments_lexer": "ipython3", 227 | "version": "3.6.6" 228 | } 229 | }, 230 | "nbformat": 4, 231 | "nbformat_minor": 2 232 | } 233 | -------------------------------------------------------------------------------- /examples/WPE_Tensorflow_offline.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%matplotlib inline\n", 12 | "\n", 13 | "import IPython\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import numpy as np\n", 16 | "import soundfile as sf\n", 17 | "from tqdm import tqdm\n", 18 | "import tensorflow as tf\n", 19 | "\n", 20 | "from nara_wpe.tf_wpe import wpe\n", 21 | "from nara_wpe.utils import stft, istft, get_stft_center_frequencies\n", 22 | "from nara_wpe import project_root" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "stft_options = dict(size=512, shift=128)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "# Minimal example with random data" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "def aquire_audio_data():\n", 48 | " D, T = 4, 10000\n", 49 | " y = np.random.normal(size=(D, T))\n", 50 | " return y" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "metadata": { 57 | "scrolled": false 58 | }, 59 | "outputs": [], 60 | "source": [ 61 | "y = aquire_audio_data()\n", 62 | "Y = stft(y, **stft_options).transpose(2, 0, 1)\n", 63 | "with tf.Session() as session:\n", 64 | " Y_tf = tf.placeholder(\n", 65 | " tf.complex128, shape=(None, None, None))\n", 66 | " Z_tf = wpe(Y_tf)\n", 67 | " Z = session.run(Z_tf, {Y_tf: Y})\n", 68 | "z_tf = istft(Z.transpose(1, 2, 0), size=stft_options['size'], shift=stft_options['shift'])" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "# Example with real audio recordings\n", 76 | "WPE estimates a filter to predict the current reverberation tail frame from K time frames which lie 3 (delay) time frames in the past. This frame (reverberation tail) is then subtracted from the observed signal." 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "### Setup" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "channels = 8\n", 93 | "sampling_rate = 16000\n", 94 | "delay = 3\n", 95 | "iterations = 5\n", 96 | "taps = 10" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "### Audio data\n", 104 | "Shape: (frames, channels)" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n", 114 | "signal_list = [\n", 115 | " sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n", 116 | " for d in range(channels)\n", 117 | "]\n", 118 | "y = np.stack(signal_list, axis=0)\n", 119 | "IPython.display.Audio(y[0], rate=sampling_rate)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "### STFT\n", 127 | "For simplicity reasons we calculate the STFT in Numpy and provide the result in form of the Tensorflow feed dict." 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": null, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "Y = stft(y, **stft_options).transpose(2, 0, 1)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "metadata": {}, 142 | "source": [ 143 | "### iterative WPE\n", 144 | "A placeholder for Y is declared. The wpe function is fed with Y via the Tensorflow feed dict. Finally, an inverse STFT in Numpy is performed to obtain a dereverberated result in time domain." 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "from nara_wpe.tf_wpe import get_power\n", 154 | "with tf.Session()as session:\n", 155 | " Y_tf = tf.placeholder(tf.complex128, shape=(None, None, None))\n", 156 | " Z_tf = wpe(Y_tf, taps=taps, iterations=iterations)\n", 157 | " Z = session.run(Z_tf, {Y_tf: Y})\n", 158 | "z = istft(Z.transpose(1, 2, 0), size=stft_options['size'], shift=stft_options['shift'])\n", 159 | "IPython.display.Audio(z[0], rate=sampling_rate)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "# Power spectrum\n", 167 | "Before and after applying WPE" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 8))\n", 177 | "im1 = ax1.imshow(20 * np.log10(np.abs(Y[:, 0, 200:400])), origin='lower')\n", 178 | "ax1.set_xlabel('')\n", 179 | "_ = ax1.set_title('reverberated')\n", 180 | "im2 = ax2.imshow(20 * np.log10(np.abs(Z[:, 0, 200:400])), origin='lower')\n", 181 | "_ = ax2.set_title('dereverberated')\n", 182 | "cb = fig.colorbar(im1)" 183 | ] 184 | } 185 | ], 186 | "metadata": { 187 | "kernelspec": { 188 | "display_name": "Python 3", 189 | "language": "python", 190 | "name": "python3" 191 | }, 192 | "language_info": { 193 | "codemirror_mode": { 194 | "name": "ipython", 195 | "version": 3 196 | }, 197 | "file_extension": ".py", 198 | "mimetype": "text/x-python", 199 | "name": "python", 200 | "nbconvert_exporter": "python", 201 | "pygments_lexer": "ipython3", 202 | "version": "3.6.6" 203 | } 204 | }, 205 | "nbformat": 4, 206 | "nbformat_minor": 2 207 | } 208 | -------------------------------------------------------------------------------- /examples/WPE_Tensorflow_online.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "%matplotlib inline\n", 12 | "\n", 13 | "import IPython\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import numpy as np\n", 16 | "import soundfile as sf\n", 17 | "import time\n", 18 | "from tqdm import tqdm\n", 19 | "import tensorflow as tf\n", 20 | "\n", 21 | "from nara_wpe.tf_wpe import wpe\n", 22 | "from nara_wpe.tf_wpe import online_wpe_step, get_power_online\n", 23 | "from nara_wpe.utils import stft, istft, get_stft_center_frequencies\n", 24 | "from nara_wpe import project_root" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "stft_options = dict(\n", 34 | " size=512,\n", 35 | " shift=128,\n", 36 | " window_length=None,\n", 37 | " fading=True,\n", 38 | " pad=True,\n", 39 | " symmetric_window=False\n", 40 | ")" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "# Example with real audio recordings\n", 48 | "The iterations are dropped in contrast to the offline version. To use past observations the correlation matrix and the correlation vector are calculated recursively with a decaying window. $\\alpha$ is the decay factor." 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "### Setup" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "channels = 8\n", 65 | "sampling_rate = 16000\n", 66 | "delay = 3\n", 67 | "alpha=0.99\n", 68 | "taps = 10\n", 69 | "frequency_bins = stft_options['size'] // 2 + 1" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "metadata": {}, 75 | "source": [ 76 | "### Audio data" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n", 86 | "signal_list = [\n", 87 | " sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n", 88 | " for d in range(channels)\n", 89 | "]\n", 90 | "y = np.stack(signal_list, axis=0)\n", 91 | "IPython.display.Audio(y[0], rate=sampling_rate)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "markdown", 96 | "metadata": {}, 97 | "source": [ 98 | "### Online buffer\n", 99 | "For simplicity the STFT is performed before providing the frames.\n", 100 | "\n", 101 | "Shape: (frames, frequency bins, channels)\n", 102 | "\n", 103 | "frames: K+delay+1" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "Y = stft(y, **stft_options).transpose(1, 2, 0)\n", 113 | "T, _, _ = Y.shape\n", 114 | "\n", 115 | "def aquire_framebuffer():\n", 116 | " buffer = list(Y[:taps+delay, :, :])\n", 117 | " for t in range(taps+delay+1, T):\n", 118 | " buffer.append(Y[t, :, :])\n", 119 | " yield np.array(buffer)\n", 120 | " buffer.pop(0)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "### Non-iterative frame online approach\n", 128 | "A frame online example requires, that certain state variables are kept from frame to frame. That is the inverse correlation matrix $\\text{R}_{t, f}^{-1}$ which is stored in Q and initialized with an identity matrix, as well as filter coefficient matrix that is stored in G and initialized with zeros. \n", 129 | "\n", 130 | "Again for simplicity the ISTFT is applied in Numpy afterwards." 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "Z_list = []\n", 140 | "\n", 141 | "Q = np.stack([np.identity(channels * taps) for a in range(frequency_bins)])\n", 142 | "G = np.zeros((frequency_bins, channels * taps, channels))\n", 143 | "\n", 144 | "with tf.Session() as session:\n", 145 | " Y_tf = tf.placeholder(tf.complex128, shape=(taps + delay + 1, frequency_bins, channels))\n", 146 | " Q_tf = tf.placeholder(tf.complex128, shape=(frequency_bins, channels * taps, channels * taps))\n", 147 | " G_tf = tf.placeholder(tf.complex128, shape=(frequency_bins, channels * taps, channels))\n", 148 | " \n", 149 | " results = online_wpe_step(Y_tf, get_power_online(tf.transpose(Y_tf, (1, 0, 2))), Q_tf, G_tf, alpha=alpha, taps=taps, delay=delay)\n", 150 | " for Y_step in tqdm(aquire_framebuffer()):\n", 151 | " feed_dict = {Y_tf: Y_step, Q_tf: Q, G_tf: G}\n", 152 | " Z, Q, G = session.run(results, feed_dict)\n", 153 | " Z_list.append(Z)\n", 154 | "\n", 155 | "Z_stacked = np.stack(Z_list)\n", 156 | "z = istft(np.asarray(Z_stacked).transpose(2, 0, 1), size=stft_options['size'], shift=stft_options['shift'])\n", 157 | "\n", 158 | "IPython.display.Audio(z[0], rate=sampling_rate)" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "# Power spectrum\n", 166 | "Before and after applying WPE." 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 8))\n", 176 | "im1 = ax1.imshow(20 * np.log10(np.abs(Y[200:400, :, 0])).T, origin='lower')\n", 177 | "ax1.set_xlabel('')\n", 178 | "_ = ax1.set_title('reverberated')\n", 179 | "im2 = ax2.imshow(20 * np.log10(np.abs(Z_stacked[200:400, :, 0])).T, origin='lower')\n", 180 | "_ = ax2.set_title('dereverberated')\n", 181 | "cb = fig.colorbar(im1)" 182 | ] 183 | } 184 | ], 185 | "metadata": { 186 | "kernelspec": { 187 | "display_name": "Python 3", 188 | "language": "python", 189 | "name": "python3" 190 | }, 191 | "language_info": { 192 | "codemirror_mode": { 193 | "name": "ipython", 194 | "version": 3 195 | }, 196 | "file_extension": ".py", 197 | "mimetype": "text/x-python", 198 | "name": "python", 199 | "nbconvert_exporter": "python", 200 | "pygments_lexer": "ipython3", 201 | "version": "3.6.6" 202 | } 203 | }, 204 | "nbformat": 4, 205 | "nbformat_minor": 2 206 | } 207 | -------------------------------------------------------------------------------- /examples/examples.rst: -------------------------------------------------------------------------------- 1 | 2 | Examples 3 | ======== 4 | 5 | Small IPython notebooks with examples in Numpy and Tensorflow. The examples 6 | are taken over from the following paper:: 7 | 8 | @InProceedings{Drude2018NaraWPE, 9 | Title = {NARA-WPE: A Python package for weighted prediction error dereverberation in Numpy and Tensorflow for online and offline processing}, 10 | Author = {Drude, Lukas and Heymann, Jahn and Boeddeker, Christoph and Haeb-Umbach, Reinhold}, 11 | Booktitle = {13. ITG Fachtagung Sprachkommunikation (ITG 2018)}, 12 | Year = {2018}, 13 | Month = {Oct}, 14 | } 15 | -------------------------------------------------------------------------------- /maintenance.md: -------------------------------------------------------------------------------- 1 | 2 | # PyPi upload 3 | 4 | Package a Python Package/ version bump See: https://packaging.python.org/tutorials/packaging-projects/ 5 | 6 | 1. Update `setup.py` to new version number 7 | 2. Commit this change 8 | 3. Tag and upload 9 | 10 | ## Install dependencies: 11 | ```bash 12 | pip install --upgrade setuptools 13 | pip install --upgrade wheel 14 | pip install --upgrade twine 15 | # pip install --upgrade bleach html5lib # some versions do not work 16 | pip install --upgrade bump2version 17 | ``` 18 | 19 | `bump2version` takes care to increase the version number, create the commit and tag. 20 | 21 | ```bash 22 | # rm -rf dist/* 23 | bump2version --verbose --tag patch # major, minor or patch 24 | python setup.py sdist bdist_wheel 25 | twine upload --repository testpypi dist/* 26 | git push origin --tags 27 | git push 28 | twine upload dist/* 29 | ``` 30 | -------------------------------------------------------------------------------- /nara_wpe/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | import pathlib 3 | except ImportError: 4 | # Python 2.7 5 | import pathlib2 as pathlib 6 | 7 | import os 8 | 9 | project_root = pathlib.Path(os.path.abspath( 10 | os.path.join(os.path.dirname(__file__), os.pardir) 11 | )) 12 | 13 | name = "nara_wpe" -------------------------------------------------------------------------------- /nara_wpe/benchmark_online_wpe.py: -------------------------------------------------------------------------------- 1 | # ToDo: move this file to tests 2 | 3 | import sys 4 | from itertools import product 5 | 6 | import pandas as pd 7 | if sys.version_info < (3, 7): 8 | import tensorflow as tf 9 | 10 | from nara_wpe import tf_wpe 11 | 12 | if sys.version_info < (3, 7): 13 | benchmark = tf.test.Benchmark() 14 | configs = [] 15 | delay = 1 16 | 17 | 18 | def config_iterator(): 19 | return product( 20 | range(1, 11), 21 | [5, 10], # K 22 | [2, 4, 6], # num_mics # range(2, 11) 23 | [512], # frame_size # 1024 24 | [tf.complex64], # dtype # , tf.complex128 25 | ['/cpu:0'] # device # '/gpu:0' 26 | ) 27 | 28 | 29 | if __name__ == '__main__' and sys.version_info < (3, 7): 30 | print('Generating configs...') 31 | for repetition, K, num_mics, frame_size, dtype, device in config_iterator(): 32 | inv_cov_tm1 = tf.eye( 33 | num_mics * K, batch_shape=[frame_size // 2 + 1], dtype=tf.complex64) 34 | filter_taps_tm1 = tf.zeros( 35 | (frame_size // 2 + 1, num_mics * K, num_mics), dtype=tf.complex64) 36 | input_buffer = tf.zeros( 37 | (K + delay + 1, frame_size // 2 + 1, num_mics), dtype=tf.complex64) 38 | power_estimate = tf.ones((frame_size // 2 + 1,), dtype=tf.complex64) 39 | with tf.device(device): 40 | configs.append(dict( 41 | repetition=repetition, 42 | K=K, 43 | num_mics=num_mics, 44 | frame_size=frame_size, 45 | dtype=dtype, 46 | device=device, 47 | op=tf_wpe.online_wpe_step( 48 | input_buffer, power_estimate, inv_cov_tm1, filter_taps_tm1, 49 | 0.9999, K, delay) 50 | )) 51 | 52 | print('Benchmarking...') 53 | results = [] 54 | with tf.Session() as sess: 55 | for cfg in configs: 56 | print(cfg) 57 | result = benchmark.run_op_benchmark( 58 | sess, 59 | cfg['op'], 60 | min_iters=100 61 | ) 62 | result['repetition'] = cfg['repetition'] 63 | result['K'] = cfg['K'] 64 | result['num_mics'] = cfg['num_mics'] 65 | result['frame_size'] = cfg['frame_size'] 66 | result['device'] = cfg['device'] 67 | result['dtype'] = cfg['dtype'].name 68 | result['fps'] = 1 / result['wall_time'] 69 | result['real_time_factor'] = ( 70 | (16000 / result['frame_size']) * 4 / result['fps'] 71 | ) 72 | results.append(result) 73 | 74 | res = pd.DataFrame(results) 75 | print(res.groupby(['K', 'num_mics', 'frame_size', 'device', 'dtype']).mean()) 76 | 77 | with open('online_wpe_results.json', 'w') as fid: 78 | res.to_json(fid) 79 | -------------------------------------------------------------------------------- /nara_wpe/ntt_wpe.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from cached_property import cached_property 3 | 4 | import tempfile 5 | import numpy as np 6 | import soundfile as sf 7 | import click 8 | from pymatbridge import Matlab 9 | 10 | from nara_wpe import project_root 11 | 12 | 13 | def ntt_wrapper( 14 | y, 15 | taps=10, 16 | delay=3, 17 | iterations=3, 18 | sampling_rate=16000, 19 | path_to_package=project_root / 'cache' / 'wpe_v1.33', 20 | stft_size=512, 21 | stft_shift=128 22 | ): 23 | wpe = NTTWrapper(path_to_package) 24 | return wpe( 25 | y=y, 26 | taps=taps, 27 | delay=delay, 28 | iterations=iterations, 29 | sampling_rate=sampling_rate, 30 | stft_size=stft_size, 31 | stft_shift=stft_shift 32 | ) 33 | 34 | 35 | class NTTWrapper: 36 | """ 37 | The WPE package has to be downloaded from 38 | http://www.kecl.ntt.co.jp/icl/signal/wpe/download.html. It is recommended 39 | to store it in the cache directory of Nara-WPE. 40 | """ 41 | def __init__(self, path_to_pkg): 42 | self.path_to_pkg = Path(path_to_pkg) 43 | 44 | if not self.path_to_pkg.exists(): 45 | raise OSError( 46 | 'NTT WPE package does not exist. It has to be downloaded' 47 | 'from http://www.kecl.ntt.co.jp/icl/signal/wpe/download.html' 48 | 'and stored in the cache directory of Nara-WPE, preferably.' 49 | ) 50 | 51 | @cached_property 52 | def process(self): 53 | mlab = Matlab() 54 | mlab.start() 55 | return mlab 56 | 57 | def cfg(self, channels, sampling_rate, iterations, taps, 58 | stft_size, stft_shift 59 | ): 60 | """ 61 | Check settings and set local.m accordingly 62 | 63 | """ 64 | cfg = self.path_to_pkg / 'settings' / 'local.m' 65 | lines = [] 66 | with cfg.open() as infile: 67 | for line in infile: 68 | if 'num_mic = ' in line and 'num_out' not in line: 69 | if not str(channels) in line: 70 | line = 'num_mic = ' + str(channels) + ";\n" 71 | elif 'fs' in line: 72 | if not str(sampling_rate) in line: 73 | line = 'fs =' + str(sampling_rate) + ";\n" 74 | elif 'channel_setup' in line and 'ssd_param' not in line: 75 | if not str(taps) in line and '%' not in line: 76 | line = "channel_setup = [" + str(taps) + "; ..." + "\n" 77 | elif 'ssd_conf' in line: 78 | if not str(iterations) in line: 79 | line = "ssd_conf = struct('max_iter',"\ 80 | + str(iterations) + ", ...\n" 81 | elif 'analym_param' in line: 82 | if not str(stft_size) in line: 83 | line = "analy_param = struct('win_size',"\ 84 | + str(stft_size) + ", ..." 85 | elif 'shift_size' in line: 86 | if not str(stft_shift) in line: 87 | line = " 'shift_size',"\ 88 | + str(stft_shift) + ", ..." 89 | elif 'hanning' in line: 90 | if not str(stft_size) in line: 91 | line = " 'win' , hanning("\ 92 | + str(stft_size) + "));" 93 | lines.append(line) 94 | return lines 95 | 96 | def __call__( 97 | self, 98 | y, 99 | taps=10, 100 | delay=3, 101 | iterations=3, 102 | sampling_rate=16000, 103 | stft_size=512, 104 | stft_shift=128 105 | ): 106 | """ 107 | 108 | Args: 109 | y: observation (channels. samples) 110 | delay: 111 | iterations: 112 | taps: 113 | stft_opts: dict contains size, shift 114 | 115 | Returns: dereverberated observation (channels, samples) 116 | 117 | """ 118 | 119 | y = y.transpose(1, 0) 120 | channels = y.shape[1] 121 | cfg_lines = self.cfg( 122 | channels, sampling_rate, iterations, taps, stft_size, stft_shift 123 | ) 124 | 125 | with tempfile.TemporaryDirectory() as tempdir: 126 | with (Path(tempdir) / 'local.m').open('w') as cfg_file: 127 | for line in cfg_lines: 128 | cfg_file.write(line) 129 | 130 | self.process.set_variable("y", y) 131 | self.process.set_variable("cfg", cfg_file.name) 132 | 133 | self.process.run_code("addpath('" + str(cfg_file.name) + "');") 134 | self.process.run_code("addpath('" + str(self.path_to_pkg) + "');") 135 | 136 | msg = self.process.run_code("y = wpe(y, cfg);") 137 | assert msg['success'] is True, \ 138 | f'WPE has failed. {msg["content"]["stdout"]}' 139 | 140 | y = self.process.get_variable("y") 141 | 142 | return y.transpose(1, 0) 143 | 144 | 145 | @click.command() 146 | @click.argument( 147 | 'files', nargs=-1, 148 | type=click.Path(exists=True), 149 | ) 150 | @click.option( 151 | '--path_to_pkg', 152 | default=str(project_root / 'cache' / 'wpe_v1.33'), 153 | help='It is recommended to save the ' 154 | 'NTT-WPE package in the cache directory.' 155 | ) 156 | @click.option( 157 | '--output_dir', 158 | default=str(project_root / 'data' / 'dereverberation_ntt'), 159 | help='Output path.' 160 | ) 161 | @click.option( 162 | '--iterations', 163 | default=5, 164 | help='Iterations of WPE' 165 | ) 166 | @click.option( 167 | '--taps', 168 | default=10, 169 | help='Number of filter taps of WPE' 170 | ) 171 | def main(path_to_pkg, files, output_dir, taps=10, delay=3, iterations=5): 172 | """ 173 | A small command line wrapper around the NTT-WPE matlab file. 174 | http://www.kecl.ntt.co.jp/icl/signal/wpe/ 175 | """ 176 | 177 | if len(files) > 1: 178 | signal_list = [ 179 | sf.read(str(file))[0] 180 | for file in files 181 | ] 182 | y = np.stack(signal_list, axis=0) 183 | sampling_rate = sf.read(str(files[0]))[1] 184 | else: 185 | y, sampling_rate = sf.read(files) 186 | 187 | wrapper = NTTWrapper(path_to_pkg) 188 | x = wrapper(y, delay, iterations, taps, 189 | sampling_rate, stft_size=512, stft_shift=128 190 | ) 191 | 192 | if len(files) > 1: 193 | for i, file in enumerate(files): 194 | sf.write( 195 | str(Path(output_dir) / Path(file).name), 196 | x[i], 197 | samplerate=sampling_rate 198 | ) 199 | else: 200 | sf.write( 201 | str(Path(output_dir) / Path(files).name), 202 | x, 203 | samplerate=sampling_rate 204 | ) 205 | 206 | 207 | if __name__ == '__main__': 208 | main() 209 | -------------------------------------------------------------------------------- /nara_wpe/test_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import unittest 3 | 4 | import six 5 | 6 | 7 | class QuietTestRunner(object): 8 | 9 | def run(self, suite): 10 | result = unittest.TestResult() 11 | suite(result) 12 | return result 13 | 14 | 15 | def repeat_with_success_at_least(times, min_success): 16 | """Decorator for multiple trial of the test case. 17 | 18 | The decorated test case is launched multiple times. 19 | The case is judged as passed at least specified number of trials. 20 | If the number of successful trials exceeds `min_success`, 21 | the remaining trials are skipped. 22 | 23 | Args: 24 | times(int): The number of trials. 25 | min_success(int): Threshold that the decorated test 26 | case is regarded as passed. 27 | 28 | """ 29 | 30 | assert times >= min_success 31 | 32 | def _repeat_with_success_at_least(f): 33 | @functools.wraps(f) 34 | def wrapper(*args, **kwargs): 35 | assert len(args) > 0 36 | instance = args[0] 37 | assert isinstance(instance, unittest.TestCase) 38 | success_counter = 0 39 | failure_counter = 0 40 | results = [] 41 | 42 | def fail(): 43 | msg = '\nFail: {0}, Success: {1}'.format( 44 | failure_counter, success_counter) 45 | if len(results) > 0: 46 | first = results[0] 47 | errs = first.failures + first.errors 48 | if len(errs) > 0: 49 | err_msg = '\n'.join(fail[1] for fail in errs) 50 | msg += '\n\nThe first error message:\n' + err_msg 51 | instance.fail(msg) 52 | 53 | for _ in six.moves.range(times): 54 | suite = unittest.TestSuite() 55 | # Create new instance to call the setup and the teardown only 56 | # once. 57 | ins = type(instance)(instance._testMethodName) 58 | suite.addTest( 59 | unittest.FunctionTestCase( 60 | lambda: f(ins, *args[1:], **kwargs), 61 | setUp=ins.setUp, 62 | tearDown=ins.tearDown)) 63 | 64 | result = QuietTestRunner().run(suite) 65 | if result.wasSuccessful(): 66 | success_counter += 1 67 | else: 68 | results.append(result) 69 | failure_counter += 1 70 | if success_counter >= min_success: 71 | instance.assertTrue(True) 72 | return 73 | if failure_counter > times - min_success: 74 | fail() 75 | return 76 | fail() 77 | return wrapper 78 | return _repeat_with_success_at_least 79 | 80 | 81 | def retry(times): 82 | """Decorator that imposes the test to be successful at least once. 83 | 84 | Decorated test case is launched multiple times. 85 | The case is regarded as passed if it is successful 86 | at least once. 87 | 88 | .. note:: 89 | In current implementation, this decorator grasps the 90 | failure information of each trial. 91 | 92 | Args: 93 | times(int): The number of trials. 94 | """ 95 | return repeat_with_success_at_least(times, 1) 96 | -------------------------------------------------------------------------------- /nara_wpe/tf_wpe.py: -------------------------------------------------------------------------------- 1 | try: 2 | import tensorflow as tf 3 | from tensorflow.contrib import signal as tf_signal 4 | except ModuleNotFoundError: 5 | import warnings 6 | # For doctests, each file will be imported 7 | warnings.warn( 8 | 'Could not import tensorflow, hence tensorflow code in nara_wpe will fail.', 9 | ) 10 | 11 | 12 | def _batch_wrapper(inner_function, signals, num_frames, time_axis=-1): 13 | """Helper function to support batching with signal lenghts respected 14 | 15 | Args: 16 | inner_function (function): A function taking the cutted signals as 17 | 18 | signals (tuple): Signals needed for the function. Observation must be 19 | in the first place. All signals must have shape (batch, ..., time) 20 | num_frames (array): Number of frames for each batch 21 | 22 | Returns: 23 | tf.Tensor: Zero padded output of the function. 24 | """ 25 | 26 | max_frames = tf.reduce_max(num_frames) 27 | 28 | # If we remove the batch dimension the time axis shifts by -1 if positive 29 | if time_axis > 0: 30 | time_axis -= 1 31 | 32 | def _single_batch(inp): 33 | frames = inp[-1] 34 | inp = inp[0] 35 | with tf.name_scope('single_batch'): 36 | 37 | def _pad(x): 38 | padding = max_frames - \ 39 | tf.minimum(frames, tf.shape(x)[time_axis]) 40 | zeros = tf.cast(tf.zeros(()), x.dtype) 41 | paddings = x.shape.ndims * [(0, 0), ] 42 | paddings[time_axis] = (0, padding) 43 | return tf.pad( 44 | x, 45 | paddings, 46 | constant_values=zeros 47 | ) 48 | 49 | def _slice(x): 50 | slices = x.shape.ndims * [slice(None), ] 51 | slices[time_axis] = slice(frames) 52 | return x[slices] 53 | 54 | enhanced = inner_function( 55 | [_slice(i) for i in inp] 56 | ) 57 | return _pad(enhanced) 58 | 59 | out = tf.map_fn( 60 | _single_batch, [signals, num_frames], dtype=signals[0].dtype 61 | ) 62 | out.set_shape(signals[0].shape) 63 | return out 64 | 65 | 66 | def get_power_online(signal): 67 | """Calculates power for `signal` 68 | 69 | Args: 70 | signal (tf.Tensor): Signal with shape (F, D, T). 71 | 72 | Returns: 73 | tf.Tensor: Power with shape (F,) 74 | 75 | """ 76 | power_estimate = get_power(signal) 77 | power_estimate = tf.reduce_mean(power_estimate, axis=-1) 78 | return power_estimate 79 | 80 | 81 | def get_power_inverse(signal): 82 | """Calculates inverse power for `signal` 83 | 84 | Args: 85 | signal (tf.Tensor): Single frequency signal with shape (D, T). 86 | psd_context: context for power estimation 87 | Returns: 88 | tf.Tensor: Inverse power with shape (T,) 89 | 90 | """ 91 | power = get_power(signal) 92 | eps = 1e-10 * tf.reduce_max(power) 93 | inverse_power = tf.reciprocal(tf.maximum(power, eps)) 94 | return inverse_power 95 | 96 | 97 | def get_power(signal, axis=-2): 98 | """Calculates power for `signal` 99 | 100 | Args: 101 | signal (tf.Tensor): Single frequency signal with shape (D, T) or (F, D, T). 102 | axis: reduce_mean axis 103 | Returns: 104 | tf.Tensor: Power with shape (T,) or (F, T) 105 | 106 | """ 107 | power = tf.real(signal) ** 2 + tf.imag(signal) ** 2 108 | power = tf.reduce_mean(power, axis=axis) 109 | 110 | return power 111 | 112 | 113 | #def get_power(signal, psd_context=0): 114 | # """ 115 | # Calculates power for single frequency signal. 116 | # In case psd_context is an tuple the two values 117 | # are describing the left and right hand context. 118 | # 119 | # Args: 120 | # signal: (D, T) 121 | # psd_context: tuple or int 122 | # """ 123 | # shape = tf.shape(signal) 124 | # if len(signal.get_shape()) == 2: 125 | # signal = tf.reshape(signal, (1, shape[0], shape[1])) 126 | # 127 | # power = tf.reduce_mean( 128 | # tf.real(signal) ** 2 + tf.imag(signal) ** 2, 129 | # axis=-2 130 | # ) 131 | # 132 | # if psd_context is not 0: 133 | # if isinstance(psd_context, tuple): 134 | # context = psd_context[0] + 1 + psd_context[1] 135 | # else: 136 | # context = 2 * psd_context + 1 137 | # psd_context = (psd_context, psd_context) 138 | # 139 | # power = tf.pad( 140 | # power, 141 | # ((0, 0), (psd_context[0], psd_context[1])), 142 | # mode='constant' 143 | # ) 144 | # print(power) 145 | # power = tf.nn.convolution( 146 | # power, 147 | # tf.ones(context), 148 | # padding='VALID' 149 | # )[psd_context[1]:-psd_context[0]] 150 | # 151 | # denom = tf.nn.convolution( 152 | # tf.zeros_like(power) + 1., 153 | # tf.ones(context), 154 | # padding='VALID' 155 | # )[psd_context[1]:-psd_context[0]] 156 | # print(power) 157 | # power /= denom 158 | # 159 | # elif psd_context == 0: 160 | # pass 161 | # else: 162 | # raise ValueError(psd_context) 163 | # 164 | # return tf.squeeze(power, axis=0) 165 | 166 | 167 | def get_correlations(Y, inverse_power, taps, delay): 168 | """Calculates weighted correlations of a window of length taps 169 | 170 | Args: 171 | Y (tf.Ttensor): Complex-valued STFT signal with shape (F, D, T) 172 | inverse_power (tf.Tensor): Weighting factor with shape (F, T) 173 | taps (int): Lenghts of correlation window 174 | delay (int): Delay for the weighting factor 175 | 176 | Returns: 177 | tf.Tensor: Correlation matrix of shape (F, taps*D, taps*D) 178 | tf.Tensor: Correlation vector of shape (F, taps*D) 179 | """ 180 | dyn_shape = tf.shape(Y) 181 | F = dyn_shape[0] 182 | D = dyn_shape[1] 183 | T = dyn_shape[2] 184 | 185 | Psi = tf_signal.frame(Y, taps, 1, axis=-1)[..., :T - delay - taps + 1, ::-1] 186 | Psi_conj_norm = ( 187 | tf.cast(inverse_power[:, None, delay + taps - 1:, None], Psi.dtype) 188 | * tf.conj(Psi) 189 | ) 190 | 191 | correlation_matrix = tf.einsum('fdtk,fetl->fkdle', Psi_conj_norm, Psi) 192 | correlation_vector = tf.einsum( 193 | 'fdtk,fet->fked', Psi_conj_norm, Y[..., delay + taps - 1:] 194 | ) 195 | 196 | correlation_matrix = tf.reshape(correlation_matrix, (F, taps * D, taps * D)) 197 | return correlation_matrix, correlation_vector 198 | 199 | 200 | def get_correlations_for_single_frequency(Y, inverse_power, taps, delay): 201 | """Calculates weighted correlations of a window of length taps for one freq. 202 | 203 | Args: 204 | Y (tf.Ttensor): Complex-valued STFT signal with shape (D, T) 205 | inverse_power (tf.Tensor): Weighting factor with shape (T) 206 | K (int): Lenghts of correlation window 207 | delay (int): Delay for the weighting factor 208 | 209 | Returns: 210 | tf.Tensor: Correlation matrix of shape (taps*D, taps*D) 211 | tf.Tensor: Correlation vector of shape (D, taps*D) 212 | """ 213 | correlation_matrix, correlation_vector = get_correlations( 214 | Y[None], inverse_power[None], taps, delay 215 | ) 216 | return correlation_matrix[0], correlation_vector[0] 217 | 218 | 219 | def get_filter_matrix_conj( 220 | Y, correlation_matrix, correlation_vector, taps, delay, mode='solve'): 221 | """Calculate (conjugate) filter matrix based on correlations for one freq. 222 | 223 | Args: 224 | Y (tf.Tensor): Complex-valued STFT signal of shape (D, T) 225 | correlation_matrix (tf.Tensor): Correlation matrix (taps*D, taps*D) 226 | correlation_vector (tf.Tensor): Correlation vector (D, taps*D) 227 | K (int): Number of filter taps 228 | delay (int): Delay 229 | mode (str, optional): Specifies how R^-1@r is calculate: 230 | "inv" calculates the inverse of R directly and then uses matmul 231 | "solve" solves Rx=r for x 232 | 233 | Raises: 234 | ValueError: Unknown mode specified 235 | 236 | Returns: 237 | tf.Tensor: (Conjugate) filter Matrix 238 | """ 239 | 240 | D = tf.shape(Y)[0] 241 | 242 | correlation_vector = tf.reshape(correlation_vector, (D * D * taps, 1)) 243 | selector = \ 244 | tf.reshape( 245 | tf.transpose( 246 | tf.reshape(tf.range(D * D * taps), (D, taps, D)), (1, 0, 2)), (-1,)) 247 | inv_selector = \ 248 | tf.reshape( 249 | tf.transpose( 250 | tf.reshape(tf.range(D * D * taps), (taps, D, D)), (1, 0, 2)), (-1,)) 251 | 252 | correlation_vector = tf.gather(correlation_vector, inv_selector) 253 | 254 | if mode == 'inv': 255 | with tf.device('/cpu:0'): 256 | inv_correlation_matrix = tf.matrix_inverse(correlation_matrix) 257 | stacked_filter_conj = tf.einsum( 258 | 'ab,cb->ca', 259 | inv_correlation_matrix, tf.reshape(correlation_vector, (D, D * taps)) 260 | ) 261 | stacked_filter_conj = tf.reshape(stacked_filter_conj, (D * D * taps, 1)) 262 | elif mode == 'solve': 263 | with tf.device('/cpu:0'): 264 | stacked_filter_conj = tf.reshape( 265 | tf.matrix_solve( 266 | tf.tile(correlation_matrix[None, ...], [D, 1, 1]), 267 | tf.reshape(correlation_vector, (D, D * taps, 1)) 268 | ), 269 | (D * D * taps, 1) 270 | ) 271 | else: 272 | raise ValueError( 273 | 'Unknown mode {}. Possible are "inv" and solve"'.format(mode)) 274 | stacked_filter_conj = tf.gather(stacked_filter_conj, selector) 275 | 276 | filter_matrix_conj = tf.transpose( 277 | tf.reshape(stacked_filter_conj, (taps, D, D)), 278 | (0, 2, 1) 279 | ) 280 | return filter_matrix_conj 281 | 282 | 283 | def perform_filter_operation(Y, filter_matrix_conj, taps, delay): 284 | """ 285 | 286 | # >>> D, T, taps, delay = 1, 10, 2, 1 287 | # >>> tf.enable_eager_execution() 288 | # >>> Y = tf.ones([D, T]) 289 | # >>> filter_matrix_conj = tf.ones([taps, D, D]) 290 | # >>> X = perform_filter_operation_v2(Y, filter_matrix_conj, taps, delay) 291 | # >>> X.shape 292 | # TensorShape([Dimension(1), Dimension(10)]) 293 | # >>> X.numpy() 294 | # array([[ 1., 0., -1., -1., -1., -1., -1., -1., -1., -1.]], dtype=float32) 295 | """ 296 | dyn_shape = tf.shape(Y) 297 | T = dyn_shape[1] 298 | 299 | def add_tap(accumulated, tau_minus_delay): 300 | new = tf.einsum( 301 | 'de,dt', 302 | filter_matrix_conj[tau_minus_delay, :, :], 303 | Y[:, :(T - delay - tau_minus_delay)] 304 | ) 305 | paddings = tf.convert_to_tensor([[0, 0], [delay + tau_minus_delay, 0]]) 306 | new = tf.pad(new, paddings, "CONSTANT") 307 | return accumulated + new 308 | 309 | reverb_tail = tf.foldl( 310 | add_tap, tf.range(0, taps), 311 | initializer=tf.zeros_like(Y) 312 | ) 313 | return Y - reverb_tail 314 | 315 | 316 | def single_frequency_wpe(Y, taps=10, delay=3, iterations=3, mode='inv'): 317 | """WPE for a single frequency. 318 | 319 | Args: 320 | Y: Complex valued STFT signal with shape (D, T) 321 | taps: Number of filter taps 322 | delay: Delay as a guard interval, such that X does not become zero. 323 | iterations: 324 | mode (str, optional): Specifies how R^-1@r is calculate: 325 | "inv" calculates the inverse of R directly and then uses matmul 326 | "solve" solves Rx=r for x 327 | 328 | Returns: 329 | 330 | """ 331 | 332 | enhanced = Y 333 | for _ in range(iterations): 334 | inverse_power = get_power_inverse(enhanced) 335 | correlation_matrix, correlation_vector = \ 336 | get_correlations_for_single_frequency(Y, inverse_power, taps, delay) 337 | filter_matrix_conj = get_filter_matrix_conj( 338 | Y, correlation_matrix, correlation_vector, 339 | taps, delay, mode=mode 340 | ) 341 | enhanced = perform_filter_operation(Y, filter_matrix_conj, taps, delay) 342 | return enhanced, inverse_power 343 | 344 | 345 | def wpe(Y, taps=10, delay=3, iterations=3, mode='inv'): 346 | """WPE for all frequencies at once. Use this for regular processing. 347 | 348 | Args: 349 | Y (tf.Tensor): Observed signal with shape (F, D, T) 350 | num_frames (tf.Tensor): Number of frames for each signal in the batch 351 | taps (int, optional): Defaults to 10. Number of filter taps. 352 | delay (int, optional): Defaults to 3. 353 | iterations (int, optional): Defaults to 3. 354 | mode (str, optional): Specifies how R^-1@r is calculated: 355 | "inv" calculates the inverse of R directly and then uses matmul 356 | "solve" solves Rx=r for x 357 | 358 | Returns: 359 | tf.Tensor: Dereverberated signal 360 | tf.Tensor: Latest estimation of the clean speech PSD 361 | """ 362 | 363 | def iteration_over_frequencies(y): 364 | enhanced, inverse_power = single_frequency_wpe( 365 | y, taps, delay, iterations, mode=mode) 366 | return (enhanced, inverse_power) 367 | 368 | enhanced, inverse_power = tf.map_fn( 369 | iteration_over_frequencies, Y, dtype=(Y.dtype, Y.dtype.real_dtype) 370 | ) 371 | 372 | return enhanced 373 | 374 | 375 | def batched_wpe(Y, num_frames, taps=10, delay=3, iterations=3, mode='inv'): 376 | """Batched version of iterative WPE. 377 | 378 | Args: 379 | Y (tf.Tensor): Observed signal with shape (B, F, D, T) 380 | num_frames (tf.Tensor): Number of frames for each signal in the batch 381 | taps (int, optional): Defaults to 10. Number of filter taps. 382 | delay (int, optional): Defaults to 3. 383 | iterations (int, optional): Defaults to 3. 384 | mode (str, optional): Specifies how R^-1@r is calculate: 385 | "inv" calculates the inverse of R directly and then uses matmul 386 | "solve" solves Rx=r for x 387 | 388 | Returns: 389 | tf.Tensor: Dereverberated signal of shape (B, F, D, T). 390 | """ 391 | 392 | def _inner_func(signals): 393 | out = wpe(signals[0], taps, delay, iterations, mode) 394 | return out 395 | 396 | return _batch_wrapper(_inner_func, [Y], num_frames) 397 | 398 | 399 | def wpe_step(Y, inverse_power, taps=10, delay=3, mode='inv', Y_stats=None): 400 | """Single step of 'wpe'. More suited for backpropagation. 401 | 402 | Args: 403 | Y (tf.Tensor): Complex valued STFT signal with shape (F, D, T) 404 | inverse_power (tf.Tensor): Power signal with shape (F, T) 405 | taps (int, optional): Filter order 406 | delay (int, optional): Delay as a guard interval, such that X does not become zero. 407 | mode (str, optional): Specifies how R^-1@r is calculate: 408 | "inv" calculates the inverse of R directly and then uses matmul 409 | "solve" solves Rx=r for x 410 | Y_stats (tf.Tensor or None, optional): Complex valued STFT signal 411 | with shape (F, D, T) use to calculate the signal statistics 412 | (i.e. correlation matrix/vector). 413 | If None, Y is used. Otherwise it's usually a segment of Y 414 | 415 | Returns: 416 | Dereverberated signal of shape (F, D, T) 417 | """ 418 | with tf.name_scope('WPE'): 419 | with tf.name_scope('correlations'): 420 | if Y_stats is None: 421 | Y_stats = Y 422 | correlation_matrix, correlation_vector = get_correlations( 423 | Y_stats, inverse_power, taps, delay 424 | ) 425 | 426 | def step(inp): 427 | (Y_f, correlation_matrix_f, correlation_vector_f) = inp 428 | with tf.name_scope('filter_matrix'): 429 | filter_matrix_conj = get_filter_matrix_conj( 430 | Y_f, 431 | correlation_matrix_f, correlation_vector_f, 432 | taps, delay, mode=mode 433 | ) 434 | with tf.name_scope('apply_filter'): 435 | enhanced = perform_filter_operation( 436 | Y_f, filter_matrix_conj, taps, delay) 437 | return enhanced 438 | 439 | enhanced = tf.map_fn( 440 | step, 441 | (Y, correlation_matrix, correlation_vector), 442 | dtype=Y.dtype, 443 | parallel_iterations=100 444 | ) 445 | 446 | return enhanced 447 | 448 | 449 | def batched_wpe_step( 450 | Y, inverse_power, num_frames, taps=10, delay=3, mode='inv', Y_stats=None): 451 | """Batched single WPE step. More suited for backpropagation. 452 | 453 | Args: 454 | Y (tf.Tensor): Complex valued STFT signal with shape (B, F, D, T) 455 | inverse_power (tf.Tensor): Power signal with shape (B, F, T) 456 | num_frames (tf.Tensor): Number of frames for each signal in the batch 457 | taps (int, optional): Filter order 458 | delay (int, optional): Delay as a guard interval, such that X does not become zero. 459 | mode (str, optional): Specifies how R^-1@r is calculate: 460 | "inv" calculates the inverse of R directly and then uses matmul 461 | "solve" solves Rx=r for x 462 | Y_stats (tf.Tensor or None, optional): Complex valued STFT signal 463 | with shape (F, D, T) use to calculate the signal statistics 464 | (i.e. correlation matrix/vector). 465 | If None, Y is used. Otherwise it's usually a segment of Y 466 | 467 | Returns: 468 | Dereverberated signal of shape B, (F, D, T) 469 | """ 470 | def _inner_func(signals): 471 | _Y, _inverse_power, _Y_stats = signals 472 | out = wpe_step(_Y, _inverse_power, taps, delay, mode, _Y_stats) 473 | return out 474 | 475 | if Y_stats is None: 476 | Y_stats = Y 477 | 478 | return _batch_wrapper(_inner_func, [Y, inverse_power, Y_stats], num_frames) 479 | 480 | 481 | def block_wpe_step( 482 | Y, inverse_power, taps=10, delay=3, mode='inv', 483 | block_length_in_seconds=2., forgetting_factor=0.7, 484 | fft_shift=256, sampling_rate=16000): 485 | """Applies wpe in a block-wise fashion. 486 | 487 | Args: 488 | Y (tf.Tensor): Complex valued STFT signal with shape (F, D, T) 489 | inverse_power (tf.Tensor): Power signal with shape (F, T) 490 | taps (int, optional): Defaults to 10. 491 | delay (int, optional): Defaults to 3. 492 | mode (str, optional): Specifies how R^-1@r is calculate: 493 | "inv" calculates the inverse of R directly and then uses matmul 494 | "solve" solves Rx=r for x 495 | block_length_in_seconds (float, optional): Length of each block in 496 | seconds 497 | forgetting_factor (float, optional): Forgetting factor for the signal 498 | statistics between the blocks 499 | fft_shift (int, optional): Shift used for the STFT. 500 | sampling_rate (int, optional): Sampling rate of the observed signal. 501 | """ 502 | frames_per_block = block_length_in_seconds * sampling_rate // fft_shift 503 | frames_per_block = tf.cast(frames_per_block, tf.int32) 504 | framed_Y = tf_signal.frame( 505 | Y, frames_per_block, frames_per_block, pad_end=True) 506 | framed_inverse_power = tf_signal.frame( 507 | inverse_power, frames_per_block, frames_per_block, pad_end=True) 508 | num_blocks = tf.shape(framed_Y)[-2] 509 | 510 | enhanced_arr = tf.TensorArray( 511 | framed_Y.dtype, size=num_blocks, clear_after_read=True) 512 | start_block = tf.constant(0) 513 | correlation_matrix, correlation_vector = get_correlations( 514 | framed_Y[..., start_block, :], framed_inverse_power[..., start_block, :], 515 | taps, delay 516 | ) 517 | num_bins = Y.shape[0] 518 | num_channels = Y.shape[1].value 519 | if num_channels is None: 520 | num_channels = tf.shape(Y)[1] 521 | num_frames = tf.shape(Y)[-1] 522 | 523 | def cond(k, *_): 524 | return k < num_blocks 525 | 526 | with tf.name_scope('block_WPE'): 527 | def block_step( 528 | k, enhanced, correlation_matrix_tm1, correlation_vector_tm1): 529 | 530 | def _init_step(): 531 | return correlation_matrix_tm1, correlation_vector_tm1 532 | 533 | def _update_step(): 534 | correlation_matrix, correlation_vector = get_correlations( 535 | framed_Y[..., k, :], framed_inverse_power[..., k, :], 536 | taps, delay 537 | ) 538 | return ( 539 | (1. - forgetting_factor) * correlation_matrix_tm1 540 | + forgetting_factor * correlation_matrix, 541 | (1. - forgetting_factor) * correlation_vector_tm1 542 | + forgetting_factor * correlation_vector 543 | ) 544 | 545 | correlation_matrix, correlation_vector = tf.case( 546 | ((tf.equal(k, 0), _init_step),), default=_update_step 547 | ) 548 | 549 | def step(inp): 550 | (Y_f, inverse_power_f, 551 | correlation_matrix_f, correlation_vector_f) = inp 552 | with tf.name_scope('filter_matrix'): 553 | filter_matrix_conj = get_filter_matrix_conj( 554 | Y_f, 555 | correlation_matrix_f, correlation_vector_f, 556 | taps, delay, mode=mode 557 | ) 558 | with tf.name_scope('apply_filter'): 559 | enhanced_f = perform_filter_operation( 560 | Y_f, filter_matrix_conj, taps, delay) 561 | return enhanced_f 562 | 563 | enhanced_block = tf.map_fn( 564 | step, 565 | (framed_Y[..., k, :], framed_inverse_power[..., k, :], 566 | correlation_matrix, correlation_vector), 567 | dtype=framed_Y.dtype, 568 | parallel_iterations=100 569 | ) 570 | 571 | enhanced = enhanced.write(k, enhanced_block) 572 | return k + 1, enhanced, correlation_matrix, correlation_vector 573 | 574 | _, enhanced_arr, _, _ = tf.while_loop( 575 | cond, block_step, 576 | (start_block, enhanced_arr, correlation_matrix, correlation_vector) 577 | ) 578 | 579 | enhanced = enhanced_arr.stack() 580 | enhanced = tf.transpose(enhanced, (1, 2, 0, 3)) 581 | enhanced = tf.reshape(enhanced, (num_bins, num_channels, -1)) 582 | 583 | return enhanced[..., :num_frames] 584 | 585 | 586 | def batched_block_wpe_step( 587 | Y, inverse_power, num_frames, taps=10, delay=3, mode='inv', 588 | block_length_in_seconds=2., forgetting_factor=0.7, 589 | fft_shift=256, sampling_rate=16000): 590 | """Batched single WPE step. More suited for backpropagation. 591 | 592 | Args: 593 | Y (tf.Tensor): Complex valued STFT signal with shape (B, F, D, T) 594 | inverse_power (tf.Tensor): Power signal with shape (B, F, T) 595 | num_frames (tf.Tensor): Number of frames for each signal in the batch 596 | taps (int, optional): Filter order 597 | delay (int, optional): Delay as a guard interval, such that X does not become zero. 598 | mode (str, optional): Specifies how R^-1@r is calculate: 599 | "inv" calculates the inverse of R directly and then uses matmul 600 | "solve" solves Rx=r for x 601 | block_length_in_seconds (float, optional): Length of each block in 602 | seconds 603 | forgetting_factor (float, optional): Forgetting factor for the signal 604 | statistics between the blocks 605 | fft_shift (int, optional): Shift used for the STFT. 606 | sampling_rate (int, optional): Sampling rate of the observed signal. 607 | 608 | Returns: 609 | Dereverberated signal of shape B, (F, D, T) 610 | """ 611 | def _inner_func(signals): 612 | _Y, _inverse_power = signals 613 | out = block_wpe_step( 614 | _Y, _inverse_power, taps, delay, 615 | mode, block_length_in_seconds, forgetting_factor, 616 | fft_shift, sampling_rate) 617 | return out 618 | 619 | return _batch_wrapper(_inner_func, [Y, inverse_power], num_frames) 620 | 621 | 622 | def online_wpe_step( 623 | input_buffer, power_estimate, inv_cov, filter_taps, 624 | alpha, taps, delay 625 | ): 626 | """ 627 | One step of online dereverberation 628 | 629 | Args: 630 | input_buffer (tf.Tensor): Buffer of shape (taps+delay+1, F, D) 631 | power_estimate (tf.Tensor): Estimate for the current PSD 632 | inv_cov (tf.Tensor): Current estimate of R^-1 633 | filter_taps (tf.Tensor): Current estimate of filter taps (F, taps*D, taps) 634 | alpha (float): Smoothing factor 635 | taps (int): Number of filter taps 636 | delay (int): Delay in frames 637 | 638 | Returns: 639 | tf.Tensor: Dereverberated frame of shape (F, D) 640 | tf.Tensor: Updated estimate of R^-1 641 | tf.Tensor: Updated estimate of the filter taps 642 | """ 643 | F = input_buffer.shape[-2] 644 | D = tf.shape(input_buffer)[-1] 645 | window = input_buffer[:-delay - 1][::-1] 646 | window = tf.reshape( 647 | tf.transpose(window, (1, 2, 0)), (F, taps * D) 648 | ) 649 | window_conj = tf.conj(window) 650 | pred = ( 651 | input_buffer[-1] - 652 | tf.einsum('fid,fi->fd', tf.conj(filter_taps), window) 653 | ) 654 | 655 | nominator = tf.einsum('fij,fj->fi', inv_cov, window) 656 | denominator = tf.cast(alpha * power_estimate, window.dtype) 657 | denominator += tf.einsum('fi,fi->f', window_conj, nominator) 658 | kalman_gain = nominator / denominator[:, None] 659 | 660 | inv_cov_k = inv_cov - tf.einsum('fj,fjm,fi->fim', window_conj, inv_cov, kalman_gain) 661 | inv_cov_k /= alpha 662 | 663 | filter_taps_k = ( 664 | filter_taps + 665 | tf.einsum('fi,fm->fim', kalman_gain, tf.conj(pred)) 666 | ) 667 | return pred, inv_cov_k, filter_taps_k 668 | 669 | 670 | def recursive_wpe( 671 | Y, power_estimate, alpha, taps=10, delay=2, 672 | only_use_final_filters=False): 673 | """Applies WPE in a framewise recursive fashion. 674 | 675 | Args: 676 | Y (tf.Tensor): Observed signal of shape (T, F, D) 677 | power_estimate (tf.Tensor): Estimate for the clean signal PSD of shape (T, F) 678 | alpha (float): Smoothing factor for the recursion 679 | taps (int, optional): Number of filter taps. 680 | delay (int, optional): Delay 681 | only_use_final_filters (bool, optional): Applies only the final 682 | estimated filter coefficients to the whole signal. This is for 683 | debugging purposes only and makes this method a offline one. 684 | 685 | Returns: 686 | tf.Tensor: Enhanced signal 687 | """ 688 | 689 | num_frames = tf.shape(Y)[0] 690 | num_bins = Y.shape[1] 691 | num_ch = tf.shape(Y)[-1] 692 | dtype = Y.dtype 693 | k = delay + taps 694 | 695 | inv_cov_tm1 = tf.eye(num_ch * taps, batch_shape=[num_bins], dtype=dtype) 696 | filter_taps_tm1 = tf.zeros((num_bins, num_ch * taps, num_ch), dtype=dtype) 697 | enhanced_arr = tf.TensorArray(dtype, size=num_frames, name='dereverberated') 698 | Y = tf.pad(Y, ((delay + taps, 0), (0, 0), (0, 0))) 699 | 700 | def dereverb_step(k_, inv_cov_tm1, filter_taps_tm1, enhanced): 701 | pos = k_ - delay - taps 702 | input_buffer = Y[pos:k_ + 1] 703 | pred, inv_cov_k, filter_taps_k = online_wpe_step( 704 | input_buffer, power_estimate[pos], 705 | inv_cov_tm1, filter_taps_tm1, alpha, taps, delay 706 | ) 707 | enhanced_k = enhanced.write(pos, pred) 708 | return k_ + 1, inv_cov_k, filter_taps_k, enhanced_k 709 | 710 | def cond(k, *_): 711 | return tf.less(k, num_frames + delay + taps) 712 | 713 | _, _, final_filter_taps, enhanced = tf.while_loop( 714 | cond, dereverb_step, (k, inv_cov_tm1, filter_taps_tm1, enhanced_arr)) 715 | 716 | # Only for testing / oracle purposes 717 | def dereverb_with_filters(k_, filter_taps, enhanced): 718 | window = Y[k_ - delay - taps:k_ - delay][::-1] 719 | window = tf.reshape( 720 | tf.transpose(window, (1, 2, 0)), (-1, taps * num_ch) 721 | ) 722 | pred = ( 723 | Y[k_] - 724 | tf.einsum('lim,li->lm', tf.conj(filter_taps), window) 725 | ) 726 | enhanced_k = enhanced.write(k_ - delay - taps, pred) 727 | return k_ + 1, filter_taps, enhanced_k 728 | 729 | if only_use_final_filters: 730 | k = tf.constant(0) + delay + taps 731 | enhanced_arr = tf.TensorArray(dtype, size=num_frames) 732 | _, _, enhanced = tf.while_loop( 733 | cond, dereverb_with_filters, (k, final_filter_taps, enhanced_arr)) 734 | 735 | return enhanced.stack() 736 | 737 | 738 | def batched_recursive_wpe( 739 | Y, power_estimate, alpha, num_frames, taps=10, delay=2, 740 | only_use_final_filters=False): 741 | """Batched single WPE step. More suited for backpropagation. 742 | 743 | Args: 744 | Y (tf.Tensor): Observed signal of shape (B, T, F, D) 745 | power_estimate (tf.Tensor): Estimate for the clean signal PSD of shape (B, T, F) 746 | alpha (float): Smoothing factor for the recursion 747 | num_frames (tf.Tensor): Number of frames for each signal in the batch 748 | K (int, optional): Number of filter taps. 749 | delay (int, optional): Delay 750 | only_use_final_filters (bool, optional): Applies only the final 751 | estimated filter coefficients to the whole signal. This is for 752 | debugging purposes only and makes this method a offline one. 753 | 754 | Returns: 755 | Dereverberated signal of shape (B, T, F, D) 756 | """ 757 | def _inner_func(signals): 758 | _Y, _power_estimate = signals 759 | out = recursive_wpe( 760 | _Y, _power_estimate, alpha, taps, delay, only_use_final_filters) 761 | return out 762 | 763 | return _batch_wrapper( 764 | _inner_func, [Y, power_estimate], num_frames, time_axis=1) 765 | -------------------------------------------------------------------------------- /nara_wpe/torch_wpe.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional 6 | from nara_wpe.wpe import segment_axis 7 | 8 | 9 | def torch_moveaxis(x: torch.tensor, source, destination): 10 | """ 11 | 12 | >>> torch_moveaxis(torch.ones(2, 25), 1, 0).shape 13 | torch.Size([25, 2]) 14 | >>> torch_moveaxis(torch.ones(2, 25), -1, -2).shape 15 | torch.Size([25, 2]) 16 | >>> torch_moveaxis(torch.ones(2, 25), 0, 1).shape 17 | torch.Size([25, 2]) 18 | >>> torch_moveaxis(torch.ones(2, 25), -2, -1).shape 19 | torch.Size([25, 2]) 20 | >>> torch_moveaxis(torch.ones(2, 25) + 1j, -2, -1).shape 21 | torch.Size([25, 2]) 22 | """ 23 | ndim = len(x.shape) 24 | permutation = list(range(ndim)) 25 | source = permutation.pop(source) 26 | permutation.insert(destination % ndim, source) 27 | return x.permute(*permutation) 28 | 29 | 30 | def build_y_tilde(Y, taps, delay): 31 | """ 32 | 33 | Note: The returned y_tilde consumes a similar amount of memory as Y, because 34 | of tricks with strides. Usually the memory consumprion is K times 35 | smaller than the memory consumprion of a contignous array, 36 | 37 | >>> T, D = 20, 2 38 | >>> Y = torch.arange(start=1, end=T * D + 1).reshape([T, D]).t() 39 | >>> # Y = torch.arange(start=1, end=T * D + 1).to(dtype=torch.complex128).reshape([T, D]).t() 40 | >>> print(Y.numpy()) 41 | [[ 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39] 42 | [ 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40]] 43 | >>> taps, delay = 4, 2 44 | >>> Y_tilde = build_y_tilde(Y, taps, delay) 45 | >>> print(Y_tilde.shape, (taps*D, T)) 46 | torch.Size([8, 20]) (8, 20) 47 | >>> print(Y_tilde.numpy()) 48 | [[ 0 0 0 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29] 49 | [ 0 0 0 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30] 50 | [ 0 0 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31] 51 | [ 0 0 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32] 52 | [ 0 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33] 53 | [ 0 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34] 54 | [ 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35] 55 | [ 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36]] 56 | >>> Y_tilde = build_y_tilde(Y, taps, 0) 57 | >>> print(Y_tilde.shape, (taps*D, T), Y_tilde.stride()) 58 | torch.Size([8, 20]) (8, 20) (1, 2) 59 | >>> print('Pseudo size:', np.prod(Y_tilde.size()) * Y_tilde.element_size()) 60 | Pseudo size: 1280 61 | >>> print('Real size:', Y_tilde.storage().size() * Y_tilde.storage().element_size()) 62 | Real size: 368 63 | >>> print(Y_tilde.numpy()) 64 | [[ 0 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33] 65 | [ 0 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34] 66 | [ 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35] 67 | [ 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36] 68 | [ 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37] 69 | [ 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38] 70 | [ 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39] 71 | [ 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40]] 72 | 73 | >>> print(Y_tilde.shape, Y_tilde.stride()) 74 | torch.Size([8, 20]) (1, 2) 75 | >>> print(Y_tilde[::3].shape, Y_tilde[::3].stride()) 76 | torch.Size([3, 20]) (3, 2) 77 | >>> print(Y_tilde[::3].shape, Y_tilde[::3].contiguous().stride()) 78 | torch.Size([3, 20]) (20, 1) 79 | >>> print(Y_tilde[::3].numpy()) 80 | [[ 0 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33] 81 | [ 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36] 82 | [ 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39]] 83 | 84 | The first columns are zero because of the delay. 85 | 86 | """ 87 | S = Y.shape[:-2] 88 | D = Y.shape[-2] 89 | T = Y.shape[-1] 90 | 91 | def pad(x, axis=-1, pad_width=taps + delay - 1): 92 | npad = np.zeros([x.ndimension(), 2], dtype=int) 93 | npad[axis, 0] = pad_width 94 | # x_np = (np.pad(x.numpy(), 95 | # pad_width=npad, 96 | # mode='constant', 97 | # constant_values=0)) 98 | x = torch.nn.functional.pad( 99 | x, 100 | pad=npad[::-1].ravel().tolist(), 101 | mode='constant', 102 | value=0, 103 | ) 104 | # assert x_np.shape == x.shape, (x_np.shape, x.shape) 105 | return x 106 | 107 | # Y_ = segment_axis(pad(Y), K, 1, axis=-1) 108 | # Y_ = np.flip(Y_, axis=-1) 109 | # if delay > 0: 110 | # Y_ = Y_[..., :-delay, :] 111 | # # Y_: ... x D x T x K 112 | # Y_ = np.moveaxis(Y_, -1, -3) 113 | # # Y_: ... x K x D x T 114 | # Y_ = np.reshape(Y_, [*S, K * D, T]) 115 | # # Y_: ... x KD x T 116 | 117 | # ToDo: write the shape 118 | Y_ = pad(Y) 119 | Y_ = torch_moveaxis(Y_, -1, -2) 120 | Y_ = torch.flip(Y_, dims=[-1 % Y_.ndimension()]) 121 | Y_ = Y_.contiguous() # Y_ = np.ascontiguousarray(Y_) 122 | Y_ = torch.flip(Y_, dims=[-1 % Y_.ndimension()]) 123 | Y_ = segment_axis(Y_, taps, 1, axis=-2) 124 | 125 | # Pytorch does not support negative strides. 126 | # Without this flip, the output of this function does not match the 127 | # analytical form, but the output of WPE will be equal. 128 | # Y_ = torch.flip(Y_, dims=[-2 % Y_.ndimension()]) 129 | 130 | if delay > 0: 131 | Y_ = Y_[..., :-delay, :, :] 132 | Y_ = torch.reshape(Y_, list(S) + [T, taps * D]) 133 | Y_ = torch_moveaxis(Y_, -2, -1) 134 | 135 | return Y_ 136 | 137 | 138 | def get_power_inverse(signal, psd_context=0): 139 | """ 140 | Assumes single frequency bin with shape (D, T). 141 | 142 | >>> s = 1 / torch.tensor([np.arange(1, 6).astype(np.complex128)]*3) 143 | >>> get_power_inverse(s).numpy() 144 | array([ 1., 4., 9., 16., 25.]) 145 | 146 | # >>> get_power_inverse(s * 0 + 1, 1).numpy() 147 | # array([1., 1., 1., 1., 1.]) 148 | # >>> get_power_inverse(s, 1).numpy() 149 | # array([ 1.6 , 2.20408163, 7.08196721, 14.04421326, 19.51219512]) 150 | # >>> get_power_inverse(s, np.inf).numpy() 151 | # array([3.41620801, 3.41620801, 3.41620801, 3.41620801, 3.41620801]) 152 | """ 153 | power = torch.mean(torch.abs(signal)**2, dim=-2) 154 | 155 | if np.isposinf(psd_context): 156 | raise NotImplementedError(psd_context) 157 | # power = torch.broadcast_to(torch.mean(power, dim=-1, keepdims=True), power.shape) 158 | elif psd_context > 0: 159 | raise NotImplementedError(psd_context) 160 | # assert int(psd_context) == psd_context, psd_context 161 | # psd_context = int(psd_context) 162 | # # import bottleneck as bn 163 | # # Handle the corner case correctly (i.e. sum() / count) 164 | # # Use bottleneck when only left context is requested 165 | # # power = bn.move_mean(power, psd_context*2+1, min_count=1) 166 | # power = window_mean(power, (psd_context, psd_context)) 167 | elif psd_context == 0: 168 | pass 169 | else: 170 | raise ValueError(psd_context) 171 | eps = 1e-10 * torch.max(power) 172 | inverse_power = 1 / torch.max(power, eps) 173 | return inverse_power 174 | 175 | 176 | def transpose(x): 177 | return x.transpose(-2, -1) 178 | 179 | 180 | def hermite(x): 181 | return x.transpose(-2, -1).conj() 182 | 183 | 184 | def wpe_v6(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='full'): 185 | """ 186 | Short of wpe_v7 with no extern references. 187 | Applicable in for-loops. 188 | 189 | >>> T = np.random.randint(100, 120) 190 | >>> D = np.random.randint(2, 6) 191 | >>> K = np.random.randint(3, 5) 192 | >>> delay = np.random.randint(1, 3) 193 | 194 | # Real test: 195 | >>> Y = np.random.normal(size=(D, T)) 196 | >>> from nara_wpe import wpe as np_wpe 197 | >>> desired = np_wpe.wpe_v6(Y, K, delay, statistics_mode='full') 198 | >>> actual = wpe_v6(torch.tensor(Y), K, delay, statistics_mode='full').numpy() 199 | >>> np.testing.assert_allclose(actual, desired, atol=1e-6) 200 | 201 | # Complex test: 202 | >>> Y = np.random.normal(size=(D, T)) + 1j * np.random.normal(size=(D, T)) 203 | >>> from nara_wpe import wpe as np_wpe 204 | >>> desired = np_wpe.wpe_v6(Y, K, delay, statistics_mode='full') 205 | >>> actual = wpe_v6(torch.tensor(Y), K, delay, statistics_mode='full').numpy() 206 | >>> np.testing.assert_allclose(actual, desired, atol=1e-6) 207 | """ 208 | 209 | if statistics_mode == 'full': 210 | s = Ellipsis 211 | elif statistics_mode == 'valid': 212 | s = (Ellipsis, slice(delay + taps - 1, None)) 213 | else: 214 | raise ValueError(statistics_mode) 215 | 216 | X = torch.clone(Y) 217 | Y_tilde = build_y_tilde(Y, taps, delay) 218 | for iteration in range(iterations): 219 | inverse_power = get_power_inverse(X, psd_context=psd_context) 220 | Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :] 221 | R = torch.matmul(Y_tilde_inverse_power[s], hermite(Y_tilde[s])) 222 | P = torch.matmul(Y_tilde_inverse_power[s], hermite(Y[s])) 223 | # G = _stable_solve(R, P) 224 | G = torch.linalg.solve(R, P) 225 | X = Y - torch.matmul(hermite(G), Y_tilde) 226 | 227 | return X 228 | 229 | 230 | def wpe_v8( 231 | Y, 232 | taps=10, 233 | delay=3, 234 | iterations=3, 235 | psd_context=0, 236 | statistics_mode='full', 237 | inplace=False 238 | ): 239 | """ 240 | Loopy Multiple Input Multiple Output Weighted Prediction Error [1, 2] implementation 241 | 242 | For numpy it is often the fastest Numpy implementation, torch has to be 243 | profiled. 244 | It loops over the independent axes. This reduces the memory footprint. 245 | 246 | Args: 247 | Y: Complex valued STFT signal with shape (..., D, T). 248 | taps: Filter order 249 | delay: Delay as a guard interval, such that X does not become zero. 250 | iterations: 251 | psd_context: Defines the number of elements in the time window 252 | to improve the power estimation. Total number of elements will 253 | be (psd_context + 1 + psd_context). 254 | statistics_mode: Either 'full' or 'valid'. 255 | 'full': Pad the observation with zeros on the left for the 256 | estimation of the correlation matrix and vector. 257 | 'valid': Only calculate correlation matrix and vector on valid 258 | slices of the observation. 259 | inplace: Whether to change Y inplace. Has only advantages, when Y has 260 | independent axes, because the core WPE algorithm does not support 261 | an inplace modification of the observation. 262 | This option may be relevant, when Y is so large, that you do not 263 | want to double the memory consumption (i.e. save Y and the 264 | dereverberated signal in the memory). 265 | 266 | Returns: 267 | Estimated signal with the same shape as Y 268 | 269 | [1] "Generalization of multi-channel linear prediction methods for blind MIMO 270 | impulse response shortening", Yoshioka, Takuya and Nakatani, Tomohiro, 2012 271 | [2] NARA-WPE: A Python package for weighted prediction error dereverberation in 272 | Numpy and Tensorflow for online and offline processing, Drude, Lukas and 273 | Heymann, Jahn and Boeddeker, Christoph and Haeb-Umbach, Reinhold, 2018 274 | 275 | """ 276 | ndim = Y.ndim 277 | if ndim == 2: 278 | out = wpe_v6( 279 | Y, 280 | taps=taps, 281 | delay=delay, 282 | iterations=iterations, 283 | psd_context=psd_context, 284 | statistics_mode=statistics_mode 285 | ) 286 | if inplace: 287 | Y[...] = out 288 | return out 289 | elif ndim >= 3: 290 | if inplace: 291 | out = Y 292 | else: 293 | out = torch.empty_like(Y) 294 | 295 | for index in np.ndindex(Y.shape[:-2]): 296 | out[index] = wpe_v6( 297 | Y=Y[index], 298 | taps=taps, 299 | delay=delay, 300 | iterations=iterations, 301 | psd_context=psd_context, 302 | statistics_mode=statistics_mode, 303 | ) 304 | return out 305 | else: 306 | raise NotImplementedError( 307 | 'Input shape has to be (..., D, T) and not {}.'.format(Y.shape) 308 | ) 309 | 310 | -------------------------------------------------------------------------------- /nara_wpe/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the STFT function and related helper functions. 3 | """ 4 | import numpy as np 5 | from math import ceil 6 | import scipy 7 | 8 | from scipy import signal 9 | from numpy.fft import rfft, irfft 10 | 11 | import string 12 | 13 | from nara_wpe.wpe import segment_axis as segment_axis_v2 14 | 15 | 16 | # http://stackoverflow.com/a/3153267 17 | def roll_zeropad(a, shift, axis=None): 18 | """ 19 | Roll array elements along a given axis. 20 | 21 | Elements off the end of the array are treated as zeros. 22 | 23 | Args: 24 | a: array_like 25 | Input array. 26 | shift: int 27 | The number of places by which elements are shifted. 28 | axis (int): optional, 29 | The axis along which elements are shifted. By default, the array 30 | is flattened before shifting, after which the original 31 | shape is restored. 32 | 33 | Returns: 34 | ndarray: Output array, with the same shape as `a`. 35 | 36 | Note: 37 | roll : Elements that roll off one end come back on the other. 38 | rollaxis : Roll the specified axis backwards, until it lies in a 39 | given position. 40 | 41 | Examples: 42 | >>> x = np.arange(10) 43 | >>> roll_zeropad(x, 2) 44 | array([0, 0, 0, 1, 2, 3, 4, 5, 6, 7]) 45 | >>> roll_zeropad(x, -2) 46 | array([2, 3, 4, 5, 6, 7, 8, 9, 0, 0]) 47 | 48 | >>> x2 = np.reshape(x, (2,5)) 49 | >>> x2 50 | array([[0, 1, 2, 3, 4], 51 | [5, 6, 7, 8, 9]]) 52 | >>> roll_zeropad(x2, 1) 53 | array([[0, 0, 1, 2, 3], 54 | [4, 5, 6, 7, 8]]) 55 | >>> roll_zeropad(x2, -2) 56 | array([[2, 3, 4, 5, 6], 57 | [7, 8, 9, 0, 0]]) 58 | >>> roll_zeropad(x2, 1, axis=0) 59 | array([[0, 0, 0, 0, 0], 60 | [0, 1, 2, 3, 4]]) 61 | >>> roll_zeropad(x2, -1, axis=0) 62 | array([[5, 6, 7, 8, 9], 63 | [0, 0, 0, 0, 0]]) 64 | >>> roll_zeropad(x2, 1, axis=1) 65 | array([[0, 0, 1, 2, 3], 66 | [0, 5, 6, 7, 8]]) 67 | >>> roll_zeropad(x2, -2, axis=1) 68 | array([[2, 3, 4, 0, 0], 69 | [7, 8, 9, 0, 0]]) 70 | 71 | >>> roll_zeropad(x2, 50) 72 | array([[0, 0, 0, 0, 0], 73 | [0, 0, 0, 0, 0]]) 74 | >>> roll_zeropad(x2, -50) 75 | array([[0, 0, 0, 0, 0], 76 | [0, 0, 0, 0, 0]]) 77 | >>> roll_zeropad(x2, 0) 78 | array([[0, 1, 2, 3, 4], 79 | [5, 6, 7, 8, 9]]) 80 | 81 | """ 82 | a = np.asanyarray(a) 83 | if shift == 0: 84 | return a 85 | if axis is None: 86 | n = a.size 87 | reshape = True 88 | else: 89 | n = a.shape[axis] 90 | reshape = False 91 | if np.abs(shift) > n: 92 | res = np.zeros_like(a) 93 | elif shift < 0: 94 | shift += n 95 | zeros = np.zeros_like(a.take(np.arange(n - shift), axis)) 96 | res = np.concatenate((a.take(np.arange(n - shift, n), axis), zeros), 97 | axis) 98 | else: 99 | zeros = np.zeros_like(a.take(np.arange(n - shift, n), axis)) 100 | res = np.concatenate((zeros, a.take(np.arange(n - shift), axis)), axis) 101 | if reshape: 102 | return res.reshape(a.shape) 103 | else: 104 | return res 105 | 106 | 107 | def stft( 108 | time_signal, 109 | size, 110 | shift, 111 | axis=-1, 112 | window=signal.windows.blackman, 113 | window_length=None, 114 | fading=True, 115 | pad=True, 116 | symmetric_window=False, 117 | ): 118 | """ 119 | ToDo: Open points: 120 | - sym_window need literature 121 | - fading why it is better? 122 | - should pad have more degrees of freedom? 123 | 124 | Calculates the short time Fourier transformation of a multi channel multi 125 | speaker time signal. It is able to add additional zeros for fade-in and 126 | fade out and should yield an STFT signal which allows perfect 127 | reconstruction. 128 | 129 | Args: 130 | time_signal: Multi channel time signal with dimensions 131 | AA x ... x AZ x T x BA x ... x BZ. 132 | size: Scalar FFT-size. 133 | shift: Scalar FFT-shift, the step between successive frames in 134 | samples. Typically shift is a fraction of size. 135 | axis: Scalar axis of time. 136 | Default: None means the biggest dimension. 137 | window: Window function handle. Default is blackman window. 138 | fading: Pads the signal with zeros for better reconstruction. 139 | window_length: Sometimes one desires to use a shorter window than 140 | the fft size. In that case, the window is padded with zeros. 141 | The default is to use the fft-size as a window size. 142 | pad: If true zero pad the signal to match the shape, else cut 143 | symmetric_window: symmetric or periodic window. Assume window is 144 | periodic. Since the implementation of the windows in scipy.signal have a 145 | curious behaviour for odd window_length. Use window(len+1)[:-1]. Since 146 | is equal to the behaviour of MATLAB. 147 | 148 | Returns: 149 | Single channel complex STFT signal with dimensions 150 | AA x ... x AZ x T' times size/2+1 times BA x ... x BZ. 151 | """ 152 | time_signal = np.array(time_signal) 153 | 154 | axis = axis % time_signal.ndim 155 | 156 | if window_length is None: 157 | window_length = size 158 | 159 | # Pad with zeros to have enough samples for the window function to fade. 160 | if fading: 161 | pad_width = np.zeros((time_signal.ndim, 2), dtype=int) 162 | pad_width[axis, :] = window_length - shift 163 | time_signal = np.pad(time_signal, pad_width, mode='constant') 164 | 165 | if isinstance(window, str): 166 | window = getattr(signal.windows, window) 167 | 168 | if symmetric_window: 169 | window = window(window_length) 170 | else: 171 | # https://github.com/scipy/scipy/issues/4551 172 | window = window(window_length + 1)[:-1] 173 | 174 | time_signal_seg = segment_axis_v2( 175 | time_signal, 176 | window_length, 177 | shift=shift, 178 | axis=axis, 179 | end='pad' if pad else 'cut' 180 | ) 181 | 182 | letters = string.ascii_lowercase[:time_signal_seg.ndim] 183 | mapping = letters + ',' + letters[axis + 1] + '->' + letters 184 | 185 | try: 186 | # ToDo: Implement this more memory efficient 187 | return rfft( 188 | np.einsum(mapping, time_signal_seg, window), 189 | n=size, 190 | axis=axis + 1 191 | ) 192 | except ValueError as e: 193 | raise ValueError( 194 | 'Could not calculate the stft, something does not match.\n' + 195 | 'mapping: {}, '.format(mapping) + 196 | 'time_signal_seg.shape: {}, '.format(time_signal_seg.shape) + 197 | 'window.shape: {}, '.format(window.shape) + 198 | 'size: {}'.format(size) + 199 | 'axis+1: {axis+1}' 200 | ) 201 | 202 | 203 | def _samples_to_stft_frames( 204 | samples, 205 | size, 206 | shift, 207 | *, 208 | pad=True, 209 | fading=False, 210 | ): 211 | """ 212 | Calculates number of STFT frames from number of samples in time domain. 213 | 214 | Args: 215 | samples: Number of samples in time domain. 216 | size: FFT size. 217 | window_length often equal to FFT size. The name size should be 218 | marked as deprecated and replaced with window_length. 219 | shift: Hop in samples. 220 | pad: See stft. 221 | fading: See stft. Note to keep old behavior, default value is False. 222 | 223 | Returns: 224 | Number of STFT frames. 225 | 226 | >>> _samples_to_stft_frames(19, 16, 4) 227 | 2 228 | >>> _samples_to_stft_frames(20, 16, 4) 229 | 2 230 | >>> _samples_to_stft_frames(21, 16, 4) 231 | 3 232 | 233 | >>> stft(np.zeros(19), 16, 4, fading=False).shape 234 | (2, 9) 235 | >>> stft(np.zeros(20), 16, 4, fading=False).shape 236 | (2, 9) 237 | >>> stft(np.zeros(21), 16, 4, fading=False).shape 238 | (3, 9) 239 | 240 | >>> _samples_to_stft_frames(19, 16, 4, fading=True) 241 | 8 242 | >>> _samples_to_stft_frames(20, 16, 4, fading=True) 243 | 8 244 | >>> _samples_to_stft_frames(21, 16, 4, fading=True) 245 | 9 246 | 247 | >>> stft(np.zeros(19), 16, 4).shape 248 | (8, 9) 249 | >>> stft(np.zeros(20), 16, 4).shape 250 | (8, 9) 251 | >>> stft(np.zeros(21), 16, 4).shape 252 | (9, 9) 253 | 254 | >>> _samples_to_stft_frames(21, 16, 3, fading=True) 255 | 12 256 | >>> stft(np.zeros(21), 16, 3).shape 257 | (12, 9) 258 | >>> _samples_to_stft_frames(21, 16, 3) 259 | 3 260 | >>> stft(np.zeros(21), 16, 3, fading=False).shape 261 | (3, 9) 262 | """ 263 | if fading: 264 | samples = samples + 2 * (size - shift) 265 | 266 | # I changed this from np.ceil to math.ceil, to yield an integer result. 267 | frames = (samples - size + shift) / shift 268 | if pad: 269 | return ceil(frames) 270 | return int(frames) 271 | 272 | 273 | def _stft_frames_to_samples(frames, size, shift): 274 | """ 275 | Calculates samples in time domain from STFT frames 276 | 277 | Args: 278 | frames: Number of STFT frames. 279 | size: FFT size. 280 | shift: Hop in samples. 281 | 282 | Returns: 283 | Number of samples in time domain. 284 | """ 285 | return frames * shift + size - shift 286 | 287 | 288 | def _biorthogonal_window_brute_force(analysis_window, shift, 289 | use_amplitude=False): 290 | """ 291 | The biorthogonal window (synthesis_window) must verify the criterion: 292 | synthesis_window * analysis_window plus it's shifts must be one. 293 | 1 == sum m from -inf to inf over (synthesis_window(n - mB) * analysis_window(n - mB)) 294 | B ... shift 295 | n ... time index 296 | m ... shift index 297 | 298 | Args: 299 | analysis_window: 300 | shift: 301 | 302 | """ 303 | size = len(analysis_window) 304 | 305 | influence_width = (size - 1) // shift 306 | 307 | denominator = np.zeros_like(analysis_window) 308 | 309 | if use_amplitude: 310 | analysis_window_square = analysis_window 311 | else: 312 | analysis_window_square = analysis_window ** 2 313 | for i in range(-influence_width, influence_width + 1): 314 | denominator += roll_zeropad(analysis_window_square, shift * i) 315 | 316 | if use_amplitude: 317 | synthesis_window = 1 / denominator 318 | else: 319 | synthesis_window = analysis_window / denominator 320 | return synthesis_window 321 | 322 | 323 | _biorthogonal_window_fastest = _biorthogonal_window_brute_force 324 | 325 | 326 | def istft( 327 | stft_signal, 328 | size=1024, 329 | shift=256, 330 | window=signal.windows.blackman, 331 | fading=True, 332 | window_length=None, 333 | symmetric_window=False, 334 | ): 335 | """ 336 | Calculated the inverse short time Fourier transform to exactly reconstruct 337 | the time signal. 338 | 339 | Notes: 340 | Be careful if you make modifications in the frequency domain (e.g. 341 | beamforming) because the synthesis window is calculated according to 342 | the unmodified! analysis window. 343 | 344 | Args: 345 | stft_signal: Single channel complex STFT signal 346 | with dimensions (..., frames, size/2+1). 347 | size: Scalar FFT-size. 348 | shift: Scalar FFT-shift. Typically shift is a fraction of size. 349 | window: Window function handle. 350 | fading: Removes the additional padding, if done during STFT. 351 | window_length: Sometimes one desires to use a shorter window than 352 | the fft size. In that case, the window is padded with zeros. 353 | The default is to use the fft-size as a window size. 354 | symmetric_window: symmetric or periodic window. Assume window is 355 | periodic. Since the implementation of the windows in scipy.signal have a 356 | curious behaviour for odd window_length. Use window(len+1)[:-1]. Since 357 | is equal to the behaviour of MATLAB. 358 | 359 | Returns: 360 | Single channel complex STFT signal 361 | Single channel time signal. 362 | """ 363 | # Note: frame_axis and frequency_axis would make this function much more 364 | # complicated 365 | stft_signal = np.array(stft_signal) 366 | 367 | assert stft_signal.shape[-1] == size // 2 + 1, str(stft_signal.shape) 368 | 369 | if window_length is None: 370 | window_length = size 371 | 372 | if isinstance(window, str): 373 | window = getattr(signal.windows, window) 374 | 375 | if symmetric_window: 376 | window = window(window_length) 377 | else: 378 | window = window(window_length + 1)[:-1] 379 | 380 | window = _biorthogonal_window_fastest(window, shift) 381 | 382 | # window = _biorthogonal_window_fastest( 383 | # window, shift, use_amplitude_for_biorthogonal_window) 384 | # if disable_sythesis_window: 385 | # window = np.ones_like(window) 386 | 387 | time_signal = np.zeros( 388 | list(stft_signal.shape[:-2]) + 389 | [stft_signal.shape[-2] * shift + window_length - shift] 390 | ) 391 | 392 | # Get the correct view to time_signal 393 | time_signal_seg = segment_axis_v2( 394 | time_signal, window_length, shift, end=None 395 | ) 396 | 397 | # Unbuffered inplace add 398 | np.add.at( 399 | time_signal_seg, 400 | Ellipsis, 401 | window * np.real(irfft(stft_signal))[..., :window_length] 402 | ) 403 | # The [..., :window_length] is the inverse of the window padding in rfft. 404 | 405 | # Compensate fade-in and fade-out 406 | if fading: 407 | time_signal = time_signal[ 408 | ..., window_length - shift:time_signal.shape[-1] - (window_length - shift)] 409 | 410 | return time_signal 411 | 412 | 413 | def istft_single_channel(stft_signal, size=1024, shift=256, 414 | window=signal.windows.blackman, fading=True, window_length=None, 415 | use_amplitude_for_biorthogonal_window=False, 416 | disable_sythesis_window=False): 417 | """ 418 | Calculated the inverse short time Fourier transform to exactly reconstruct 419 | the time signal. 420 | 421 | Notes: 422 | Be careful if you make modifications in the frequency domain (e.g. 423 | beamforming) because the synthesis window is calculated according to the 424 | unmodified! analysis window. 425 | 426 | Args: 427 | stft_signal: Single channel complex STFT signal 428 | with dimensions frames times size/2+1. 429 | size: Scalar FFT-size. 430 | shift: Scalar FFT-shift. Typically shift is a fraction of size. 431 | window: Window function handle. 432 | fading: Removes the additional padding, if done during STFT. 433 | window_length: Sometimes one desires to use a shorter window than 434 | the fft size. In that case, the window is padded with zeros. 435 | The default is to use the fft-size as a window size. 436 | 437 | Returns: 438 | Single channel complex STFT signal 439 | Single channel time signal. 440 | """ 441 | assert stft_signal.shape[1] == size // 2 + 1, str(stft_signal.shape) 442 | 443 | if window_length is None: 444 | window = window(size) 445 | else: 446 | window = window(window_length) 447 | window = np.pad(window, (0, size-window_length), mode='constant') 448 | window = _biorthogonal_window_fastest(window, shift, 449 | use_amplitude_for_biorthogonal_window) 450 | if disable_sythesis_window: 451 | window = np.ones_like(window) 452 | 453 | time_signal = np.zeros(stft_signal.shape[0] * shift + size - shift) 454 | 455 | for j, i in enumerate(range(0, len(time_signal) - size + shift, shift)): 456 | time_signal[i:i + size] += window * np.real(irfft(stft_signal[j])) 457 | 458 | # Compensate fade-in and fade-out 459 | if fading: 460 | time_signal = time_signal[size-shift:len(time_signal)-(size-shift)] 461 | 462 | return time_signal 463 | 464 | 465 | def stft_to_spectrogram(stft_signal): 466 | """ 467 | Calculates the power spectrum (spectrogram) of an stft signal. 468 | The output is guaranteed to be real. 469 | 470 | Args: 471 | stft_signal: Complex STFT signal with dimensions 472 | #time_frames times #frequency_bins. 473 | 474 | Returns: 475 | Real spectrogram with same dimensions as input. 476 | """ 477 | spectrogram = stft_signal.real**2 + stft_signal.imag**2 478 | return spectrogram 479 | 480 | 481 | def spectrogram(time_signal, *args, **kwargs): 482 | """ 483 | Thin wrapper of stft with power spectrum calculation. 484 | 485 | Args: 486 | time_signal: 487 | *args: 488 | **kwargs: 489 | 490 | Returns: 491 | 492 | """ 493 | return stft_to_spectrogram(stft(time_signal, *args, **kwargs)) 494 | 495 | 496 | def spectrogram_to_energy_per_frame(spectrogram): 497 | """ 498 | The energy per frame is sometimes used as an additional feature to the MFCC 499 | features. Here, it is calculated from the power spectrum. 500 | 501 | Args: 502 | spectrogram: Real valued power spectrum. 503 | 504 | Returns: 505 | Real valued energy per frame. 506 | """ 507 | energy = np.sum(spectrogram, 1) 508 | 509 | # If energy is zero, we get problems with log 510 | energy = np.where(energy == 0, np.finfo(float).eps, energy) 511 | return energy 512 | 513 | 514 | def get_stft_center_frequencies(size=1024, sample_rate=16000): 515 | """ 516 | It is often necessary to know, which center frequency is 517 | represented by each frequency bin index. 518 | 519 | Args: 520 | size: Scalar FFT-size. 521 | sample_rate: Scalar sample frequency in Hertz. 522 | 523 | Returns: 524 | Array of all relevant center frequencies 525 | """ 526 | frequency_index = np.arange(0, size/2 + 1) 527 | return frequency_index * sample_rate / size 528 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "Cython", "numpy", "scipy"] 3 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = 3 | --doctest-modules 4 | --doctest-continue-on-failure 5 | --junitxml=junit/test-results.xml 6 | --cov=nara_wpe 7 | --cov-report=xml 8 | --cov-report=html 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """A setuptools based setup module. 2 | 3 | See: 4 | https://packaging.python.org/en/latest/distributing.html 5 | https://github.com/pypa/sampleproject 6 | """ 7 | 8 | import setuptools as st 9 | from codecs import open 10 | from os import path 11 | 12 | 13 | here = path.abspath(path.dirname(__file__)) 14 | 15 | # Get the long description from the relevant file 16 | with open(path.join(here, 'README.rst'), encoding='utf-8') as f: 17 | long_description = f.read() 18 | 19 | st.setup( 20 | name='nara_wpe', 21 | 22 | # Versions should comply with PEP440. For a discussion on single-sourcing 23 | # the version across setup.py and the project code, see 24 | # https://packaging.python.org/en/latest/single_source_version.html 25 | version='0.0.11', 26 | 27 | description='Weighted Prediction Error for speech dereverberation', 28 | long_description=long_description, 29 | 30 | # The project's main homepage. 31 | url='https://github.com/fgnt/nara_wpe', 32 | 33 | # Author details 34 | author='Department of Communications Engineering, Paderborn University', 35 | author_email='sek@nt.upb.de', 36 | 37 | # Choose your license 38 | license='MIT', 39 | 40 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 41 | classifiers=[ 42 | # How mature is this project? Common values are 43 | # 3 - Alpha 44 | # 4 - Beta 45 | # 5 - Production/Stable 46 | 'Development Status :: 3 - Alpha', 47 | 48 | # Indicate who your project is intended for 49 | 'Intended Audience :: Developers', 50 | 'Topic :: Software Development :: Build Tools', 51 | 52 | # Pick your license as you wish (should match "license" above) 53 | 'License :: OSI Approved :: MIT License', 54 | 55 | # Specify the Python versions you support here. In particular, ensure 56 | # that you indicate whether you support Python 2, Python 3 or both. 57 | 'Programming Language :: Python :: 3.7', 58 | 'Programming Language :: Python :: 3.8', 59 | 'Programming Language :: Python :: 3.9', 60 | 'Programming Language :: Python :: 3.10', 61 | 'Programming Language :: Python :: 3.11', 62 | 'Programming Language :: Python :: 3.12', 63 | ], 64 | 65 | # What does your project relate to? 66 | keywords='speech', 67 | 68 | # You can just specify the packages manually here if your project is 69 | # simple. Or you can use find_packages(). 70 | packages=st.find_packages(exclude=['contrib', 'docs', 'tests*']), 71 | 72 | # List run-time dependencies here. These will be installed by pip when 73 | # your project is installed. For an analysis of "install_requires" vs pip's 74 | # requirements files see: 75 | # https://packaging.python.org/en/latest/requirements.html 76 | install_requires=[ 77 | 'pathlib2;python_version<"3.0"', 78 | 'numpy', 79 | 'tqdm', 80 | 'soundfile', 81 | 'bottleneck', 82 | 'click' 83 | ], 84 | 85 | # List additional groups of dependencies here (e.g. development 86 | # dependencies). You can install these using the following syntax, 87 | # for example: 88 | # $ pip install -e .[dev,test] 89 | extras_require={ 90 | 'dev': ['check-manifest'], 91 | 'test': [ 92 | 'pytest', 93 | 'coverage', 94 | 'jupyter', 95 | 'matplotlib', 96 | 'scipy', 97 | 'tensorflow==1.12.0;python_version<"3.7"', # Python 3.7 has no tensorflow==1.12.0 98 | 'pytest-cov', 99 | 'codecov', 100 | 'pandas', 101 | 'torch', 102 | 'cached_property', 103 | 'zmq', 104 | 'pyzmq', # Required to install pymatbridge 105 | 'pymatbridge', 106 | ], 107 | }, 108 | 109 | # If there are data files included in your packages that need to be 110 | # installed, specify them here. If using Python 2.6 or less, then these 111 | # have to be included in MANIFEST.in as well. 112 | # package_data={ 113 | # 'nt_feature_extraction': ['package_data.dat'], 114 | # }, 115 | 116 | # Although 'package_data' is the preferred approach, in some case you may 117 | # need to place data files outside of your packages. See: 118 | # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa 119 | # In this case, 'data_file' will be installed into '/my_data' 120 | # data_files=[('my_data', ['data/data_file'])], 121 | 122 | # https://stackoverflow.com/questions/2379898/make-distutils-look-for-numpy-header-files-in-the-correct-place 123 | # include_dirs = [np.get_include()], 124 | 125 | # To provide executable scripts, use entry points in preference to the 126 | # "scripts" keyword. Entry points provide cross-platform support and allow 127 | # pip to create the appropriate form of executable for the target platform. 128 | # entry_points={ 129 | # 'console_scripts': [ 130 | # 'nt_feature_extraction=nt_feature_extraction:main', 131 | # ], 132 | # }, 133 | ) 134 | -------------------------------------------------------------------------------- /tests/test_notebooks.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import subprocess 4 | import tempfile 5 | import unittest 6 | import pytest 7 | 8 | import nbformat 9 | 10 | from nara_wpe import project_root 11 | 12 | 13 | def _notebook_run(path): 14 | """Execute a notebook via nbconvert and collect output. 15 | :returns (parsed nb object, execution errors) 16 | """ 17 | dirname = os.path.dirname(str(path)) 18 | os.chdir(dirname) 19 | with tempfile.NamedTemporaryFile(suffix=".ipynb") as fout: 20 | args = ["jupyter", "nbconvert", "--to", "notebook", "--execute", 21 | "--ExecutePreprocessor.timeout=360", 22 | "--output", fout.name, str(path)] 23 | subprocess.check_call(args) 24 | 25 | fout.seek(0) 26 | nb = nbformat.read(fout, nbformat.current_nbformat) 27 | 28 | errors = [ 29 | output for cell in nb.cells if "outputs" in cell 30 | for output in cell["outputs"] 31 | if output.output_type == "error" 32 | ] 33 | 34 | return nb, errors 35 | 36 | 37 | root = project_root / 'examples' 38 | 39 | 40 | @pytest.mark.skipif(sys.version_info >= (3, 6, 0), reason='Only with Python 3.6+') 41 | def test_wpe_numpy_offline(self): 42 | nb, errors = _notebook_run(root / 'WPE_Numpy_offline.ipynb') 43 | assert errors == [] 44 | 45 | 46 | @pytest.mark.skipif(sys.version_info >= (3, 6, 0), reason='Only with Python 3.6+') 47 | def test_wpe_numpy_online(self): 48 | nb, errors = _notebook_run(root / 'WPE_Numpy_online.ipynb') 49 | assert errors == [] 50 | 51 | 52 | @pytest.mark.skipif(not((3, 6, 0) <= sys.version_info < (3, 7, 0)), reason='Only with Python 3.6') 53 | def test_wpe_tensorflow_offline(): 54 | nb, errors = _notebook_run(root / 'WPE_Tensorflow_offline.ipynb') 55 | assert errors == [] 56 | 57 | 58 | @pytest.mark.skipif(not((3, 6, 0) <= sys.version_info < (3, 7, 0)), reason='Only with Python 3.6') 59 | def test_wpe_tensorflow_online(): 60 | nb, errors = _notebook_run(root / 'WPE_Tensorflow_online.ipynb') 61 | assert errors == [] 62 | 63 | 64 | # def test_NTT_wrapper(self): 65 | # nb, errors = _notebook_run(self.root / 'NTT_wrapper_offline.ipynb') 66 | # assert errors == [] 67 | -------------------------------------------------------------------------------- /tests/test_tf_wpe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run individual test file with i.e. 3 | python nara_wpe/tests/test_tf_wpe.py 4 | """ 5 | import sys 6 | import pytest 7 | import numpy as np 8 | from nara_wpe import wpe 9 | if sys.version_info < (3, 7): 10 | from nara_wpe import tf_wpe 11 | import tensorflow as tf 12 | else: 13 | # Dummy tf 14 | class tf: 15 | class test: 16 | class TestCase: 17 | pass 18 | from nara_wpe.test_utils import retry 19 | 20 | 21 | @pytest.mark.skipif(sys.version_info >= (3, 7), reason="Python 3.7+ has no tensorflow==1.12. Help wanted for tensorflow 2 support.") 22 | class TestWPE(tf.test.TestCase): 23 | def setUp(self): 24 | np.random.seed(0) 25 | self.T = np.random.randint(100, 120) 26 | self.D = np.random.randint(2, 8) 27 | self.K = np.random.randint(3, 5) 28 | self.delay = np.random.randint(0, 2) 29 | self.Y = np.random.normal(size=(self.D, self.T)) \ 30 | + 1j * np.random.normal(size=(self.D, self.T)) 31 | 32 | def test_inverse_power(self): 33 | np_inv_power = wpe.get_power_inverse(self.Y) 34 | 35 | with self.test_session() as sess: 36 | tf_signal = tf.placeholder(tf.complex128, shape=[None, None]) 37 | tf_res = tf_wpe.get_power_inverse(tf_signal) 38 | tf_inv_power = sess.run(tf_res, {tf_signal: self.Y}) 39 | 40 | np.testing.assert_allclose(np_inv_power, tf_inv_power) 41 | 42 | def test_correlations(self): 43 | np_inv_power = wpe.get_power_inverse(self.Y) 44 | np_corr = wpe.get_correlations_narrow( 45 | self.Y, np_inv_power, self.K, self.delay 46 | ) 47 | 48 | with tf.Graph().as_default(), tf.Session() as sess: 49 | tf_signal = tf.placeholder(tf.complex128, shape=[None, None]) 50 | tf_inverse_power = tf_wpe.get_power_inverse(tf_signal) 51 | tf_res = tf_wpe.get_correlations_for_single_frequency( 52 | tf_signal, tf_inverse_power, self.K, self.delay 53 | ) 54 | tf_corr = sess.run(tf_res, {tf_signal: self.Y}) 55 | 56 | np.testing.assert_allclose(np_corr[0], tf_corr[0]) 57 | np.testing.assert_allclose(np_corr[1], tf_corr[1]) 58 | 59 | @retry(5) 60 | def test_filter_matrix(self): 61 | np_inv_power = wpe.get_power_inverse(self.Y) 62 | np_filter_matrix = wpe.get_filter_matrix_conj_v5( 63 | self.Y, np_inv_power, self.K, self.delay 64 | ) 65 | 66 | with tf.Graph().as_default(), tf.Session() as sess: 67 | tf_signal = tf.placeholder(tf.complex128, shape=[None, None]) 68 | tf_inverse_power = tf_wpe.get_power_inverse(tf_signal) 69 | tf_matrix, tf_vector = tf_wpe.get_correlations_for_single_frequency( 70 | tf_signal, tf_inverse_power, self.K, self.delay 71 | ) 72 | tf_filter = tf_wpe.get_filter_matrix_conj( 73 | tf_signal, tf_matrix, tf_vector, 74 | self.K, self.delay 75 | ) 76 | tf_filter_matrix, tf_inv_power_2 = sess.run( 77 | [tf_filter, tf_inverse_power], {tf_signal: self.Y} 78 | ) 79 | 80 | np.testing.assert_allclose(np_inv_power, tf_inv_power_2) 81 | np.testing.assert_allclose(np_filter_matrix, tf_filter_matrix) 82 | 83 | @retry(5) 84 | def test_filter_operation(self): 85 | np_inv_power = wpe.get_power_inverse(self.Y) 86 | np_filter_matrix = wpe.get_filter_matrix_conj_v5( 87 | self.Y, np_inv_power, self.K, self.delay 88 | ) 89 | np_filter_op = wpe.perform_filter_operation_v4( 90 | self.Y, np_filter_matrix, self.K, self.delay 91 | ) 92 | 93 | with tf.Graph().as_default(), tf.Session() as sess: 94 | tf_signal = tf.placeholder(tf.complex128, shape=[None, None]) 95 | tf_inverse_power = tf_wpe.get_power_inverse(tf_signal) 96 | tf_matrix, tf_vector = tf_wpe.get_correlations_for_single_frequency( 97 | tf_signal, tf_inverse_power, self.K, self.delay 98 | ) 99 | tf_filter = tf_wpe.get_filter_matrix_conj( 100 | tf_signal, tf_matrix, tf_vector, 101 | self.K, self.delay 102 | ) 103 | tf_filter_op = tf_wpe.perform_filter_operation( 104 | tf_signal, tf_filter, self.K, self.delay 105 | ) 106 | tf_filter_op = sess.run(tf_filter_op, {tf_signal: self.Y}) 107 | 108 | np.testing.assert_allclose(np_filter_op, tf_filter_op) 109 | 110 | def test_wpe_step(self): 111 | with self.test_session() as sess: 112 | Y = tf.convert_to_tensor(self.Y[None]) 113 | enhanced, inv_power = tf_wpe.single_frequency_wpe( 114 | Y[0], iterations=3 115 | ) 116 | step_enhanced = tf_wpe.wpe_step(Y, inv_power[None]) 117 | enhanced, step_enhanced = sess.run( 118 | [enhanced, step_enhanced] 119 | ) 120 | np.testing.assert_allclose(enhanced, step_enhanced[0]) 121 | 122 | def _get_batch_data(self): 123 | Y = tf.convert_to_tensor(self.Y[None]) 124 | inv_power = tf_wpe.get_power_inverse(Y[0])[None] 125 | Y_short = Y[..., :self.T-20] 126 | inv_power_short = inv_power[..., :self.T-20] 127 | Y_batch = tf.stack( 128 | [Y, tf.pad(Y_short, ((0, 0), (0, 0), (0, 20)))] 129 | ) 130 | inv_power_batch = tf.stack( 131 | [inv_power, tf.pad(inv_power_short, ((0, 0), (0, 20)))] 132 | ) 133 | return Y_batch, inv_power_batch 134 | 135 | def test_batched_wpe_step(self): 136 | with self.test_session() as sess: 137 | Y_batch, inv_power_batch = self._get_batch_data() 138 | enhanced_ref_1 = tf_wpe.wpe_step( 139 | Y_batch[0], inv_power_batch[0] 140 | ) 141 | enhanced_ref_2 = tf_wpe.wpe_step( 142 | Y_batch[0, ...,:self.T-20], inv_power_batch[0, ...,:self.T-20] 143 | ) 144 | step_enhanced = tf_wpe.batched_wpe_step( 145 | Y_batch, inv_power_batch, 146 | num_frames=tf.convert_to_tensor([self.T, self.T-20]) 147 | ) 148 | enhanced, ref1, ref2 = sess.run( 149 | [step_enhanced, enhanced_ref_1, enhanced_ref_2] 150 | ) 151 | np.testing.assert_allclose(enhanced[0], ref1) 152 | np.testing.assert_allclose(enhanced[1, ..., :-20], ref2) 153 | 154 | def test_wpe(self): 155 | with self.test_session() as sess: 156 | Y = tf.convert_to_tensor(self.Y) 157 | enhanced, inv_power = tf_wpe.single_frequency_wpe( 158 | Y, iterations=1 159 | ) 160 | enhanced = sess.run(enhanced) 161 | ref = wpe.wpe_v7(self.Y, iterations=1, statistics_mode='valid') 162 | np.testing.assert_allclose(enhanced, ref) 163 | 164 | def test_batched_wpe(self): 165 | with self.test_session() as sess: 166 | Y_batch, _ = self._get_batch_data() 167 | enhanced_ref_1 = tf_wpe.wpe(Y_batch[0]) 168 | enhanced_ref_2 = tf_wpe.wpe(Y_batch[0, ..., :self.T-20]) 169 | step_enhanced = tf_wpe.batched_wpe( 170 | Y_batch, 171 | num_frames=tf.convert_to_tensor([self.T, self.T-20]) 172 | ) 173 | enhanced, ref1, ref2 = sess.run( 174 | [step_enhanced, enhanced_ref_1, enhanced_ref_2] 175 | ) 176 | np.testing.assert_allclose(enhanced[0], ref1) 177 | np.testing.assert_allclose(enhanced[1, ..., :-20], ref2) 178 | 179 | def test_batched_block_wpe_step(self): 180 | with self.test_session() as sess: 181 | Y_batch, inv_power_batch = self._get_batch_data() 182 | enhanced_ref_1 = tf_wpe.block_wpe_step( 183 | Y_batch[0], inv_power_batch[0] 184 | ) 185 | enhanced_ref_2 = tf_wpe.block_wpe_step( 186 | Y_batch[0, ..., :self.T-20], inv_power_batch[0, ..., :self.T-20] 187 | ) 188 | step_enhanced = tf_wpe.batched_block_wpe_step( 189 | Y_batch, inv_power_batch, 190 | num_frames=tf.convert_to_tensor([self.T, self.T-20]) 191 | ) 192 | enhanced, ref1, ref2 = sess.run( 193 | [step_enhanced, enhanced_ref_1, enhanced_ref_2] 194 | ) 195 | np.testing.assert_allclose(enhanced[0], ref1) 196 | np.testing.assert_allclose(enhanced[1, ..., :-20], ref2) 197 | 198 | @retry(5) 199 | def test_recursive_wpe(self): 200 | with self.test_session() as sess: 201 | T = 5000 202 | D = 2 203 | K = 1 204 | delay = 3 205 | Y = np.random.normal(size=(D, T)) \ 206 | + 1j * np.random.normal(size=(D, T)) 207 | Y = tf.convert_to_tensor(Y[None]) 208 | power = tf.reduce_mean(tf.real(Y) ** 2 + tf.imag(Y) ** 2, axis=1) 209 | inv_power = tf.reciprocal(power) 210 | step_enhanced = tf_wpe.wpe_step( 211 | Y, inv_power, taps=K, delay=D) 212 | recursive_enhanced = tf_wpe.recursive_wpe( 213 | tf.transpose(Y, (2, 0, 1)), 214 | tf.transpose(power), 215 | 1., 216 | taps=K, 217 | delay=D, 218 | only_use_final_filters=True 219 | ) 220 | recursive_enhanced = tf.transpose(recursive_enhanced, (1, 2, 0)) 221 | recursive_enhanced, step_enhanced = sess.run( 222 | [recursive_enhanced, step_enhanced] 223 | ) 224 | np.testing.assert_allclose( 225 | recursive_enhanced[..., -200:], 226 | step_enhanced[..., -200:], 227 | atol=0.01, rtol=0.2 228 | ) 229 | 230 | def test_batched_recursive_wpe(self): 231 | with self.test_session() as sess: 232 | Y_batch, inv_power_batch = self._get_batch_data() 233 | Y_batch = tf.transpose(Y_batch, (0, 3, 1, 2)) 234 | inv_power_batch = tf.transpose(inv_power_batch, (0, 2, 1)) 235 | enhanced_ref_1 = tf_wpe.recursive_wpe( 236 | Y_batch[0], inv_power_batch[0], 0.999 237 | ) 238 | enhanced_ref_2 = tf_wpe.recursive_wpe( 239 | Y_batch[0, :self.T-20], inv_power_batch[0, :self.T-20], 240 | 0.999 241 | ) 242 | step_enhanced = tf_wpe.batched_recursive_wpe( 243 | Y_batch, inv_power_batch, 0.999, 244 | num_frames=tf.convert_to_tensor([self.T, self.T-20]) 245 | ) 246 | enhanced, ref1, ref2 = sess.run( 247 | [step_enhanced, enhanced_ref_1, enhanced_ref_2] 248 | ) 249 | np.testing.assert_allclose(enhanced[0], ref1) 250 | np.testing.assert_allclose(enhanced[1, :-20], ref2) 251 | 252 | 253 | if __name__ == '__main__': 254 | if sys.version_info < (3, 7): 255 | tf.test.main() 256 | -------------------------------------------------------------------------------- /tests/test_wpe.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run all tests with: 3 | nosetests -w tests/ 4 | """ 5 | 6 | import unittest 7 | import numpy.testing as tc 8 | import numpy as np 9 | from nara_wpe import wpe 10 | from nara_wpe.test_utils import retry 11 | 12 | 13 | class TestWPE(unittest.TestCase): 14 | def setUp(self): 15 | self.T = np.random.randint(100, 120) 16 | self.D = np.random.randint(2, 6) 17 | self.K = np.random.randint(3, 5) 18 | self.delay = np.random.randint(1, 3) 19 | self.Y = np.random.normal(size=(self.D, self.T)) \ 20 | + 1j * np.random.normal(size=(self.D, self.T)) 21 | 22 | def test_correlations_v1_vs_v2_toy_example(self): 23 | K = 3 24 | delay = 1 25 | Y = np.asarray( 26 | [ 27 | [11, 12, 13, 14], 28 | [41, 22, 23, 24] 29 | ], dtype=np.float32 30 | ) 31 | inverse_power = wpe.get_power_inverse(Y) 32 | R_desired, r_desired = wpe.get_correlations(Y, inverse_power, K, delay) 33 | R_actual, r_actual = wpe.get_correlations_v2(Y, inverse_power, K, delay) 34 | tc.assert_allclose(R_actual, R_desired) 35 | tc.assert_allclose(r_actual, r_desired) 36 | 37 | def test_correlations_v1_vs_v2(self): 38 | inverse_power = wpe.get_power_inverse(self.Y) 39 | R_desired, r_desired = wpe.get_correlations( 40 | self.Y, inverse_power, self.K, self.delay 41 | ) 42 | R_actual, r_actual = wpe.get_correlations_v2( 43 | self.Y, inverse_power, self.K, self.delay 44 | ) 45 | tc.assert_allclose(R_actual, R_desired) 46 | tc.assert_allclose(r_actual, r_desired) 47 | 48 | @retry(5) 49 | def test_filter_operation_v1_vs_v4(self): 50 | filter_matrix_conj = np.random.normal(size=(self.K, self.D, self.D)) \ 51 | + 1j * np.random.normal(size=(self.K, self.D, self.D)) 52 | 53 | desired = wpe.perform_filter_operation( 54 | self.Y, filter_matrix_conj, self.K, self.delay 55 | ) 56 | actual = wpe.perform_filter_operation_v4( 57 | self.Y, filter_matrix_conj, self.K, self.delay 58 | ) 59 | tc.assert_allclose(desired, actual) 60 | 61 | def test_correlations_narrow_v1_vs_v5(self): 62 | inverse_power = wpe.get_power_inverse(self.Y) 63 | R_desired, r_desired = wpe.get_correlations_narrow( 64 | self.Y, inverse_power, self.K, self.delay 65 | ) 66 | R_actual, r_actual = wpe.get_correlations_narrow_v5( 67 | self.Y, inverse_power, self.K, self.delay 68 | ) 69 | tc.assert_allclose(R_actual, R_desired) 70 | tc.assert_allclose(r_actual, r_desired) 71 | 72 | def test_correlations_narrow_v1_vs_v6(self): 73 | inverse_power = wpe.get_power_inverse(self.Y) 74 | R_desired, r_desired = wpe.get_correlations_narrow( 75 | self.Y, inverse_power, self.K, self.delay 76 | ) 77 | 78 | s = (Ellipsis, slice(self.delay + self.K - 1, None)) 79 | Y_tilde = wpe.build_y_tilde(self.Y, self.K, self.delay) 80 | R_actual, r_actual = wpe.get_correlations_v6( 81 | self.Y[s], Y_tilde[s], inverse_power[s] 82 | ) 83 | tc.assert_allclose(R_actual.conj(), R_desired) 84 | tc.assert_allclose( 85 | r_actual.conj(), 86 | np.swapaxes(r_desired, 1, 2).reshape(-1, r_desired.shape[-1]), 87 | rtol=1e-5, atol=1e-5 88 | ) 89 | 90 | @retry(5) 91 | def test_filter_matrix_conj_v1_vs_v5(self): 92 | inverse_power = wpe.get_power_inverse(self.Y) 93 | 94 | correlation_matrix, correlation_vector = wpe.get_correlations( 95 | self.Y, inverse_power, self.K, self.delay 96 | ) 97 | desired = wpe.get_filter_matrix_conj( 98 | correlation_matrix, correlation_vector, self.K, self.D 99 | ) 100 | actual = wpe.get_filter_matrix_conj_v5( 101 | self.Y, inverse_power, self.K, self.delay 102 | ) 103 | tc.assert_allclose(actual, desired, atol=1e-6) 104 | 105 | @retry(5) 106 | def test_filter_matrix_conj_v1_vs_v7(self): 107 | # ToDo: Fix 'wpe.get_correlations' to consider delay correctly at the 108 | # begin and end of the utterance. 109 | # delay = self.delay 110 | delay = 0 111 | 112 | inverse_power = wpe.get_power_inverse(self.Y) 113 | 114 | correlation_matrix, correlation_vector = wpe.get_correlations( 115 | self.Y, inverse_power, self.K, delay 116 | ) 117 | desired = wpe.get_filter_matrix_conj( 118 | correlation_matrix, correlation_vector, self.K, self.D 119 | ) 120 | 121 | s = [Ellipsis, slice(delay + self.K - 1, None)] 122 | Y_tilde = wpe.build_y_tilde(self.Y, self.K, delay) 123 | actual = wpe.get_filter_matrix_v7( 124 | self.Y, Y_tilde=Y_tilde, inverse_power=inverse_power, 125 | ) 126 | tc.assert_allclose( 127 | actual.conj(), 128 | np.swapaxes(desired, 1, 2).reshape(-1, desired.shape[-1]), 129 | atol=1e-6 130 | ) 131 | 132 | 133 | @retry(5) 134 | def test_delay_zero_cancels_all(self): 135 | delay = 0 136 | X_hat = wpe.wpe(self.Y, self.K, delay=delay) 137 | 138 | # Beginning is never zero. It is a copy of input signal. 139 | tc.assert_allclose( 140 | X_hat[:, delay + self.K - 1:], 141 | np.zeros_like(X_hat[:, delay + self.K - 1:]), 142 | atol=1e-6 143 | ) 144 | 145 | @retry(5) 146 | def test_wpe_v0_vs_v7(self): 147 | # ToDo: Fix 'wpe.wpe_v0' to consider delay correctly at the 148 | # begin and end of the utterance. 149 | # delay = self.delay 150 | delay = 0 151 | 152 | desired = wpe.wpe_v0(self.Y, self.K, delay, statistics_mode='full') 153 | actual = wpe.wpe_v7(self.Y, self.K, delay, statistics_mode='full') 154 | tc.assert_allclose(actual, desired, atol=1e-6) 155 | 156 | desired = wpe.wpe_v0(self.Y, self.K, delay, statistics_mode='valid') 157 | actual = wpe.wpe_v7(self.Y, self.K, delay, statistics_mode='valid') 158 | tc.assert_allclose(actual, desired, atol=1e-6) 159 | 160 | desired = wpe.wpe_v6(self.Y, self.K, delay, statistics_mode='valid') 161 | actual = wpe.wpe_v7(self.Y, self.K, delay, statistics_mode='full') 162 | tc.assert_raises(AssertionError, tc.assert_array_equal, desired, actual) 163 | 164 | @retry(5) 165 | def test_wpe_v6_vs_v7(self): 166 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full') 167 | actual = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full') 168 | tc.assert_allclose(actual, desired, atol=1e-6) 169 | 170 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='valid') 171 | actual = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='valid') 172 | tc.assert_allclose(actual, desired, atol=1e-6) 173 | 174 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='valid') 175 | actual = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full') 176 | tc.assert_raises(AssertionError, tc.assert_array_equal, desired, actual) 177 | 178 | @retry(5) 179 | def test_wpe_v8(self): 180 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='valid') 181 | actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='valid') 182 | tc.assert_allclose(actual, desired, atol=1e-6) 183 | 184 | desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='valid') 185 | actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='valid') 186 | tc.assert_allclose(actual, desired, atol=1e-6) 187 | 188 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full') 189 | actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full') 190 | tc.assert_allclose(actual, desired, atol=1e-6) 191 | 192 | desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full') 193 | actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full') 194 | tc.assert_allclose(actual, desired, atol=1e-6) 195 | 196 | @retry(5) 197 | def test_wpe_multi_freq(self): 198 | desired = wpe.wpe_v0(self.Y, self.K, self.delay, statistics_mode='full') 199 | desired = [desired, desired] 200 | actual = wpe.wpe_v0(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full') 201 | tc.assert_allclose(actual, desired, atol=1e-6) 202 | 203 | desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full') 204 | desired = [desired, desired] 205 | actual = wpe.wpe_v7(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full') 206 | tc.assert_allclose(actual, desired, atol=1e-6) 207 | 208 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full') 209 | desired = [desired, desired] 210 | actual = wpe.wpe_v6(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full') 211 | tc.assert_allclose(actual, desired, atol=1e-6) 212 | 213 | desired = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full') 214 | desired = [desired, desired] 215 | actual = wpe.wpe_v8(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full') 216 | tc.assert_allclose(actual, desired, atol=1e-6) 217 | 218 | @retry(5) 219 | def test_wpe_batched_multi_freq(self): 220 | def to_batched_multi_freq(x): 221 | return np.array([ 222 | [x, x*2], 223 | [x*3, x*4], 224 | [x*5, x*6], 225 | ]) 226 | Y_batched_multi_freq = to_batched_multi_freq(self.Y) 227 | 228 | tc.assert_raises(NotImplementedError, wpe.wpe_v0, Y_batched_multi_freq, self.K, self.delay, statistics_mode='full') 229 | 230 | desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full') 231 | desired = to_batched_multi_freq(desired) 232 | actual = wpe.wpe_v7(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full') 233 | assert desired.shape == (3, 2, self.D, self.T) 234 | assert actual.shape == (3, 2, self.D, self.T) 235 | tc.assert_allclose(actual, desired, atol=1e-6) 236 | 237 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full') 238 | desired = to_batched_multi_freq(desired) 239 | actual = wpe.wpe_v6(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full') 240 | assert desired.shape == (3, 2, self.D, self.T) 241 | assert actual.shape == (3, 2, self.D, self.T) 242 | tc.assert_allclose(actual, desired, atol=1e-6) 243 | 244 | desired = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full') 245 | desired = to_batched_multi_freq(desired) 246 | actual = wpe.wpe_v8(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full') 247 | assert desired.shape == (3, 2, self.D, self.T) 248 | assert actual.shape == (3, 2, self.D, self.T) 249 | tc.assert_allclose(actual, desired, atol=1e-6) 250 | --------------------------------------------------------------------------------