├── .bumpversion.cfg
├── .github
└── workflows
│ └── tests.yml
├── .gitignore
├── .pylintrc
├── .travis.yml
├── LICENSE
├── README.rst
├── data
├── AMI_WSJ20-Array1-1_T10c0201.wav
├── AMI_WSJ20-Array1-2_T10c0201.wav
├── AMI_WSJ20-Array1-3_T10c0201.wav
├── AMI_WSJ20-Array1-4_T10c0201.wav
├── AMI_WSJ20-Array1-5_T10c0201.wav
├── AMI_WSJ20-Array1-6_T10c0201.wav
├── AMI_WSJ20-Array1-7_T10c0201.wav
└── AMI_WSJ20-Array1-8_T10c0201.wav
├── docs
├── Makefile
├── conf.py
├── index.rst
├── make.bat
├── make_apidoc.sh
├── modules.rst
├── nara_wpe.benchmark_online_wpe.rst
├── nara_wpe.gradient_overrides.rst
├── nara_wpe.rst
├── nara_wpe.test_utils.rst
├── nara_wpe.tf_wpe.rst
├── nara_wpe.utils.rst
└── nara_wpe.wpe.rst
├── examples
├── NTT_wrapper_offline.ipynb
├── WPE_Numpy_offline.ipynb
├── WPE_Numpy_online.ipynb
├── WPE_Tensorflow_offline.ipynb
├── WPE_Tensorflow_online.ipynb
└── examples.rst
├── maintenance.md
├── nara_wpe
├── __init__.py
├── benchmark_online_wpe.py
├── ntt_wpe.py
├── test_utils.py
├── tf_wpe.py
├── torch_wpe.py
├── utils.py
└── wpe.py
├── pyproject.toml
├── pytest.ini
├── setup.py
└── tests
├── test_notebooks.py
├── test_tf_wpe.py
└── test_wpe.py
/.bumpversion.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 0.0.11
3 | commit = True
4 | tag = True
5 |
6 | [bumpversion:file:setup.py]
7 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | [push, pull_request]
5 |
6 | jobs:
7 | build:
8 |
9 | runs-on: ${{ matrix.os }}
10 | strategy:
11 | matrix:
12 | python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
13 | os: [ubuntu-latest]
14 | include:
15 | - os: ubuntu-22.04
16 | python-version: '3.7'
17 | - os: macos-latest
18 | python-version: '3.12'
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 |
23 | - name: Set up Python ${{ matrix.python-version }}
24 | uses: actions/setup-python@v2
25 | with:
26 | python-version: ${{ matrix.python-version }}
27 |
28 | - name: Install linux dependencies
29 | run: |
30 | sudo apt-get update
31 | sudo apt-get install libsndfile1 sox
32 | sudo apt-get install libzmq3-dev
33 | pip install numpy pyzmq # pymatbridge needs numpy and pyzmq preinstalled (i.e. does not work to list in setup.py)
34 | # Install of current version of pymatbridge on pypi does not work, using the git-repository instead
35 | pip install git+https://github.com/arokem/python-matlab-bridge.git@master
36 | if: matrix.os != 'macos-latest'
37 | - name: Install macos dependencies
38 | run: |
39 | brew install libsndfile
40 | pip install git+https://github.com/arokem/python-matlab-bridge.git@master
41 | if: matrix.os == 'macos-latest'
42 | - name: Install nara_wpe
43 | run: |
44 | pip install -e .[test]
45 |
46 | - name: Test with pytest
47 | run: |
48 | pytest "tests/" "nara_wpe/"
49 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | *.egg-info
3 | *.pyc
4 | .idea/
5 | docs/_build
6 | docs/_templates
7 | .ipynb_checkpoints/
8 | cache
9 | .pytest_cache
10 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MESSAGES CONTROL]
2 |
3 | # Enable the message, report, category or checker with the given id(s). You can
4 | # either give multiple identifier separated by comma (,) or put this option
5 | # multiple time.
6 | #enable=
7 |
8 | # Disable the message, report, category or checker with the given id(s). You
9 | # can either give multiple identifier separated by comma (,) or put this option
10 | # multiple time (only on the command line, not in the configuration file where
11 | # it should appear only once).
12 | # C0103:Invalid variable name
13 | disable=C0103,
14 |
15 | extension-pkg-whitelist=numpy,scipy
16 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - 2.7
4 | - 3.5
5 | - 3.6
6 | - 3.7
7 | - 3.8
8 | - 3.9
9 |
10 | cache: pip
11 |
12 | install:
13 | - pip install -e .
14 | - pip install tensorflow==1.12.0
15 | - pip install coverage
16 | - pip install jupyter
17 | - pip install matplotlib
18 | - pip install scipy
19 |
20 | script:
21 | - pytest tests/
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Communications Engineering Group, Paderborn University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | ========
2 | nara_wpe
3 | ========
4 |
5 | .. image:: https://readthedocs.org/projects/nara-wpe/badge/?version=latest
6 | :target: http://nara-wpe.readthedocs.io/en/latest/
7 | :alt: Documentation Status
8 |
9 | .. image:: https://github.com/fgnt/nara_wpe/actions/workflows/tests.yml/badge.svg?branch=master
10 | :target: https://github.com/fgnt/nara_wpe/actions/workflows/tests.yml
11 | :alt: Tests
12 |
13 | .. image:: https://img.shields.io/pypi/v/nara-wpe.svg
14 | :target: https://pypi.org/project/nara-wpe/
15 | :alt: PyPI
16 |
17 | .. image:: https://img.shields.io/pypi/dm/nara-wpe.svg
18 | :target: https://pypi.org/project/nara-wpe/
19 | :alt: PyPI
20 |
21 | .. image:: https://img.shields.io/badge/license-MIT-blue.svg
22 | :target: https://raw.githubusercontent.com/fgnt/nara_wpe/master/LICENSE
23 | :alt: MIT License
24 |
25 | Weighted Prediction Error for speech dereverberation
26 | ====================================================
27 |
28 | Background noise and signal reverberation due to reflections in an enclosure are the two main impairments in acoustic
29 | signal processing and far-field speech recognition. This work addresses signal dereverberation techniques based on WPE for speech recognition and other far-field applications.
30 | WPE is a compelling algorithm to blindly dereverberate acoustic signals based on long-term linear prediction.
31 |
32 | The main algorithm is based on the following paper:
33 | Yoshioka, Takuya, and Tomohiro Nakatani. "Generalization of multi-channel linear prediction methods for blind MIMO impulse response shortening." IEEE Transactions on Audio, Speech, and Language Processing 20.10 (2012): 2707-2720.
34 |
35 |
36 | Content
37 | =======
38 |
39 | - Iterative offline WPE/ block-online WPE/ recursive frame-online WPE
40 | - All algorithms implemented both in Numpy and in TensorFlow (works with version `1.12.0`).
41 | - Continuously tested with Python 3.7, 3.8, 3.9 and 3.10.
42 | - Automatically built documentation: `nara-wpe.readthedocs.io `_
43 | - Modular design to facilitate changes for further research
44 |
45 | Installation
46 | ============
47 |
48 | Install it directly with Pip, if you just want to use it:
49 |
50 | .. code-block:: bash
51 |
52 | pip install nara_wpe
53 |
54 | If you want to make changes or want the most recent version: Clone the repository and install it as follows:
55 |
56 | .. code-block:: bash
57 |
58 | git clone https://github.com/fgnt/nara_wpe.git
59 | cd nara_wpe
60 | pip install --editable .
61 |
62 | Check the `example notebook `_ for further details.
63 | If you download the example notebook, you can listen to the input audio examples and to the dereverberated output too.
64 |
65 |
66 | Citation
67 | ========
68 |
69 | To cite this implementation, you can cite the following paper::
70 |
71 | @InProceedings{Drude2018NaraWPE,
72 | Title = {{NARA-WPE}: A Python package for weighted prediction error dereverberation in {Numpy} and {Tensorflow} for online and offline processing},
73 | Author = {Drude, Lukas and Heymann, Jahn and Boeddeker, Christoph and Haeb-Umbach, Reinhold},
74 | Booktitle = {13. ITG Fachtagung Sprachkommunikation (ITG 2018)},
75 | Year = {2018},
76 | Month = {Oct},
77 | }
78 |
79 |
80 | To view the paper see
81 | `IEEE Xplore `__ (`PDF `__)
82 | or for a preview see `Paderborn University RIS `__ (`PDF `__).
83 |
84 |
85 |
86 | Comparision with the NTT WPE implementation
87 | ===========================================
88 |
89 | The fairly recent John Hopkins University paper (Manohar, Vimal: `Acoustic Modeling for Overlapping Speech Recognition: JHU CHiME-5 Challenge System `_, ICASSP 2019) reporting on their CHiME 5 challenge results dedicate an entire table to the comparison of the Nara-WPE implementation and the NTT WPE implementation.
90 | Their result is, that the Nara-WPE implementation is as least as good as the NTT WPE implementation in all their reported conditions.
91 |
92 |
93 | Development history
94 | ====================
95 |
96 | Since 2017-09-05 a TensorFlow implementation has been added to `nara_wpe`.
97 | It has been tested with a few test cases against the Numpy implementation.
98 |
99 | The first version of the Numpy implementation was written in June 2017 while
100 | Lukas Drude and Kateřina Žmolíková resided in Nara, Japan. The aim was to have
101 | a publicly available implementation of Takuya Yoshioka's 2012 paper.
102 |
--------------------------------------------------------------------------------
/data/AMI_WSJ20-Array1-1_T10c0201.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-1_T10c0201.wav
--------------------------------------------------------------------------------
/data/AMI_WSJ20-Array1-2_T10c0201.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-2_T10c0201.wav
--------------------------------------------------------------------------------
/data/AMI_WSJ20-Array1-3_T10c0201.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-3_T10c0201.wav
--------------------------------------------------------------------------------
/data/AMI_WSJ20-Array1-4_T10c0201.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-4_T10c0201.wav
--------------------------------------------------------------------------------
/data/AMI_WSJ20-Array1-5_T10c0201.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-5_T10c0201.wav
--------------------------------------------------------------------------------
/data/AMI_WSJ20-Array1-6_T10c0201.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-6_T10c0201.wav
--------------------------------------------------------------------------------
/data/AMI_WSJ20-Array1-7_T10c0201.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-7_T10c0201.wav
--------------------------------------------------------------------------------
/data/AMI_WSJ20-Array1-8_T10c0201.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fgnt/nara_wpe/5fbfc3639274e5a70dffcf08e20615a38c372ebb/data/AMI_WSJ20-Array1-8_T10c0201.wav
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | PAPER =
8 | BUILDDIR = _build
9 |
10 | # User-friendly check for sphinx-build
11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
13 | endif
14 |
15 | # Internal variables.
16 | PAPEROPT_a4 = -D latex_paper_size=a4
17 | PAPEROPT_letter = -D latex_paper_size=letter
18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
19 | # the i18n builder cannot share the environment and doctrees with the others
20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
21 |
22 | .PHONY: help
23 | help:
24 | @echo "Please use \`make ' where is one of"
25 | @echo " html to make standalone HTML files"
26 | @echo " dirhtml to make HTML files named index.html in directories"
27 | @echo " singlehtml to make a single large HTML file"
28 | @echo " pickle to make pickle files"
29 | @echo " json to make JSON files"
30 | @echo " htmlhelp to make HTML files and a HTML help project"
31 | @echo " qthelp to make HTML files and a qthelp project"
32 | @echo " applehelp to make an Apple Help Book"
33 | @echo " devhelp to make HTML files and a Devhelp project"
34 | @echo " epub to make an epub"
35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
36 | @echo " latexpdf to make LaTeX files and run them through pdflatex"
37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
38 | @echo " text to make text files"
39 | @echo " man to make manual pages"
40 | @echo " texinfo to make Texinfo files"
41 | @echo " info to make Texinfo files and run them through makeinfo"
42 | @echo " gettext to make PO message catalogs"
43 | @echo " changes to make an overview of all changed/added/deprecated items"
44 | @echo " xml to make Docutils-native XML files"
45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes"
46 | @echo " linkcheck to check all external links for integrity"
47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)"
48 | @echo " coverage to run coverage check of the documentation (if enabled)"
49 |
50 | .PHONY: clean
51 | clean:
52 | rm -rf $(BUILDDIR)/*
53 |
54 | .PHONY: html
55 | html:
56 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
57 | @echo
58 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
59 |
60 | .PHONY: dirhtml
61 | dirhtml:
62 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
63 | @echo
64 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
65 |
66 | .PHONY: singlehtml
67 | singlehtml:
68 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
69 | @echo
70 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
71 |
72 | .PHONY: pickle
73 | pickle:
74 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
75 | @echo
76 | @echo "Build finished; now you can process the pickle files."
77 |
78 | .PHONY: json
79 | json:
80 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
81 | @echo
82 | @echo "Build finished; now you can process the JSON files."
83 |
84 | .PHONY: htmlhelp
85 | htmlhelp:
86 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
87 | @echo
88 | @echo "Build finished; now you can run HTML Help Workshop with the" \
89 | ".hhp project file in $(BUILDDIR)/htmlhelp."
90 |
91 | .PHONY: qthelp
92 | qthelp:
93 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
94 | @echo
95 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \
96 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
97 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/nara_wpe.qhcp"
98 | @echo "To view the help file:"
99 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/nara_wpe.qhc"
100 |
101 | .PHONY: applehelp
102 | applehelp:
103 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp
104 | @echo
105 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp."
106 | @echo "N.B. You won't be able to view it unless you put it in" \
107 | "~/Library/Documentation/Help or install it in your application" \
108 | "bundle."
109 |
110 | .PHONY: devhelp
111 | devhelp:
112 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
113 | @echo
114 | @echo "Build finished."
115 | @echo "To view the help file:"
116 | @echo "# mkdir -p $$HOME/.local/share/devhelp/nara_wpe"
117 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/nara_wpe"
118 | @echo "# devhelp"
119 |
120 | .PHONY: epub
121 | epub:
122 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
123 | @echo
124 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
125 |
126 | .PHONY: latex
127 | latex:
128 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
129 | @echo
130 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
131 | @echo "Run \`make' in that directory to run these through (pdf)latex" \
132 | "(use \`make latexpdf' here to do that automatically)."
133 |
134 | .PHONY: latexpdf
135 | latexpdf:
136 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
137 | @echo "Running LaTeX files through pdflatex..."
138 | $(MAKE) -C $(BUILDDIR)/latex all-pdf
139 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
140 |
141 | .PHONY: latexpdfja
142 | latexpdfja:
143 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
144 | @echo "Running LaTeX files through platex and dvipdfmx..."
145 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
146 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
147 |
148 | .PHONY: text
149 | text:
150 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
151 | @echo
152 | @echo "Build finished. The text files are in $(BUILDDIR)/text."
153 |
154 | .PHONY: man
155 | man:
156 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
157 | @echo
158 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
159 |
160 | .PHONY: texinfo
161 | texinfo:
162 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
163 | @echo
164 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
165 | @echo "Run \`make' in that directory to run these through makeinfo" \
166 | "(use \`make info' here to do that automatically)."
167 |
168 | .PHONY: info
169 | info:
170 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
171 | @echo "Running Texinfo files through makeinfo..."
172 | make -C $(BUILDDIR)/texinfo info
173 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
174 |
175 | .PHONY: gettext
176 | gettext:
177 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
178 | @echo
179 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
180 |
181 | .PHONY: changes
182 | changes:
183 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
184 | @echo
185 | @echo "The overview file is in $(BUILDDIR)/changes."
186 |
187 | .PHONY: linkcheck
188 | linkcheck:
189 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
190 | @echo
191 | @echo "Link check complete; look for any errors in the above output " \
192 | "or in $(BUILDDIR)/linkcheck/output.txt."
193 |
194 | .PHONY: doctest
195 | doctest:
196 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
197 | @echo "Testing of doctests in the sources finished, look at the " \
198 | "results in $(BUILDDIR)/doctest/output.txt."
199 |
200 | .PHONY: coverage
201 | coverage:
202 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage
203 | @echo "Testing of coverage in the sources finished, look at the " \
204 | "results in $(BUILDDIR)/coverage/python.txt."
205 |
206 | .PHONY: xml
207 | xml:
208 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
209 | @echo
210 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml."
211 |
212 | .PHONY: pseudoxml
213 | pseudoxml:
214 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
215 | @echo
216 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
217 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #
3 | # nara_wpe documentation build configuration file, created by
4 | # sphinx-quickstart on Thu May 24 11:59:41 2018.
5 | #
6 | # This file is execfile()d with the current directory set to its
7 | # containing dir.
8 | #
9 | # Note that not all possible configuration values are present in this
10 | # autogenerated file.
11 | #
12 | # All configuration values have a default; values that are commented out
13 | # serve to show the default.
14 |
15 | import sys
16 | import os
17 |
18 | sys.path.insert(0, os.path.abspath('../nara_wpe'))
19 |
20 | # -- General configuration ------------------------------------------------
21 |
22 | extensions = [
23 | 'sphinx.ext.autodoc',
24 | 'sphinx.ext.doctest',
25 | 'sphinx.ext.todo',
26 | 'sphinx.ext.mathjax',
27 | 'sphinx.ext.napoleon',
28 | 'sphinx.ext.viewcode',
29 | ]
30 |
31 | # Add any paths that contain templates here, relative to this directory.
32 | templates_path = ['_templates']
33 |
34 | # The suffix(es) of source filenames.
35 | source_suffix = '.rst'
36 |
37 | # The encoding of source files.
38 | #source_encoding = 'utf-8-sig'
39 |
40 | # The master toctree document.
41 | master_doc = 'index'
42 |
43 | # General information about the project.
44 | project = u'nara_wpe'
45 | copyright = u'2018, Lukas Drude'
46 | author = u'Lukas Drude'
47 |
48 | # The version info for the project you're documenting, acts as replacement for
49 | # |version| and |release|, also used in various other places throughout the
50 | # built documents.
51 | #
52 | # The short X.Y version.
53 | version = u'0.0.0'
54 | # The full version, including alpha/beta/rc tags.
55 | release = u'0.0.0'
56 |
57 | # The language for content autogenerated by Sphinx. Refer to documentation
58 | # for a list of supported languages.
59 | #
60 | # This is also used if you do content translation via gettext catalogs.
61 | # Usually you set "language" from the command line for these cases.
62 | language = None
63 |
64 | # There are two options for replacing |today|: either, you set today to some
65 | # non-false value, then it is used:
66 | #today = ''
67 | # Else, today_fmt is used as the format for a strftime call.
68 | #today_fmt = '%B %d, %Y'
69 |
70 | # List of patterns, relative to source directory, that match files and
71 | # directories to ignore when looking for source files.
72 | exclude_patterns = ['_build']
73 |
74 | autodoc_mock_imports = ['pathlib', 'nara_wpe', 'nara_wpe.benchmark_online_wpe',
75 | 'nara_wpe.gradient_overrides', 'nara_wpe.test_utils',
76 | 'nara_wpe.tf_wpe', 'nara_wpe.utils', 'nara_wpe.wpe',
77 | 'numpy', 'pandas', 'tensorflow', ]
78 |
79 |
80 | # The reST default role (used for this markup: `text`) to use for all
81 | # documents.
82 | #default_role = None
83 |
84 | # If true, '()' will be appended to :func: etc. cross-reference text.
85 | #add_function_parentheses = True
86 |
87 | # If true, the current module name will be prepended to all description
88 | # unit titles (such as .. function::).
89 | #add_module_names = True
90 |
91 | # If true, sectionauthor and moduleauthor directives will be shown in the
92 | # output. They are ignored by default.
93 | #show_authors = False
94 |
95 | # The name of the Pygments (syntax highlighting) style to use.
96 | pygments_style = 'sphinx'
97 |
98 | # A list of ignored prefixes for module index sorting.
99 | #modindex_common_prefix = []
100 |
101 | # If true, keep warnings as "system message" paragraphs in the built documents.
102 | #keep_warnings = False
103 |
104 | # If true, `todo` and `todoList` produce output, else they produce nothing.
105 | todo_include_todos = True
106 |
107 |
108 | # -- Options for HTML output ----------------------------------------------
109 |
110 | # The theme to use for HTML and HTML Help pages. See the documentation for
111 | # a list of builtin themes.
112 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True'
113 |
114 | if not on_rtd: # only import and set the theme if we're building docs locally
115 | import sphinx_rtd_theme
116 | html_theme = 'sphinx_rtd_theme'
117 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
118 |
119 |
120 | # Theme options are theme-specific and customize the look and feel of a theme
121 | # further. For a list of options available for each theme, see the
122 | # documentation.
123 | #html_theme_options = {}
124 |
125 | # Add any paths that contain custom themes here, relative to this directory.
126 | #html_theme_path = []
127 |
128 | # The name for this set of Sphinx documents. If None, it defaults to
129 | # " v documentation".
130 | #html_title = None
131 |
132 | # A shorter title for the navigation bar. Default is the same as html_title.
133 | #html_short_title = None
134 |
135 | # The name of an image file (relative to this directory) to place at the top
136 | # of the sidebar.
137 | #html_logo = None
138 |
139 | # The name of an image file (relative to this directory) to use as a favicon of
140 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
141 | # pixels large.
142 | #html_favicon = None
143 |
144 | # Add any paths that contain custom static files (such as style sheets) here,
145 | # relative to this directory. They are copied after the builtin static files,
146 | # so a file named "default.css" will overwrite the builtin "default.css".
147 | html_static_path = ['_static']
148 |
149 | # Add any extra paths that contain custom files (such as robots.txt or
150 | # .htaccess) here, relative to this directory. These files are copied
151 | # directly to the root of the documentation.
152 | #html_extra_path = []
153 |
154 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
155 | # using the given strftime format.
156 | #html_last_updated_fmt = '%b %d, %Y'
157 |
158 | # If true, SmartyPants will be used to convert quotes and dashes to
159 | # typographically correct entities.
160 | #html_use_smartypants = True
161 |
162 | # Custom sidebar templates, maps document names to template names.
163 | #html_sidebars = {}
164 |
165 | # Additional templates that should be rendered to pages, maps page names to
166 | # template names.
167 | #html_additional_pages = {}
168 |
169 | # If false, no module index is generated.
170 | #html_domain_indices = True
171 |
172 | # If false, no index is generated.
173 | #html_use_index = True
174 |
175 | # If true, the index is split into individual pages for each letter.
176 | #html_split_index = False
177 |
178 | # If true, links to the reST sources are added to the pages.
179 | #html_show_sourcelink = True
180 |
181 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
182 | #html_show_sphinx = True
183 |
184 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
185 | #html_show_copyright = True
186 |
187 | # If true, an OpenSearch description file will be output, and all pages will
188 | # contain a tag referring to it. The value of this option must be the
189 | # base URL from which the finished HTML is served.
190 | #html_use_opensearch = ''
191 |
192 | # This is the file name suffix for HTML files (e.g. ".xhtml").
193 | #html_file_suffix = None
194 |
195 | # Language to be used for generating the HTML full-text search index.
196 | # Sphinx supports the following languages:
197 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja'
198 | # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr'
199 | #html_search_language = 'en'
200 |
201 | # A dictionary with options for the search language support, empty by default.
202 | # Now only 'ja' uses this config value
203 | #html_search_options = {'type': 'default'}
204 |
205 | # The name of a javascript file (relative to the configuration directory) that
206 | # implements a search results scorer. If empty, the default will be used.
207 | #html_search_scorer = 'scorer.js'
208 |
209 | # Output file base name for HTML help builder.
210 | htmlhelp_basename = 'nara_wpedoc'
211 |
212 | # -- Options for LaTeX output ---------------------------------------------
213 |
214 | latex_elements = {
215 | # The paper size ('letterpaper' or 'a4paper').
216 | #'papersize': 'letterpaper',
217 |
218 | # The font size ('10pt', '11pt' or '12pt').
219 | #'pointsize': '10pt',
220 |
221 | # Additional stuff for the LaTeX preamble.
222 | #'preamble': '',
223 |
224 | # Latex figure (float) alignment
225 | #'figure_align': 'htbp',
226 | }
227 |
228 | # Grouping the document tree into LaTeX files. List of tuples
229 | # (source start file, target name, title,
230 | # author, documentclass [howto, manual, or own class]).
231 | latex_documents = [
232 | (master_doc, 'nara_wpe.tex', u'nara\\_wpe Documentation',
233 | u'Lukas Drude', 'manual'),
234 | ]
235 |
236 | # The name of an image file (relative to this directory) to place at the top of
237 | # the title page.
238 | #latex_logo = None
239 |
240 | # For "manual" documents, if this is true, then toplevel headings are parts,
241 | # not chapters.
242 | #latex_use_parts = False
243 |
244 | # If true, show page references after internal links.
245 | #latex_show_pagerefs = False
246 |
247 | # If true, show URL addresses after external links.
248 | #latex_show_urls = False
249 |
250 | # Documents to append as an appendix to all manuals.
251 | #latex_appendices = []
252 |
253 | # If false, no module index is generated.
254 | #latex_domain_indices = True
255 |
256 |
257 | # -- Options for manual page output ---------------------------------------
258 |
259 | # One entry per manual page. List of tuples
260 | # (source start file, name, description, authors, manual section).
261 | man_pages = [
262 | (master_doc, 'nara_wpe', u'nara_wpe Documentation',
263 | [author], 1)
264 | ]
265 |
266 | # If true, show URL addresses after external links.
267 | #man_show_urls = False
268 |
269 |
270 | # -- Options for Texinfo output -------------------------------------------
271 |
272 | # Grouping the document tree into Texinfo files. List of tuples
273 | # (source start file, target name, title, author,
274 | # dir menu entry, description, category)
275 | texinfo_documents = [
276 | (master_doc, 'nara_wpe', u'nara_wpe Documentation',
277 | author, 'nara_wpe', 'One line description of project.',
278 | 'Miscellaneous'),
279 | ]
280 |
281 | # Documents to append as an appendix to all manuals.
282 | #texinfo_appendices = []
283 |
284 | # If false, no module index is generated.
285 | #texinfo_domain_indices = True
286 |
287 | # How to display URL addresses: 'footnote', 'no', or 'inline'.
288 | #texinfo_show_urls = 'footnote'
289 |
290 | # If true, do not generate a @detailmenu in the "Top" node's menu.
291 | #texinfo_no_detailmenu = False
292 |
293 |
294 | # Example configuration for intersphinx: refer to the Python standard library.
295 | intersphinx_mapping = {'https://docs.python.org/': None}
296 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. nara_wpe documentation master file, created by
2 | sphinx-quickstart on Thu May 24 11:59:41 2018.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | .. include:: ../README.rst
7 |
8 | Welcome to nara_wpe's documentation!
9 | ====================================
10 |
11 | Table of contents:
12 |
13 | .. toctree::
14 | :maxdepth: 2
15 |
16 | nara_wpe
17 |
18 |
19 |
20 | Indices and tables
21 | ==================
22 |
23 | * :ref:`genindex`
24 | * :ref:`modindex`
25 | * :ref:`search`
26 |
27 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | REM Command file for Sphinx documentation
4 |
5 | if "%SPHINXBUILD%" == "" (
6 | set SPHINXBUILD=sphinx-build
7 | )
8 | set BUILDDIR=_build
9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% .
10 | set I18NSPHINXOPTS=%SPHINXOPTS% .
11 | if NOT "%PAPER%" == "" (
12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
14 | )
15 |
16 | if "%1" == "" goto help
17 |
18 | if "%1" == "help" (
19 | :help
20 | echo.Please use `make ^` where ^ is one of
21 | echo. html to make standalone HTML files
22 | echo. dirhtml to make HTML files named index.html in directories
23 | echo. singlehtml to make a single large HTML file
24 | echo. pickle to make pickle files
25 | echo. json to make JSON files
26 | echo. htmlhelp to make HTML files and a HTML help project
27 | echo. qthelp to make HTML files and a qthelp project
28 | echo. devhelp to make HTML files and a Devhelp project
29 | echo. epub to make an epub
30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
31 | echo. text to make text files
32 | echo. man to make manual pages
33 | echo. texinfo to make Texinfo files
34 | echo. gettext to make PO message catalogs
35 | echo. changes to make an overview over all changed/added/deprecated items
36 | echo. xml to make Docutils-native XML files
37 | echo. pseudoxml to make pseudoxml-XML files for display purposes
38 | echo. linkcheck to check all external links for integrity
39 | echo. doctest to run all doctests embedded in the documentation if enabled
40 | echo. coverage to run coverage check of the documentation if enabled
41 | goto end
42 | )
43 |
44 | if "%1" == "clean" (
45 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
46 | del /q /s %BUILDDIR%\*
47 | goto end
48 | )
49 |
50 |
51 | REM Check if sphinx-build is available and fallback to Python version if any
52 | %SPHINXBUILD% 1>NUL 2>NUL
53 | if errorlevel 9009 goto sphinx_python
54 | goto sphinx_ok
55 |
56 | :sphinx_python
57 |
58 | set SPHINXBUILD=python -m sphinx.__init__
59 | %SPHINXBUILD% 2> nul
60 | if errorlevel 9009 (
61 | echo.
62 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
63 | echo.installed, then set the SPHINXBUILD environment variable to point
64 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
65 | echo.may add the Sphinx directory to PATH.
66 | echo.
67 | echo.If you don't have Sphinx installed, grab it from
68 | echo.http://sphinx-doc.org/
69 | exit /b 1
70 | )
71 |
72 | :sphinx_ok
73 |
74 |
75 | if "%1" == "html" (
76 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
77 | if errorlevel 1 exit /b 1
78 | echo.
79 | echo.Build finished. The HTML pages are in %BUILDDIR%/html.
80 | goto end
81 | )
82 |
83 | if "%1" == "dirhtml" (
84 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
85 | if errorlevel 1 exit /b 1
86 | echo.
87 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
88 | goto end
89 | )
90 |
91 | if "%1" == "singlehtml" (
92 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
93 | if errorlevel 1 exit /b 1
94 | echo.
95 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
96 | goto end
97 | )
98 |
99 | if "%1" == "pickle" (
100 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
101 | if errorlevel 1 exit /b 1
102 | echo.
103 | echo.Build finished; now you can process the pickle files.
104 | goto end
105 | )
106 |
107 | if "%1" == "json" (
108 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
109 | if errorlevel 1 exit /b 1
110 | echo.
111 | echo.Build finished; now you can process the JSON files.
112 | goto end
113 | )
114 |
115 | if "%1" == "htmlhelp" (
116 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
117 | if errorlevel 1 exit /b 1
118 | echo.
119 | echo.Build finished; now you can run HTML Help Workshop with the ^
120 | .hhp project file in %BUILDDIR%/htmlhelp.
121 | goto end
122 | )
123 |
124 | if "%1" == "qthelp" (
125 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
126 | if errorlevel 1 exit /b 1
127 | echo.
128 | echo.Build finished; now you can run "qcollectiongenerator" with the ^
129 | .qhcp project file in %BUILDDIR%/qthelp, like this:
130 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\nara_wpe.qhcp
131 | echo.To view the help file:
132 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\nara_wpe.ghc
133 | goto end
134 | )
135 |
136 | if "%1" == "devhelp" (
137 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
138 | if errorlevel 1 exit /b 1
139 | echo.
140 | echo.Build finished.
141 | goto end
142 | )
143 |
144 | if "%1" == "epub" (
145 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
146 | if errorlevel 1 exit /b 1
147 | echo.
148 | echo.Build finished. The epub file is in %BUILDDIR%/epub.
149 | goto end
150 | )
151 |
152 | if "%1" == "latex" (
153 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
154 | if errorlevel 1 exit /b 1
155 | echo.
156 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
157 | goto end
158 | )
159 |
160 | if "%1" == "latexpdf" (
161 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
162 | cd %BUILDDIR%/latex
163 | make all-pdf
164 | cd %~dp0
165 | echo.
166 | echo.Build finished; the PDF files are in %BUILDDIR%/latex.
167 | goto end
168 | )
169 |
170 | if "%1" == "latexpdfja" (
171 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
172 | cd %BUILDDIR%/latex
173 | make all-pdf-ja
174 | cd %~dp0
175 | echo.
176 | echo.Build finished; the PDF files are in %BUILDDIR%/latex.
177 | goto end
178 | )
179 |
180 | if "%1" == "text" (
181 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
182 | if errorlevel 1 exit /b 1
183 | echo.
184 | echo.Build finished. The text files are in %BUILDDIR%/text.
185 | goto end
186 | )
187 |
188 | if "%1" == "man" (
189 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
190 | if errorlevel 1 exit /b 1
191 | echo.
192 | echo.Build finished. The manual pages are in %BUILDDIR%/man.
193 | goto end
194 | )
195 |
196 | if "%1" == "texinfo" (
197 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
198 | if errorlevel 1 exit /b 1
199 | echo.
200 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
201 | goto end
202 | )
203 |
204 | if "%1" == "gettext" (
205 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
206 | if errorlevel 1 exit /b 1
207 | echo.
208 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
209 | goto end
210 | )
211 |
212 | if "%1" == "changes" (
213 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
214 | if errorlevel 1 exit /b 1
215 | echo.
216 | echo.The overview file is in %BUILDDIR%/changes.
217 | goto end
218 | )
219 |
220 | if "%1" == "linkcheck" (
221 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
222 | if errorlevel 1 exit /b 1
223 | echo.
224 | echo.Link check complete; look for any errors in the above output ^
225 | or in %BUILDDIR%/linkcheck/output.txt.
226 | goto end
227 | )
228 |
229 | if "%1" == "doctest" (
230 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
231 | if errorlevel 1 exit /b 1
232 | echo.
233 | echo.Testing of doctests in the sources finished, look at the ^
234 | results in %BUILDDIR%/doctest/output.txt.
235 | goto end
236 | )
237 |
238 | if "%1" == "coverage" (
239 | %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage
240 | if errorlevel 1 exit /b 1
241 | echo.
242 | echo.Testing of coverage in the sources finished, look at the ^
243 | results in %BUILDDIR%/coverage/python.txt.
244 | goto end
245 | )
246 |
247 | if "%1" == "xml" (
248 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml
249 | if errorlevel 1 exit /b 1
250 | echo.
251 | echo.Build finished. The XML files are in %BUILDDIR%/xml.
252 | goto end
253 | )
254 |
255 | if "%1" == "pseudoxml" (
256 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml
257 | if errorlevel 1 exit /b 1
258 | echo.
259 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.
260 | goto end
261 | )
262 |
263 | :end
264 |
--------------------------------------------------------------------------------
/docs/make_apidoc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | sphinx-apidoc --no-toc -e -f -o . ../nara_wpe
--------------------------------------------------------------------------------
/docs/modules.rst:
--------------------------------------------------------------------------------
1 | .. toctree::
2 | :maxdepth: 4
3 |
4 | nara_wpe
--------------------------------------------------------------------------------
/docs/nara_wpe.benchmark_online_wpe.rst:
--------------------------------------------------------------------------------
1 | nara\_wpe.benchmark\_online\_wpe module
2 | =======================================
3 |
4 | .. automodule:: benchmark_online_wpe
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/nara_wpe.gradient_overrides.rst:
--------------------------------------------------------------------------------
1 | nara\_wpe.gradient\_overrides module
2 | ====================================
3 |
4 | .. automodule:: gradient_overrides
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/nara_wpe.rst:
--------------------------------------------------------------------------------
1 | nara\_wpe package
2 | =================
3 |
4 | Submodules
5 | ----------
6 |
7 | .. toctree::
8 |
9 | nara_wpe.benchmark_online_wpe
10 | nara_wpe.gradient_overrides
11 | nara_wpe.test_utils
12 | nara_wpe.tf_wpe
13 | nara_wpe.utils
14 | nara_wpe.wpe
15 |
16 | Module contents
17 | ---------------
18 |
19 | .. automodule:: nara_wpe
20 | :members:
21 | :undoc-members:
22 | :show-inheritance:
23 |
--------------------------------------------------------------------------------
/docs/nara_wpe.test_utils.rst:
--------------------------------------------------------------------------------
1 | nara\_wpe.test\_utils module
2 | ============================
3 |
4 | .. automodule:: test_utils
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/nara_wpe.tf_wpe.rst:
--------------------------------------------------------------------------------
1 | nara\_wpe.tf\_wpe module
2 | ========================
3 |
4 | .. automodule:: tf_wpe
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/nara_wpe.utils.rst:
--------------------------------------------------------------------------------
1 | nara\_wpe.utils module
2 | ======================
3 |
4 | .. automodule:: utils
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/nara_wpe.wpe.rst:
--------------------------------------------------------------------------------
1 | nara\_wpe.wpe module
2 | ====================
3 |
4 | .. automodule:: wpe
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/examples/NTT_wrapper_offline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%reload_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "%matplotlib inline\n",
12 | "\n",
13 | "import IPython\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import numpy as np\n",
16 | "import soundfile as sf\n",
17 | "import time\n",
18 | "\n",
19 | "from nara_wpe.ntt_wpe import ntt_wrapper as wpe\n",
20 | "from nara_wpe import project_root\n",
21 | "from nara_wpe.utils import stft, istft"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "# Minimal example with random data"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "def aquire_audio_data():\n",
38 | " D, T = 4, 10000\n",
39 | " y = np.random.normal(size=(D, T))\n",
40 | " return y"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": null,
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "y = aquire_audio_data()\n",
50 | "\n",
51 | "start = time.perf_counter()\n",
52 | "x = wpe(y)\n",
53 | "end = time.perf_counter()\n",
54 | "\n",
55 | "print(f\"Time: {end-start}\")"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "# Example with real audio recordings"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {},
68 | "source": [
69 | "WPE estimates a filter to predict the current reverberation tail frame from K time frames which lie 3 (delay) time frames in the past. This frame (reverberation tail) is then subtracted from the observed signal.\n",
70 | "\n",
71 | "### Setup"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "channels = 8\n",
81 | "sampling_rate = 16000\n",
82 | "delay = 3\n",
83 | "iterations = 5\n",
84 | "taps = 10"
85 | ]
86 | },
87 | {
88 | "cell_type": "markdown",
89 | "metadata": {},
90 | "source": [
91 | "### Audio data\n",
92 | "Shape: (frames, channels)"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": null,
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n",
102 | "signal_list = [\n",
103 | " sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n",
104 | " for d in range(channels)\n",
105 | "]\n",
106 | "y = np.stack(signal_list, axis=0)\n",
107 | "IPython.display.Audio(y[0], rate=sampling_rate)"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "metadata": {},
113 | "source": [
114 | "### iterative WPE\n",
115 | "The wpe function is fed with y. The STFT and ISTFT is included in the Matlab package of NTT. "
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "x = wpe(y, iterations=iterations)\n",
125 | "IPython.display.Audio(x[0], rate=sampling_rate)"
126 | ]
127 | },
128 | {
129 | "cell_type": "markdown",
130 | "metadata": {},
131 | "source": [
132 | "## Power spectrum \n",
133 | "Before and after applying NTT WPE"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": null,
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "stft_options = dict(\n",
143 | " size=512,\n",
144 | " shift=128,\n",
145 | " window_length=None,\n",
146 | " fading=True,\n",
147 | " pad=True,\n",
148 | " symmetric_window=False\n",
149 | ")\n",
150 | "Y = stft(y, **stft_options).transpose(2, 0, 1)\n",
151 | "X = stft(x, **stft_options).transpose(2, 0, 1)\n",
152 | "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))\n",
153 | "im1 = ax1.imshow(20 * np.log10(np.abs(Y[ :, 0, 200:400])), origin='lower')\n",
154 | "ax1.set_xlabel('frames')\n",
155 | "_ = ax1.set_title('reverberated')\n",
156 | "im2 = ax2.imshow(20 * np.log10(np.abs(X[ :, 0, 200:400])), origin='lower', vmin=-120, vmax=0)\n",
157 | "ax2.set_xlabel('frames')\n",
158 | "_ = ax2.set_title('dereverberated')\n",
159 | "cb = fig.colorbar(im2)"
160 | ]
161 | }
162 | ],
163 | "metadata": {
164 | "kernelspec": {
165 | "display_name": "py36",
166 | "language": "python",
167 | "name": "py36"
168 | },
169 | "language_info": {
170 | "codemirror_mode": {
171 | "name": "ipython",
172 | "version": 3
173 | },
174 | "file_extension": ".py",
175 | "mimetype": "text/x-python",
176 | "name": "python",
177 | "nbconvert_exporter": "python",
178 | "pygments_lexer": "ipython3",
179 | "version": "3.6.6"
180 | }
181 | },
182 | "nbformat": 4,
183 | "nbformat_minor": 2
184 | }
185 |
--------------------------------------------------------------------------------
/examples/WPE_Numpy_offline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%reload_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "%matplotlib inline\n",
12 | "\n",
13 | "import IPython\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import numpy as np\n",
16 | "import soundfile as sf\n",
17 | "from tqdm import tqdm\n",
18 | "\n",
19 | "from nara_wpe.wpe import wpe\n",
20 | "from nara_wpe.wpe import get_power\n",
21 | "from nara_wpe.utils import stft, istft, get_stft_center_frequencies\n",
22 | "from nara_wpe import project_root"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "stft_options = dict(size=512, shift=128)"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {},
37 | "source": [
38 | "# Minimal example with random data"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "def aquire_audio_data():\n",
48 | " D, T = 4, 10000\n",
49 | " y = np.random.normal(size=(D, T))\n",
50 | " return y"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "y = aquire_audio_data()\n",
60 | "Y = stft(y, **stft_options)\n",
61 | "Y = Y.transpose(2, 0, 1)\n",
62 | "\n",
63 | "Z = wpe(Y)\n",
64 | "z_np = istft(Z.transpose(1, 2, 0), size=stft_options['size'], shift=stft_options['shift'])"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "metadata": {},
70 | "source": [
71 | "# Example with real audio recordings"
72 | ]
73 | },
74 | {
75 | "cell_type": "markdown",
76 | "metadata": {},
77 | "source": [
78 | "WPE estimates a filter to predict the current reverberation tail frame from K time frames which lie 3 (delay) time frames in the past. This frame (reverberation tail) is then subtracted from the observed signal.\n",
79 | "\n",
80 | "### Setup"
81 | ]
82 | },
83 | {
84 | "cell_type": "code",
85 | "execution_count": null,
86 | "metadata": {},
87 | "outputs": [],
88 | "source": [
89 | "channels = 8\n",
90 | "sampling_rate = 16000\n",
91 | "delay = 3\n",
92 | "iterations = 5\n",
93 | "taps = 10\n",
94 | "alpha=0.9999"
95 | ]
96 | },
97 | {
98 | "cell_type": "markdown",
99 | "metadata": {},
100 | "source": [
101 | "### Audio data\n",
102 | "Shape: (channels, frames)"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n",
112 | "signal_list = [\n",
113 | " sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n",
114 | " for d in range(channels)\n",
115 | "]\n",
116 | "y = np.stack(signal_list, axis=0)\n",
117 | "IPython.display.Audio(y[0], rate=sampling_rate)"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "metadata": {},
123 | "source": [
124 | "### STFT\n",
125 | "A STFT is performed to obtain a Numpy array with shape (frequency bins, channels, frames)."
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "Y = stft(y, **stft_options).transpose(2, 0, 1)"
135 | ]
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "metadata": {},
140 | "source": [
141 | "### Iterative WPE\n",
142 | "The wpe function is fed with Y. Finally, an inverse STFT is performed to obtain a dereverberated result in time domain. "
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "metadata": {},
149 | "outputs": [],
150 | "source": [
151 | "Z = wpe(\n",
152 | " Y,\n",
153 | " taps=taps,\n",
154 | " delay=delay,\n",
155 | " iterations=iterations,\n",
156 | " statistics_mode='full'\n",
157 | ").transpose(1, 2, 0)\n",
158 | "z = istft(Z, size=stft_options['size'], shift=stft_options['shift'])\n",
159 | "IPython.display.Audio(z[0], rate=sampling_rate)"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | "## Power spectrum \n",
167 | "Before and after applying WPE"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {},
174 | "outputs": [],
175 | "source": [
176 | "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 10))\n",
177 | "im1 = ax1.imshow(20 * np.log10(np.abs(Y[ :, 0, 200:400])), origin='lower')\n",
178 | "ax1.set_xlabel('frames')\n",
179 | "_ = ax1.set_title('reverberated')\n",
180 | "im2 = ax2.imshow(20 * np.log10(np.abs(Z[0, 200:400, :])).T, origin='lower', vmin=-120, vmax=0)\n",
181 | "ax2.set_xlabel('frames')\n",
182 | "_ = ax2.set_title('dereverberated')\n",
183 | "cb = fig.colorbar(im2)"
184 | ]
185 | }
186 | ],
187 | "metadata": {
188 | "kernelspec": {
189 | "display_name": "Python 3",
190 | "language": "python",
191 | "name": "python3"
192 | },
193 | "language_info": {
194 | "codemirror_mode": {
195 | "name": "ipython",
196 | "version": 3
197 | },
198 | "file_extension": ".py",
199 | "mimetype": "text/x-python",
200 | "name": "python",
201 | "nbconvert_exporter": "python",
202 | "pygments_lexer": "ipython3",
203 | "version": "3.6.6"
204 | }
205 | },
206 | "nbformat": 4,
207 | "nbformat_minor": 2
208 | }
209 |
--------------------------------------------------------------------------------
/examples/WPE_Numpy_online.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%reload_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "%matplotlib inline\n",
12 | "\n",
13 | "import IPython\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import numpy as np\n",
16 | "import soundfile as sf\n",
17 | "import time\n",
18 | "from tqdm import tqdm\n",
19 | "\n",
20 | "from nara_wpe.wpe import online_wpe_step, get_power_online, OnlineWPE\n",
21 | "from nara_wpe.utils import stft, istft, get_stft_center_frequencies\n",
22 | "from nara_wpe import project_root"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "stft_options = dict(size=512, shift=128)"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {},
37 | "source": [
38 | "# Example with real audio recordings\n",
39 | "The iterations are dropped in contrast to the offline version. To use past observations the correlation matrix and the correlation vector are calculated recursively with a decaying window. $\\alpha$ is the decay factor."
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "### Setup"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "channels = 8\n",
56 | "sampling_rate = 16000\n",
57 | "delay = 3\n",
58 | "alpha=0.9999\n",
59 | "taps = 10\n",
60 | "frequency_bins = stft_options['size'] // 2 + 1"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "### Audio data"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n",
77 | "signal_list = [\n",
78 | " sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n",
79 | " for d in range(channels)\n",
80 | "]\n",
81 | "y = np.stack(signal_list, axis=0)\n",
82 | "IPython.display.Audio(y[0], rate=sampling_rate)"
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {},
88 | "source": [
89 | "### Online buffer\n",
90 | "For simplicity the STFT is performed before providing the frames.\n",
91 | "\n",
92 | "Shape: (frames, frequency bins, channels)\n",
93 | "\n",
94 | "frames: K+delay+1"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "Y = stft(y, **stft_options).transpose(1, 2, 0)\n",
104 | "T, _, _ = Y.shape\n",
105 | "\n",
106 | "def aquire_framebuffer():\n",
107 | " buffer = list(Y[:taps+delay, :, :])\n",
108 | " for t in range(taps+delay+1, T):\n",
109 | " buffer.append(Y[t, :, :])\n",
110 | " yield np.array(buffer)\n",
111 | " buffer.pop(0)"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {},
117 | "source": [
118 | "### Non-iterative frame online approach\n",
119 | "A frame online example requires, that certain state variables are kept from frame to frame. That is the inverse correlation matrix $\\text{R}_{t, f}^{-1}$ which is stored in Q and initialized with an identity matrix, as well as filter coefficient matrix that is stored in G and initialized with zeros. \n",
120 | "\n",
121 | "Again for simplicity the ISTFT is applied afterwards."
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": null,
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "Z_list = []\n",
131 | "Q = np.stack([np.identity(channels * taps) for a in range(frequency_bins)])\n",
132 | "G = np.zeros((frequency_bins, channels * taps, channels))\n",
133 | "\n",
134 | "for Y_step in tqdm(aquire_framebuffer()):\n",
135 | " Z, Q, G = online_wpe_step(\n",
136 | " Y_step,\n",
137 | " get_power_online(Y_step.transpose(1, 2, 0)),\n",
138 | " Q,\n",
139 | " G,\n",
140 | " alpha=alpha,\n",
141 | " taps=taps,\n",
142 | " delay=delay\n",
143 | " )\n",
144 | " Z_list.append(Z)\n",
145 | "\n",
146 | "Z_stacked = np.stack(Z_list)\n",
147 | "z = istft(np.asarray(Z_stacked).transpose(2, 0, 1), size=stft_options['size'], shift=stft_options['shift'])\n",
148 | "\n",
149 | "IPython.display.Audio(z[0], rate=sampling_rate)"
150 | ]
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "metadata": {},
155 | "source": [
156 | "## Frame online WPE in class fashion:"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "metadata": {},
162 | "source": [
163 | "Online WPE class holds the correlation Matrix and the coefficient matrix. "
164 | ]
165 | },
166 | {
167 | "cell_type": "code",
168 | "execution_count": null,
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "Z_list = []\n",
173 | "online_wpe = OnlineWPE(\n",
174 | " taps=taps,\n",
175 | " delay=delay,\n",
176 | " alpha=alpha\n",
177 | ")\n",
178 | "for Y_step in tqdm(aquire_framebuffer()):\n",
179 | " Z_list.append(online_wpe.step_frame(Y_step))\n",
180 | "\n",
181 | "Z = np.stack(Z_list)\n",
182 | "z = istft(np.asarray(Z).transpose(2, 0, 1), size=stft_options['size'], shift=stft_options['shift'])\n",
183 | "\n",
184 | "IPython.display.Audio(z[0], rate=sampling_rate)"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "metadata": {},
190 | "source": [
191 | "# Power spectrum\n",
192 | "Before and after applying WPE."
193 | ]
194 | },
195 | {
196 | "cell_type": "code",
197 | "execution_count": null,
198 | "metadata": {},
199 | "outputs": [],
200 | "source": [
201 | "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 8))\n",
202 | "im1 = ax1.imshow(20 * np.log10(np.abs(Y[200:400, :, 0])).T, origin='lower')\n",
203 | "ax1.set_xlabel('')\n",
204 | "_ = ax1.set_title('reverberated')\n",
205 | "im2 = ax2.imshow(20 * np.log10(np.abs(Z_stacked[200:400, :, 0])).T, origin='lower')\n",
206 | "_ = ax2.set_title('dereverberated')\n",
207 | "cb = fig.colorbar(im1)"
208 | ]
209 | }
210 | ],
211 | "metadata": {
212 | "kernelspec": {
213 | "display_name": "Python 3",
214 | "language": "python",
215 | "name": "python3"
216 | },
217 | "language_info": {
218 | "codemirror_mode": {
219 | "name": "ipython",
220 | "version": 3
221 | },
222 | "file_extension": ".py",
223 | "mimetype": "text/x-python",
224 | "name": "python",
225 | "nbconvert_exporter": "python",
226 | "pygments_lexer": "ipython3",
227 | "version": "3.6.6"
228 | }
229 | },
230 | "nbformat": 4,
231 | "nbformat_minor": 2
232 | }
233 |
--------------------------------------------------------------------------------
/examples/WPE_Tensorflow_offline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%reload_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "%matplotlib inline\n",
12 | "\n",
13 | "import IPython\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import numpy as np\n",
16 | "import soundfile as sf\n",
17 | "from tqdm import tqdm\n",
18 | "import tensorflow as tf\n",
19 | "\n",
20 | "from nara_wpe.tf_wpe import wpe\n",
21 | "from nara_wpe.utils import stft, istft, get_stft_center_frequencies\n",
22 | "from nara_wpe import project_root"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": null,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "stft_options = dict(size=512, shift=128)"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {},
37 | "source": [
38 | "# Minimal example with random data"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "def aquire_audio_data():\n",
48 | " D, T = 4, 10000\n",
49 | " y = np.random.normal(size=(D, T))\n",
50 | " return y"
51 | ]
52 | },
53 | {
54 | "cell_type": "code",
55 | "execution_count": null,
56 | "metadata": {
57 | "scrolled": false
58 | },
59 | "outputs": [],
60 | "source": [
61 | "y = aquire_audio_data()\n",
62 | "Y = stft(y, **stft_options).transpose(2, 0, 1)\n",
63 | "with tf.Session() as session:\n",
64 | " Y_tf = tf.placeholder(\n",
65 | " tf.complex128, shape=(None, None, None))\n",
66 | " Z_tf = wpe(Y_tf)\n",
67 | " Z = session.run(Z_tf, {Y_tf: Y})\n",
68 | "z_tf = istft(Z.transpose(1, 2, 0), size=stft_options['size'], shift=stft_options['shift'])"
69 | ]
70 | },
71 | {
72 | "cell_type": "markdown",
73 | "metadata": {},
74 | "source": [
75 | "# Example with real audio recordings\n",
76 | "WPE estimates a filter to predict the current reverberation tail frame from K time frames which lie 3 (delay) time frames in the past. This frame (reverberation tail) is then subtracted from the observed signal."
77 | ]
78 | },
79 | {
80 | "cell_type": "markdown",
81 | "metadata": {},
82 | "source": [
83 | "### Setup"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": null,
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "channels = 8\n",
93 | "sampling_rate = 16000\n",
94 | "delay = 3\n",
95 | "iterations = 5\n",
96 | "taps = 10"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {},
102 | "source": [
103 | "### Audio data\n",
104 | "Shape: (frames, channels)"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": null,
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n",
114 | "signal_list = [\n",
115 | " sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n",
116 | " for d in range(channels)\n",
117 | "]\n",
118 | "y = np.stack(signal_list, axis=0)\n",
119 | "IPython.display.Audio(y[0], rate=sampling_rate)"
120 | ]
121 | },
122 | {
123 | "cell_type": "markdown",
124 | "metadata": {},
125 | "source": [
126 | "### STFT\n",
127 | "For simplicity reasons we calculate the STFT in Numpy and provide the result in form of the Tensorflow feed dict."
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": null,
133 | "metadata": {},
134 | "outputs": [],
135 | "source": [
136 | "Y = stft(y, **stft_options).transpose(2, 0, 1)"
137 | ]
138 | },
139 | {
140 | "cell_type": "markdown",
141 | "metadata": {},
142 | "source": [
143 | "### iterative WPE\n",
144 | "A placeholder for Y is declared. The wpe function is fed with Y via the Tensorflow feed dict. Finally, an inverse STFT in Numpy is performed to obtain a dereverberated result in time domain."
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": null,
150 | "metadata": {},
151 | "outputs": [],
152 | "source": [
153 | "from nara_wpe.tf_wpe import get_power\n",
154 | "with tf.Session()as session:\n",
155 | " Y_tf = tf.placeholder(tf.complex128, shape=(None, None, None))\n",
156 | " Z_tf = wpe(Y_tf, taps=taps, iterations=iterations)\n",
157 | " Z = session.run(Z_tf, {Y_tf: Y})\n",
158 | "z = istft(Z.transpose(1, 2, 0), size=stft_options['size'], shift=stft_options['shift'])\n",
159 | "IPython.display.Audio(z[0], rate=sampling_rate)"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "metadata": {},
165 | "source": [
166 | "# Power spectrum\n",
167 | "Before and after applying WPE"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {},
174 | "outputs": [],
175 | "source": [
176 | "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 8))\n",
177 | "im1 = ax1.imshow(20 * np.log10(np.abs(Y[:, 0, 200:400])), origin='lower')\n",
178 | "ax1.set_xlabel('')\n",
179 | "_ = ax1.set_title('reverberated')\n",
180 | "im2 = ax2.imshow(20 * np.log10(np.abs(Z[:, 0, 200:400])), origin='lower')\n",
181 | "_ = ax2.set_title('dereverberated')\n",
182 | "cb = fig.colorbar(im1)"
183 | ]
184 | }
185 | ],
186 | "metadata": {
187 | "kernelspec": {
188 | "display_name": "Python 3",
189 | "language": "python",
190 | "name": "python3"
191 | },
192 | "language_info": {
193 | "codemirror_mode": {
194 | "name": "ipython",
195 | "version": 3
196 | },
197 | "file_extension": ".py",
198 | "mimetype": "text/x-python",
199 | "name": "python",
200 | "nbconvert_exporter": "python",
201 | "pygments_lexer": "ipython3",
202 | "version": "3.6.6"
203 | }
204 | },
205 | "nbformat": 4,
206 | "nbformat_minor": 2
207 | }
208 |
--------------------------------------------------------------------------------
/examples/WPE_Tensorflow_online.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%reload_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "%matplotlib inline\n",
12 | "\n",
13 | "import IPython\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import numpy as np\n",
16 | "import soundfile as sf\n",
17 | "import time\n",
18 | "from tqdm import tqdm\n",
19 | "import tensorflow as tf\n",
20 | "\n",
21 | "from nara_wpe.tf_wpe import wpe\n",
22 | "from nara_wpe.tf_wpe import online_wpe_step, get_power_online\n",
23 | "from nara_wpe.utils import stft, istft, get_stft_center_frequencies\n",
24 | "from nara_wpe import project_root"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "stft_options = dict(\n",
34 | " size=512,\n",
35 | " shift=128,\n",
36 | " window_length=None,\n",
37 | " fading=True,\n",
38 | " pad=True,\n",
39 | " symmetric_window=False\n",
40 | ")"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {},
46 | "source": [
47 | "# Example with real audio recordings\n",
48 | "The iterations are dropped in contrast to the offline version. To use past observations the correlation matrix and the correlation vector are calculated recursively with a decaying window. $\\alpha$ is the decay factor."
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "### Setup"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "channels = 8\n",
65 | "sampling_rate = 16000\n",
66 | "delay = 3\n",
67 | "alpha=0.99\n",
68 | "taps = 10\n",
69 | "frequency_bins = stft_options['size'] // 2 + 1"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {},
75 | "source": [
76 | "### Audio data"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "file_template = 'AMI_WSJ20-Array1-{}_T10c0201.wav'\n",
86 | "signal_list = [\n",
87 | " sf.read(str(project_root / 'data' / file_template.format(d + 1)))[0]\n",
88 | " for d in range(channels)\n",
89 | "]\n",
90 | "y = np.stack(signal_list, axis=0)\n",
91 | "IPython.display.Audio(y[0], rate=sampling_rate)"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "metadata": {},
97 | "source": [
98 | "### Online buffer\n",
99 | "For simplicity the STFT is performed before providing the frames.\n",
100 | "\n",
101 | "Shape: (frames, frequency bins, channels)\n",
102 | "\n",
103 | "frames: K+delay+1"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "Y = stft(y, **stft_options).transpose(1, 2, 0)\n",
113 | "T, _, _ = Y.shape\n",
114 | "\n",
115 | "def aquire_framebuffer():\n",
116 | " buffer = list(Y[:taps+delay, :, :])\n",
117 | " for t in range(taps+delay+1, T):\n",
118 | " buffer.append(Y[t, :, :])\n",
119 | " yield np.array(buffer)\n",
120 | " buffer.pop(0)"
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "metadata": {},
126 | "source": [
127 | "### Non-iterative frame online approach\n",
128 | "A frame online example requires, that certain state variables are kept from frame to frame. That is the inverse correlation matrix $\\text{R}_{t, f}^{-1}$ which is stored in Q and initialized with an identity matrix, as well as filter coefficient matrix that is stored in G and initialized with zeros. \n",
129 | "\n",
130 | "Again for simplicity the ISTFT is applied in Numpy afterwards."
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": null,
136 | "metadata": {},
137 | "outputs": [],
138 | "source": [
139 | "Z_list = []\n",
140 | "\n",
141 | "Q = np.stack([np.identity(channels * taps) for a in range(frequency_bins)])\n",
142 | "G = np.zeros((frequency_bins, channels * taps, channels))\n",
143 | "\n",
144 | "with tf.Session() as session:\n",
145 | " Y_tf = tf.placeholder(tf.complex128, shape=(taps + delay + 1, frequency_bins, channels))\n",
146 | " Q_tf = tf.placeholder(tf.complex128, shape=(frequency_bins, channels * taps, channels * taps))\n",
147 | " G_tf = tf.placeholder(tf.complex128, shape=(frequency_bins, channels * taps, channels))\n",
148 | " \n",
149 | " results = online_wpe_step(Y_tf, get_power_online(tf.transpose(Y_tf, (1, 0, 2))), Q_tf, G_tf, alpha=alpha, taps=taps, delay=delay)\n",
150 | " for Y_step in tqdm(aquire_framebuffer()):\n",
151 | " feed_dict = {Y_tf: Y_step, Q_tf: Q, G_tf: G}\n",
152 | " Z, Q, G = session.run(results, feed_dict)\n",
153 | " Z_list.append(Z)\n",
154 | "\n",
155 | "Z_stacked = np.stack(Z_list)\n",
156 | "z = istft(np.asarray(Z_stacked).transpose(2, 0, 1), size=stft_options['size'], shift=stft_options['shift'])\n",
157 | "\n",
158 | "IPython.display.Audio(z[0], rate=sampling_rate)"
159 | ]
160 | },
161 | {
162 | "cell_type": "markdown",
163 | "metadata": {},
164 | "source": [
165 | "# Power spectrum\n",
166 | "Before and after applying WPE."
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(20, 8))\n",
176 | "im1 = ax1.imshow(20 * np.log10(np.abs(Y[200:400, :, 0])).T, origin='lower')\n",
177 | "ax1.set_xlabel('')\n",
178 | "_ = ax1.set_title('reverberated')\n",
179 | "im2 = ax2.imshow(20 * np.log10(np.abs(Z_stacked[200:400, :, 0])).T, origin='lower')\n",
180 | "_ = ax2.set_title('dereverberated')\n",
181 | "cb = fig.colorbar(im1)"
182 | ]
183 | }
184 | ],
185 | "metadata": {
186 | "kernelspec": {
187 | "display_name": "Python 3",
188 | "language": "python",
189 | "name": "python3"
190 | },
191 | "language_info": {
192 | "codemirror_mode": {
193 | "name": "ipython",
194 | "version": 3
195 | },
196 | "file_extension": ".py",
197 | "mimetype": "text/x-python",
198 | "name": "python",
199 | "nbconvert_exporter": "python",
200 | "pygments_lexer": "ipython3",
201 | "version": "3.6.6"
202 | }
203 | },
204 | "nbformat": 4,
205 | "nbformat_minor": 2
206 | }
207 |
--------------------------------------------------------------------------------
/examples/examples.rst:
--------------------------------------------------------------------------------
1 |
2 | Examples
3 | ========
4 |
5 | Small IPython notebooks with examples in Numpy and Tensorflow. The examples
6 | are taken over from the following paper::
7 |
8 | @InProceedings{Drude2018NaraWPE,
9 | Title = {NARA-WPE: A Python package for weighted prediction error dereverberation in Numpy and Tensorflow for online and offline processing},
10 | Author = {Drude, Lukas and Heymann, Jahn and Boeddeker, Christoph and Haeb-Umbach, Reinhold},
11 | Booktitle = {13. ITG Fachtagung Sprachkommunikation (ITG 2018)},
12 | Year = {2018},
13 | Month = {Oct},
14 | }
15 |
--------------------------------------------------------------------------------
/maintenance.md:
--------------------------------------------------------------------------------
1 |
2 | # PyPi upload
3 |
4 | Package a Python Package/ version bump See: https://packaging.python.org/tutorials/packaging-projects/
5 |
6 | 1. Update `setup.py` to new version number
7 | 2. Commit this change
8 | 3. Tag and upload
9 |
10 | ## Install dependencies:
11 | ```bash
12 | pip install --upgrade setuptools
13 | pip install --upgrade wheel
14 | pip install --upgrade twine
15 | # pip install --upgrade bleach html5lib # some versions do not work
16 | pip install --upgrade bump2version
17 | ```
18 |
19 | `bump2version` takes care to increase the version number, create the commit and tag.
20 |
21 | ```bash
22 | # rm -rf dist/*
23 | bump2version --verbose --tag patch # major, minor or patch
24 | python setup.py sdist bdist_wheel
25 | twine upload --repository testpypi dist/*
26 | git push origin --tags
27 | git push
28 | twine upload dist/*
29 | ```
30 |
--------------------------------------------------------------------------------
/nara_wpe/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | import pathlib
3 | except ImportError:
4 | # Python 2.7
5 | import pathlib2 as pathlib
6 |
7 | import os
8 |
9 | project_root = pathlib.Path(os.path.abspath(
10 | os.path.join(os.path.dirname(__file__), os.pardir)
11 | ))
12 |
13 | name = "nara_wpe"
--------------------------------------------------------------------------------
/nara_wpe/benchmark_online_wpe.py:
--------------------------------------------------------------------------------
1 | # ToDo: move this file to tests
2 |
3 | import sys
4 | from itertools import product
5 |
6 | import pandas as pd
7 | if sys.version_info < (3, 7):
8 | import tensorflow as tf
9 |
10 | from nara_wpe import tf_wpe
11 |
12 | if sys.version_info < (3, 7):
13 | benchmark = tf.test.Benchmark()
14 | configs = []
15 | delay = 1
16 |
17 |
18 | def config_iterator():
19 | return product(
20 | range(1, 11),
21 | [5, 10], # K
22 | [2, 4, 6], # num_mics # range(2, 11)
23 | [512], # frame_size # 1024
24 | [tf.complex64], # dtype # , tf.complex128
25 | ['/cpu:0'] # device # '/gpu:0'
26 | )
27 |
28 |
29 | if __name__ == '__main__' and sys.version_info < (3, 7):
30 | print('Generating configs...')
31 | for repetition, K, num_mics, frame_size, dtype, device in config_iterator():
32 | inv_cov_tm1 = tf.eye(
33 | num_mics * K, batch_shape=[frame_size // 2 + 1], dtype=tf.complex64)
34 | filter_taps_tm1 = tf.zeros(
35 | (frame_size // 2 + 1, num_mics * K, num_mics), dtype=tf.complex64)
36 | input_buffer = tf.zeros(
37 | (K + delay + 1, frame_size // 2 + 1, num_mics), dtype=tf.complex64)
38 | power_estimate = tf.ones((frame_size // 2 + 1,), dtype=tf.complex64)
39 | with tf.device(device):
40 | configs.append(dict(
41 | repetition=repetition,
42 | K=K,
43 | num_mics=num_mics,
44 | frame_size=frame_size,
45 | dtype=dtype,
46 | device=device,
47 | op=tf_wpe.online_wpe_step(
48 | input_buffer, power_estimate, inv_cov_tm1, filter_taps_tm1,
49 | 0.9999, K, delay)
50 | ))
51 |
52 | print('Benchmarking...')
53 | results = []
54 | with tf.Session() as sess:
55 | for cfg in configs:
56 | print(cfg)
57 | result = benchmark.run_op_benchmark(
58 | sess,
59 | cfg['op'],
60 | min_iters=100
61 | )
62 | result['repetition'] = cfg['repetition']
63 | result['K'] = cfg['K']
64 | result['num_mics'] = cfg['num_mics']
65 | result['frame_size'] = cfg['frame_size']
66 | result['device'] = cfg['device']
67 | result['dtype'] = cfg['dtype'].name
68 | result['fps'] = 1 / result['wall_time']
69 | result['real_time_factor'] = (
70 | (16000 / result['frame_size']) * 4 / result['fps']
71 | )
72 | results.append(result)
73 |
74 | res = pd.DataFrame(results)
75 | print(res.groupby(['K', 'num_mics', 'frame_size', 'device', 'dtype']).mean())
76 |
77 | with open('online_wpe_results.json', 'w') as fid:
78 | res.to_json(fid)
79 |
--------------------------------------------------------------------------------
/nara_wpe/ntt_wpe.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from cached_property import cached_property
3 |
4 | import tempfile
5 | import numpy as np
6 | import soundfile as sf
7 | import click
8 | from pymatbridge import Matlab
9 |
10 | from nara_wpe import project_root
11 |
12 |
13 | def ntt_wrapper(
14 | y,
15 | taps=10,
16 | delay=3,
17 | iterations=3,
18 | sampling_rate=16000,
19 | path_to_package=project_root / 'cache' / 'wpe_v1.33',
20 | stft_size=512,
21 | stft_shift=128
22 | ):
23 | wpe = NTTWrapper(path_to_package)
24 | return wpe(
25 | y=y,
26 | taps=taps,
27 | delay=delay,
28 | iterations=iterations,
29 | sampling_rate=sampling_rate,
30 | stft_size=stft_size,
31 | stft_shift=stft_shift
32 | )
33 |
34 |
35 | class NTTWrapper:
36 | """
37 | The WPE package has to be downloaded from
38 | http://www.kecl.ntt.co.jp/icl/signal/wpe/download.html. It is recommended
39 | to store it in the cache directory of Nara-WPE.
40 | """
41 | def __init__(self, path_to_pkg):
42 | self.path_to_pkg = Path(path_to_pkg)
43 |
44 | if not self.path_to_pkg.exists():
45 | raise OSError(
46 | 'NTT WPE package does not exist. It has to be downloaded'
47 | 'from http://www.kecl.ntt.co.jp/icl/signal/wpe/download.html'
48 | 'and stored in the cache directory of Nara-WPE, preferably.'
49 | )
50 |
51 | @cached_property
52 | def process(self):
53 | mlab = Matlab()
54 | mlab.start()
55 | return mlab
56 |
57 | def cfg(self, channels, sampling_rate, iterations, taps,
58 | stft_size, stft_shift
59 | ):
60 | """
61 | Check settings and set local.m accordingly
62 |
63 | """
64 | cfg = self.path_to_pkg / 'settings' / 'local.m'
65 | lines = []
66 | with cfg.open() as infile:
67 | for line in infile:
68 | if 'num_mic = ' in line and 'num_out' not in line:
69 | if not str(channels) in line:
70 | line = 'num_mic = ' + str(channels) + ";\n"
71 | elif 'fs' in line:
72 | if not str(sampling_rate) in line:
73 | line = 'fs =' + str(sampling_rate) + ";\n"
74 | elif 'channel_setup' in line and 'ssd_param' not in line:
75 | if not str(taps) in line and '%' not in line:
76 | line = "channel_setup = [" + str(taps) + "; ..." + "\n"
77 | elif 'ssd_conf' in line:
78 | if not str(iterations) in line:
79 | line = "ssd_conf = struct('max_iter',"\
80 | + str(iterations) + ", ...\n"
81 | elif 'analym_param' in line:
82 | if not str(stft_size) in line:
83 | line = "analy_param = struct('win_size',"\
84 | + str(stft_size) + ", ..."
85 | elif 'shift_size' in line:
86 | if not str(stft_shift) in line:
87 | line = " 'shift_size',"\
88 | + str(stft_shift) + ", ..."
89 | elif 'hanning' in line:
90 | if not str(stft_size) in line:
91 | line = " 'win' , hanning("\
92 | + str(stft_size) + "));"
93 | lines.append(line)
94 | return lines
95 |
96 | def __call__(
97 | self,
98 | y,
99 | taps=10,
100 | delay=3,
101 | iterations=3,
102 | sampling_rate=16000,
103 | stft_size=512,
104 | stft_shift=128
105 | ):
106 | """
107 |
108 | Args:
109 | y: observation (channels. samples)
110 | delay:
111 | iterations:
112 | taps:
113 | stft_opts: dict contains size, shift
114 |
115 | Returns: dereverberated observation (channels, samples)
116 |
117 | """
118 |
119 | y = y.transpose(1, 0)
120 | channels = y.shape[1]
121 | cfg_lines = self.cfg(
122 | channels, sampling_rate, iterations, taps, stft_size, stft_shift
123 | )
124 |
125 | with tempfile.TemporaryDirectory() as tempdir:
126 | with (Path(tempdir) / 'local.m').open('w') as cfg_file:
127 | for line in cfg_lines:
128 | cfg_file.write(line)
129 |
130 | self.process.set_variable("y", y)
131 | self.process.set_variable("cfg", cfg_file.name)
132 |
133 | self.process.run_code("addpath('" + str(cfg_file.name) + "');")
134 | self.process.run_code("addpath('" + str(self.path_to_pkg) + "');")
135 |
136 | msg = self.process.run_code("y = wpe(y, cfg);")
137 | assert msg['success'] is True, \
138 | f'WPE has failed. {msg["content"]["stdout"]}'
139 |
140 | y = self.process.get_variable("y")
141 |
142 | return y.transpose(1, 0)
143 |
144 |
145 | @click.command()
146 | @click.argument(
147 | 'files', nargs=-1,
148 | type=click.Path(exists=True),
149 | )
150 | @click.option(
151 | '--path_to_pkg',
152 | default=str(project_root / 'cache' / 'wpe_v1.33'),
153 | help='It is recommended to save the '
154 | 'NTT-WPE package in the cache directory.'
155 | )
156 | @click.option(
157 | '--output_dir',
158 | default=str(project_root / 'data' / 'dereverberation_ntt'),
159 | help='Output path.'
160 | )
161 | @click.option(
162 | '--iterations',
163 | default=5,
164 | help='Iterations of WPE'
165 | )
166 | @click.option(
167 | '--taps',
168 | default=10,
169 | help='Number of filter taps of WPE'
170 | )
171 | def main(path_to_pkg, files, output_dir, taps=10, delay=3, iterations=5):
172 | """
173 | A small command line wrapper around the NTT-WPE matlab file.
174 | http://www.kecl.ntt.co.jp/icl/signal/wpe/
175 | """
176 |
177 | if len(files) > 1:
178 | signal_list = [
179 | sf.read(str(file))[0]
180 | for file in files
181 | ]
182 | y = np.stack(signal_list, axis=0)
183 | sampling_rate = sf.read(str(files[0]))[1]
184 | else:
185 | y, sampling_rate = sf.read(files)
186 |
187 | wrapper = NTTWrapper(path_to_pkg)
188 | x = wrapper(y, delay, iterations, taps,
189 | sampling_rate, stft_size=512, stft_shift=128
190 | )
191 |
192 | if len(files) > 1:
193 | for i, file in enumerate(files):
194 | sf.write(
195 | str(Path(output_dir) / Path(file).name),
196 | x[i],
197 | samplerate=sampling_rate
198 | )
199 | else:
200 | sf.write(
201 | str(Path(output_dir) / Path(files).name),
202 | x,
203 | samplerate=sampling_rate
204 | )
205 |
206 |
207 | if __name__ == '__main__':
208 | main()
209 |
--------------------------------------------------------------------------------
/nara_wpe/test_utils.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import unittest
3 |
4 | import six
5 |
6 |
7 | class QuietTestRunner(object):
8 |
9 | def run(self, suite):
10 | result = unittest.TestResult()
11 | suite(result)
12 | return result
13 |
14 |
15 | def repeat_with_success_at_least(times, min_success):
16 | """Decorator for multiple trial of the test case.
17 |
18 | The decorated test case is launched multiple times.
19 | The case is judged as passed at least specified number of trials.
20 | If the number of successful trials exceeds `min_success`,
21 | the remaining trials are skipped.
22 |
23 | Args:
24 | times(int): The number of trials.
25 | min_success(int): Threshold that the decorated test
26 | case is regarded as passed.
27 |
28 | """
29 |
30 | assert times >= min_success
31 |
32 | def _repeat_with_success_at_least(f):
33 | @functools.wraps(f)
34 | def wrapper(*args, **kwargs):
35 | assert len(args) > 0
36 | instance = args[0]
37 | assert isinstance(instance, unittest.TestCase)
38 | success_counter = 0
39 | failure_counter = 0
40 | results = []
41 |
42 | def fail():
43 | msg = '\nFail: {0}, Success: {1}'.format(
44 | failure_counter, success_counter)
45 | if len(results) > 0:
46 | first = results[0]
47 | errs = first.failures + first.errors
48 | if len(errs) > 0:
49 | err_msg = '\n'.join(fail[1] for fail in errs)
50 | msg += '\n\nThe first error message:\n' + err_msg
51 | instance.fail(msg)
52 |
53 | for _ in six.moves.range(times):
54 | suite = unittest.TestSuite()
55 | # Create new instance to call the setup and the teardown only
56 | # once.
57 | ins = type(instance)(instance._testMethodName)
58 | suite.addTest(
59 | unittest.FunctionTestCase(
60 | lambda: f(ins, *args[1:], **kwargs),
61 | setUp=ins.setUp,
62 | tearDown=ins.tearDown))
63 |
64 | result = QuietTestRunner().run(suite)
65 | if result.wasSuccessful():
66 | success_counter += 1
67 | else:
68 | results.append(result)
69 | failure_counter += 1
70 | if success_counter >= min_success:
71 | instance.assertTrue(True)
72 | return
73 | if failure_counter > times - min_success:
74 | fail()
75 | return
76 | fail()
77 | return wrapper
78 | return _repeat_with_success_at_least
79 |
80 |
81 | def retry(times):
82 | """Decorator that imposes the test to be successful at least once.
83 |
84 | Decorated test case is launched multiple times.
85 | The case is regarded as passed if it is successful
86 | at least once.
87 |
88 | .. note::
89 | In current implementation, this decorator grasps the
90 | failure information of each trial.
91 |
92 | Args:
93 | times(int): The number of trials.
94 | """
95 | return repeat_with_success_at_least(times, 1)
96 |
--------------------------------------------------------------------------------
/nara_wpe/tf_wpe.py:
--------------------------------------------------------------------------------
1 | try:
2 | import tensorflow as tf
3 | from tensorflow.contrib import signal as tf_signal
4 | except ModuleNotFoundError:
5 | import warnings
6 | # For doctests, each file will be imported
7 | warnings.warn(
8 | 'Could not import tensorflow, hence tensorflow code in nara_wpe will fail.',
9 | )
10 |
11 |
12 | def _batch_wrapper(inner_function, signals, num_frames, time_axis=-1):
13 | """Helper function to support batching with signal lenghts respected
14 |
15 | Args:
16 | inner_function (function): A function taking the cutted signals as
17 |
18 | signals (tuple): Signals needed for the function. Observation must be
19 | in the first place. All signals must have shape (batch, ..., time)
20 | num_frames (array): Number of frames for each batch
21 |
22 | Returns:
23 | tf.Tensor: Zero padded output of the function.
24 | """
25 |
26 | max_frames = tf.reduce_max(num_frames)
27 |
28 | # If we remove the batch dimension the time axis shifts by -1 if positive
29 | if time_axis > 0:
30 | time_axis -= 1
31 |
32 | def _single_batch(inp):
33 | frames = inp[-1]
34 | inp = inp[0]
35 | with tf.name_scope('single_batch'):
36 |
37 | def _pad(x):
38 | padding = max_frames - \
39 | tf.minimum(frames, tf.shape(x)[time_axis])
40 | zeros = tf.cast(tf.zeros(()), x.dtype)
41 | paddings = x.shape.ndims * [(0, 0), ]
42 | paddings[time_axis] = (0, padding)
43 | return tf.pad(
44 | x,
45 | paddings,
46 | constant_values=zeros
47 | )
48 |
49 | def _slice(x):
50 | slices = x.shape.ndims * [slice(None), ]
51 | slices[time_axis] = slice(frames)
52 | return x[slices]
53 |
54 | enhanced = inner_function(
55 | [_slice(i) for i in inp]
56 | )
57 | return _pad(enhanced)
58 |
59 | out = tf.map_fn(
60 | _single_batch, [signals, num_frames], dtype=signals[0].dtype
61 | )
62 | out.set_shape(signals[0].shape)
63 | return out
64 |
65 |
66 | def get_power_online(signal):
67 | """Calculates power for `signal`
68 |
69 | Args:
70 | signal (tf.Tensor): Signal with shape (F, D, T).
71 |
72 | Returns:
73 | tf.Tensor: Power with shape (F,)
74 |
75 | """
76 | power_estimate = get_power(signal)
77 | power_estimate = tf.reduce_mean(power_estimate, axis=-1)
78 | return power_estimate
79 |
80 |
81 | def get_power_inverse(signal):
82 | """Calculates inverse power for `signal`
83 |
84 | Args:
85 | signal (tf.Tensor): Single frequency signal with shape (D, T).
86 | psd_context: context for power estimation
87 | Returns:
88 | tf.Tensor: Inverse power with shape (T,)
89 |
90 | """
91 | power = get_power(signal)
92 | eps = 1e-10 * tf.reduce_max(power)
93 | inverse_power = tf.reciprocal(tf.maximum(power, eps))
94 | return inverse_power
95 |
96 |
97 | def get_power(signal, axis=-2):
98 | """Calculates power for `signal`
99 |
100 | Args:
101 | signal (tf.Tensor): Single frequency signal with shape (D, T) or (F, D, T).
102 | axis: reduce_mean axis
103 | Returns:
104 | tf.Tensor: Power with shape (T,) or (F, T)
105 |
106 | """
107 | power = tf.real(signal) ** 2 + tf.imag(signal) ** 2
108 | power = tf.reduce_mean(power, axis=axis)
109 |
110 | return power
111 |
112 |
113 | #def get_power(signal, psd_context=0):
114 | # """
115 | # Calculates power for single frequency signal.
116 | # In case psd_context is an tuple the two values
117 | # are describing the left and right hand context.
118 | #
119 | # Args:
120 | # signal: (D, T)
121 | # psd_context: tuple or int
122 | # """
123 | # shape = tf.shape(signal)
124 | # if len(signal.get_shape()) == 2:
125 | # signal = tf.reshape(signal, (1, shape[0], shape[1]))
126 | #
127 | # power = tf.reduce_mean(
128 | # tf.real(signal) ** 2 + tf.imag(signal) ** 2,
129 | # axis=-2
130 | # )
131 | #
132 | # if psd_context is not 0:
133 | # if isinstance(psd_context, tuple):
134 | # context = psd_context[0] + 1 + psd_context[1]
135 | # else:
136 | # context = 2 * psd_context + 1
137 | # psd_context = (psd_context, psd_context)
138 | #
139 | # power = tf.pad(
140 | # power,
141 | # ((0, 0), (psd_context[0], psd_context[1])),
142 | # mode='constant'
143 | # )
144 | # print(power)
145 | # power = tf.nn.convolution(
146 | # power,
147 | # tf.ones(context),
148 | # padding='VALID'
149 | # )[psd_context[1]:-psd_context[0]]
150 | #
151 | # denom = tf.nn.convolution(
152 | # tf.zeros_like(power) + 1.,
153 | # tf.ones(context),
154 | # padding='VALID'
155 | # )[psd_context[1]:-psd_context[0]]
156 | # print(power)
157 | # power /= denom
158 | #
159 | # elif psd_context == 0:
160 | # pass
161 | # else:
162 | # raise ValueError(psd_context)
163 | #
164 | # return tf.squeeze(power, axis=0)
165 |
166 |
167 | def get_correlations(Y, inverse_power, taps, delay):
168 | """Calculates weighted correlations of a window of length taps
169 |
170 | Args:
171 | Y (tf.Ttensor): Complex-valued STFT signal with shape (F, D, T)
172 | inverse_power (tf.Tensor): Weighting factor with shape (F, T)
173 | taps (int): Lenghts of correlation window
174 | delay (int): Delay for the weighting factor
175 |
176 | Returns:
177 | tf.Tensor: Correlation matrix of shape (F, taps*D, taps*D)
178 | tf.Tensor: Correlation vector of shape (F, taps*D)
179 | """
180 | dyn_shape = tf.shape(Y)
181 | F = dyn_shape[0]
182 | D = dyn_shape[1]
183 | T = dyn_shape[2]
184 |
185 | Psi = tf_signal.frame(Y, taps, 1, axis=-1)[..., :T - delay - taps + 1, ::-1]
186 | Psi_conj_norm = (
187 | tf.cast(inverse_power[:, None, delay + taps - 1:, None], Psi.dtype)
188 | * tf.conj(Psi)
189 | )
190 |
191 | correlation_matrix = tf.einsum('fdtk,fetl->fkdle', Psi_conj_norm, Psi)
192 | correlation_vector = tf.einsum(
193 | 'fdtk,fet->fked', Psi_conj_norm, Y[..., delay + taps - 1:]
194 | )
195 |
196 | correlation_matrix = tf.reshape(correlation_matrix, (F, taps * D, taps * D))
197 | return correlation_matrix, correlation_vector
198 |
199 |
200 | def get_correlations_for_single_frequency(Y, inverse_power, taps, delay):
201 | """Calculates weighted correlations of a window of length taps for one freq.
202 |
203 | Args:
204 | Y (tf.Ttensor): Complex-valued STFT signal with shape (D, T)
205 | inverse_power (tf.Tensor): Weighting factor with shape (T)
206 | K (int): Lenghts of correlation window
207 | delay (int): Delay for the weighting factor
208 |
209 | Returns:
210 | tf.Tensor: Correlation matrix of shape (taps*D, taps*D)
211 | tf.Tensor: Correlation vector of shape (D, taps*D)
212 | """
213 | correlation_matrix, correlation_vector = get_correlations(
214 | Y[None], inverse_power[None], taps, delay
215 | )
216 | return correlation_matrix[0], correlation_vector[0]
217 |
218 |
219 | def get_filter_matrix_conj(
220 | Y, correlation_matrix, correlation_vector, taps, delay, mode='solve'):
221 | """Calculate (conjugate) filter matrix based on correlations for one freq.
222 |
223 | Args:
224 | Y (tf.Tensor): Complex-valued STFT signal of shape (D, T)
225 | correlation_matrix (tf.Tensor): Correlation matrix (taps*D, taps*D)
226 | correlation_vector (tf.Tensor): Correlation vector (D, taps*D)
227 | K (int): Number of filter taps
228 | delay (int): Delay
229 | mode (str, optional): Specifies how R^-1@r is calculate:
230 | "inv" calculates the inverse of R directly and then uses matmul
231 | "solve" solves Rx=r for x
232 |
233 | Raises:
234 | ValueError: Unknown mode specified
235 |
236 | Returns:
237 | tf.Tensor: (Conjugate) filter Matrix
238 | """
239 |
240 | D = tf.shape(Y)[0]
241 |
242 | correlation_vector = tf.reshape(correlation_vector, (D * D * taps, 1))
243 | selector = \
244 | tf.reshape(
245 | tf.transpose(
246 | tf.reshape(tf.range(D * D * taps), (D, taps, D)), (1, 0, 2)), (-1,))
247 | inv_selector = \
248 | tf.reshape(
249 | tf.transpose(
250 | tf.reshape(tf.range(D * D * taps), (taps, D, D)), (1, 0, 2)), (-1,))
251 |
252 | correlation_vector = tf.gather(correlation_vector, inv_selector)
253 |
254 | if mode == 'inv':
255 | with tf.device('/cpu:0'):
256 | inv_correlation_matrix = tf.matrix_inverse(correlation_matrix)
257 | stacked_filter_conj = tf.einsum(
258 | 'ab,cb->ca',
259 | inv_correlation_matrix, tf.reshape(correlation_vector, (D, D * taps))
260 | )
261 | stacked_filter_conj = tf.reshape(stacked_filter_conj, (D * D * taps, 1))
262 | elif mode == 'solve':
263 | with tf.device('/cpu:0'):
264 | stacked_filter_conj = tf.reshape(
265 | tf.matrix_solve(
266 | tf.tile(correlation_matrix[None, ...], [D, 1, 1]),
267 | tf.reshape(correlation_vector, (D, D * taps, 1))
268 | ),
269 | (D * D * taps, 1)
270 | )
271 | else:
272 | raise ValueError(
273 | 'Unknown mode {}. Possible are "inv" and solve"'.format(mode))
274 | stacked_filter_conj = tf.gather(stacked_filter_conj, selector)
275 |
276 | filter_matrix_conj = tf.transpose(
277 | tf.reshape(stacked_filter_conj, (taps, D, D)),
278 | (0, 2, 1)
279 | )
280 | return filter_matrix_conj
281 |
282 |
283 | def perform_filter_operation(Y, filter_matrix_conj, taps, delay):
284 | """
285 |
286 | # >>> D, T, taps, delay = 1, 10, 2, 1
287 | # >>> tf.enable_eager_execution()
288 | # >>> Y = tf.ones([D, T])
289 | # >>> filter_matrix_conj = tf.ones([taps, D, D])
290 | # >>> X = perform_filter_operation_v2(Y, filter_matrix_conj, taps, delay)
291 | # >>> X.shape
292 | # TensorShape([Dimension(1), Dimension(10)])
293 | # >>> X.numpy()
294 | # array([[ 1., 0., -1., -1., -1., -1., -1., -1., -1., -1.]], dtype=float32)
295 | """
296 | dyn_shape = tf.shape(Y)
297 | T = dyn_shape[1]
298 |
299 | def add_tap(accumulated, tau_minus_delay):
300 | new = tf.einsum(
301 | 'de,dt',
302 | filter_matrix_conj[tau_minus_delay, :, :],
303 | Y[:, :(T - delay - tau_minus_delay)]
304 | )
305 | paddings = tf.convert_to_tensor([[0, 0], [delay + tau_minus_delay, 0]])
306 | new = tf.pad(new, paddings, "CONSTANT")
307 | return accumulated + new
308 |
309 | reverb_tail = tf.foldl(
310 | add_tap, tf.range(0, taps),
311 | initializer=tf.zeros_like(Y)
312 | )
313 | return Y - reverb_tail
314 |
315 |
316 | def single_frequency_wpe(Y, taps=10, delay=3, iterations=3, mode='inv'):
317 | """WPE for a single frequency.
318 |
319 | Args:
320 | Y: Complex valued STFT signal with shape (D, T)
321 | taps: Number of filter taps
322 | delay: Delay as a guard interval, such that X does not become zero.
323 | iterations:
324 | mode (str, optional): Specifies how R^-1@r is calculate:
325 | "inv" calculates the inverse of R directly and then uses matmul
326 | "solve" solves Rx=r for x
327 |
328 | Returns:
329 |
330 | """
331 |
332 | enhanced = Y
333 | for _ in range(iterations):
334 | inverse_power = get_power_inverse(enhanced)
335 | correlation_matrix, correlation_vector = \
336 | get_correlations_for_single_frequency(Y, inverse_power, taps, delay)
337 | filter_matrix_conj = get_filter_matrix_conj(
338 | Y, correlation_matrix, correlation_vector,
339 | taps, delay, mode=mode
340 | )
341 | enhanced = perform_filter_operation(Y, filter_matrix_conj, taps, delay)
342 | return enhanced, inverse_power
343 |
344 |
345 | def wpe(Y, taps=10, delay=3, iterations=3, mode='inv'):
346 | """WPE for all frequencies at once. Use this for regular processing.
347 |
348 | Args:
349 | Y (tf.Tensor): Observed signal with shape (F, D, T)
350 | num_frames (tf.Tensor): Number of frames for each signal in the batch
351 | taps (int, optional): Defaults to 10. Number of filter taps.
352 | delay (int, optional): Defaults to 3.
353 | iterations (int, optional): Defaults to 3.
354 | mode (str, optional): Specifies how R^-1@r is calculated:
355 | "inv" calculates the inverse of R directly and then uses matmul
356 | "solve" solves Rx=r for x
357 |
358 | Returns:
359 | tf.Tensor: Dereverberated signal
360 | tf.Tensor: Latest estimation of the clean speech PSD
361 | """
362 |
363 | def iteration_over_frequencies(y):
364 | enhanced, inverse_power = single_frequency_wpe(
365 | y, taps, delay, iterations, mode=mode)
366 | return (enhanced, inverse_power)
367 |
368 | enhanced, inverse_power = tf.map_fn(
369 | iteration_over_frequencies, Y, dtype=(Y.dtype, Y.dtype.real_dtype)
370 | )
371 |
372 | return enhanced
373 |
374 |
375 | def batched_wpe(Y, num_frames, taps=10, delay=3, iterations=3, mode='inv'):
376 | """Batched version of iterative WPE.
377 |
378 | Args:
379 | Y (tf.Tensor): Observed signal with shape (B, F, D, T)
380 | num_frames (tf.Tensor): Number of frames for each signal in the batch
381 | taps (int, optional): Defaults to 10. Number of filter taps.
382 | delay (int, optional): Defaults to 3.
383 | iterations (int, optional): Defaults to 3.
384 | mode (str, optional): Specifies how R^-1@r is calculate:
385 | "inv" calculates the inverse of R directly and then uses matmul
386 | "solve" solves Rx=r for x
387 |
388 | Returns:
389 | tf.Tensor: Dereverberated signal of shape (B, F, D, T).
390 | """
391 |
392 | def _inner_func(signals):
393 | out = wpe(signals[0], taps, delay, iterations, mode)
394 | return out
395 |
396 | return _batch_wrapper(_inner_func, [Y], num_frames)
397 |
398 |
399 | def wpe_step(Y, inverse_power, taps=10, delay=3, mode='inv', Y_stats=None):
400 | """Single step of 'wpe'. More suited for backpropagation.
401 |
402 | Args:
403 | Y (tf.Tensor): Complex valued STFT signal with shape (F, D, T)
404 | inverse_power (tf.Tensor): Power signal with shape (F, T)
405 | taps (int, optional): Filter order
406 | delay (int, optional): Delay as a guard interval, such that X does not become zero.
407 | mode (str, optional): Specifies how R^-1@r is calculate:
408 | "inv" calculates the inverse of R directly and then uses matmul
409 | "solve" solves Rx=r for x
410 | Y_stats (tf.Tensor or None, optional): Complex valued STFT signal
411 | with shape (F, D, T) use to calculate the signal statistics
412 | (i.e. correlation matrix/vector).
413 | If None, Y is used. Otherwise it's usually a segment of Y
414 |
415 | Returns:
416 | Dereverberated signal of shape (F, D, T)
417 | """
418 | with tf.name_scope('WPE'):
419 | with tf.name_scope('correlations'):
420 | if Y_stats is None:
421 | Y_stats = Y
422 | correlation_matrix, correlation_vector = get_correlations(
423 | Y_stats, inverse_power, taps, delay
424 | )
425 |
426 | def step(inp):
427 | (Y_f, correlation_matrix_f, correlation_vector_f) = inp
428 | with tf.name_scope('filter_matrix'):
429 | filter_matrix_conj = get_filter_matrix_conj(
430 | Y_f,
431 | correlation_matrix_f, correlation_vector_f,
432 | taps, delay, mode=mode
433 | )
434 | with tf.name_scope('apply_filter'):
435 | enhanced = perform_filter_operation(
436 | Y_f, filter_matrix_conj, taps, delay)
437 | return enhanced
438 |
439 | enhanced = tf.map_fn(
440 | step,
441 | (Y, correlation_matrix, correlation_vector),
442 | dtype=Y.dtype,
443 | parallel_iterations=100
444 | )
445 |
446 | return enhanced
447 |
448 |
449 | def batched_wpe_step(
450 | Y, inverse_power, num_frames, taps=10, delay=3, mode='inv', Y_stats=None):
451 | """Batched single WPE step. More suited for backpropagation.
452 |
453 | Args:
454 | Y (tf.Tensor): Complex valued STFT signal with shape (B, F, D, T)
455 | inverse_power (tf.Tensor): Power signal with shape (B, F, T)
456 | num_frames (tf.Tensor): Number of frames for each signal in the batch
457 | taps (int, optional): Filter order
458 | delay (int, optional): Delay as a guard interval, such that X does not become zero.
459 | mode (str, optional): Specifies how R^-1@r is calculate:
460 | "inv" calculates the inverse of R directly and then uses matmul
461 | "solve" solves Rx=r for x
462 | Y_stats (tf.Tensor or None, optional): Complex valued STFT signal
463 | with shape (F, D, T) use to calculate the signal statistics
464 | (i.e. correlation matrix/vector).
465 | If None, Y is used. Otherwise it's usually a segment of Y
466 |
467 | Returns:
468 | Dereverberated signal of shape B, (F, D, T)
469 | """
470 | def _inner_func(signals):
471 | _Y, _inverse_power, _Y_stats = signals
472 | out = wpe_step(_Y, _inverse_power, taps, delay, mode, _Y_stats)
473 | return out
474 |
475 | if Y_stats is None:
476 | Y_stats = Y
477 |
478 | return _batch_wrapper(_inner_func, [Y, inverse_power, Y_stats], num_frames)
479 |
480 |
481 | def block_wpe_step(
482 | Y, inverse_power, taps=10, delay=3, mode='inv',
483 | block_length_in_seconds=2., forgetting_factor=0.7,
484 | fft_shift=256, sampling_rate=16000):
485 | """Applies wpe in a block-wise fashion.
486 |
487 | Args:
488 | Y (tf.Tensor): Complex valued STFT signal with shape (F, D, T)
489 | inverse_power (tf.Tensor): Power signal with shape (F, T)
490 | taps (int, optional): Defaults to 10.
491 | delay (int, optional): Defaults to 3.
492 | mode (str, optional): Specifies how R^-1@r is calculate:
493 | "inv" calculates the inverse of R directly and then uses matmul
494 | "solve" solves Rx=r for x
495 | block_length_in_seconds (float, optional): Length of each block in
496 | seconds
497 | forgetting_factor (float, optional): Forgetting factor for the signal
498 | statistics between the blocks
499 | fft_shift (int, optional): Shift used for the STFT.
500 | sampling_rate (int, optional): Sampling rate of the observed signal.
501 | """
502 | frames_per_block = block_length_in_seconds * sampling_rate // fft_shift
503 | frames_per_block = tf.cast(frames_per_block, tf.int32)
504 | framed_Y = tf_signal.frame(
505 | Y, frames_per_block, frames_per_block, pad_end=True)
506 | framed_inverse_power = tf_signal.frame(
507 | inverse_power, frames_per_block, frames_per_block, pad_end=True)
508 | num_blocks = tf.shape(framed_Y)[-2]
509 |
510 | enhanced_arr = tf.TensorArray(
511 | framed_Y.dtype, size=num_blocks, clear_after_read=True)
512 | start_block = tf.constant(0)
513 | correlation_matrix, correlation_vector = get_correlations(
514 | framed_Y[..., start_block, :], framed_inverse_power[..., start_block, :],
515 | taps, delay
516 | )
517 | num_bins = Y.shape[0]
518 | num_channels = Y.shape[1].value
519 | if num_channels is None:
520 | num_channels = tf.shape(Y)[1]
521 | num_frames = tf.shape(Y)[-1]
522 |
523 | def cond(k, *_):
524 | return k < num_blocks
525 |
526 | with tf.name_scope('block_WPE'):
527 | def block_step(
528 | k, enhanced, correlation_matrix_tm1, correlation_vector_tm1):
529 |
530 | def _init_step():
531 | return correlation_matrix_tm1, correlation_vector_tm1
532 |
533 | def _update_step():
534 | correlation_matrix, correlation_vector = get_correlations(
535 | framed_Y[..., k, :], framed_inverse_power[..., k, :],
536 | taps, delay
537 | )
538 | return (
539 | (1. - forgetting_factor) * correlation_matrix_tm1
540 | + forgetting_factor * correlation_matrix,
541 | (1. - forgetting_factor) * correlation_vector_tm1
542 | + forgetting_factor * correlation_vector
543 | )
544 |
545 | correlation_matrix, correlation_vector = tf.case(
546 | ((tf.equal(k, 0), _init_step),), default=_update_step
547 | )
548 |
549 | def step(inp):
550 | (Y_f, inverse_power_f,
551 | correlation_matrix_f, correlation_vector_f) = inp
552 | with tf.name_scope('filter_matrix'):
553 | filter_matrix_conj = get_filter_matrix_conj(
554 | Y_f,
555 | correlation_matrix_f, correlation_vector_f,
556 | taps, delay, mode=mode
557 | )
558 | with tf.name_scope('apply_filter'):
559 | enhanced_f = perform_filter_operation(
560 | Y_f, filter_matrix_conj, taps, delay)
561 | return enhanced_f
562 |
563 | enhanced_block = tf.map_fn(
564 | step,
565 | (framed_Y[..., k, :], framed_inverse_power[..., k, :],
566 | correlation_matrix, correlation_vector),
567 | dtype=framed_Y.dtype,
568 | parallel_iterations=100
569 | )
570 |
571 | enhanced = enhanced.write(k, enhanced_block)
572 | return k + 1, enhanced, correlation_matrix, correlation_vector
573 |
574 | _, enhanced_arr, _, _ = tf.while_loop(
575 | cond, block_step,
576 | (start_block, enhanced_arr, correlation_matrix, correlation_vector)
577 | )
578 |
579 | enhanced = enhanced_arr.stack()
580 | enhanced = tf.transpose(enhanced, (1, 2, 0, 3))
581 | enhanced = tf.reshape(enhanced, (num_bins, num_channels, -1))
582 |
583 | return enhanced[..., :num_frames]
584 |
585 |
586 | def batched_block_wpe_step(
587 | Y, inverse_power, num_frames, taps=10, delay=3, mode='inv',
588 | block_length_in_seconds=2., forgetting_factor=0.7,
589 | fft_shift=256, sampling_rate=16000):
590 | """Batched single WPE step. More suited for backpropagation.
591 |
592 | Args:
593 | Y (tf.Tensor): Complex valued STFT signal with shape (B, F, D, T)
594 | inverse_power (tf.Tensor): Power signal with shape (B, F, T)
595 | num_frames (tf.Tensor): Number of frames for each signal in the batch
596 | taps (int, optional): Filter order
597 | delay (int, optional): Delay as a guard interval, such that X does not become zero.
598 | mode (str, optional): Specifies how R^-1@r is calculate:
599 | "inv" calculates the inverse of R directly and then uses matmul
600 | "solve" solves Rx=r for x
601 | block_length_in_seconds (float, optional): Length of each block in
602 | seconds
603 | forgetting_factor (float, optional): Forgetting factor for the signal
604 | statistics between the blocks
605 | fft_shift (int, optional): Shift used for the STFT.
606 | sampling_rate (int, optional): Sampling rate of the observed signal.
607 |
608 | Returns:
609 | Dereverberated signal of shape B, (F, D, T)
610 | """
611 | def _inner_func(signals):
612 | _Y, _inverse_power = signals
613 | out = block_wpe_step(
614 | _Y, _inverse_power, taps, delay,
615 | mode, block_length_in_seconds, forgetting_factor,
616 | fft_shift, sampling_rate)
617 | return out
618 |
619 | return _batch_wrapper(_inner_func, [Y, inverse_power], num_frames)
620 |
621 |
622 | def online_wpe_step(
623 | input_buffer, power_estimate, inv_cov, filter_taps,
624 | alpha, taps, delay
625 | ):
626 | """
627 | One step of online dereverberation
628 |
629 | Args:
630 | input_buffer (tf.Tensor): Buffer of shape (taps+delay+1, F, D)
631 | power_estimate (tf.Tensor): Estimate for the current PSD
632 | inv_cov (tf.Tensor): Current estimate of R^-1
633 | filter_taps (tf.Tensor): Current estimate of filter taps (F, taps*D, taps)
634 | alpha (float): Smoothing factor
635 | taps (int): Number of filter taps
636 | delay (int): Delay in frames
637 |
638 | Returns:
639 | tf.Tensor: Dereverberated frame of shape (F, D)
640 | tf.Tensor: Updated estimate of R^-1
641 | tf.Tensor: Updated estimate of the filter taps
642 | """
643 | F = input_buffer.shape[-2]
644 | D = tf.shape(input_buffer)[-1]
645 | window = input_buffer[:-delay - 1][::-1]
646 | window = tf.reshape(
647 | tf.transpose(window, (1, 2, 0)), (F, taps * D)
648 | )
649 | window_conj = tf.conj(window)
650 | pred = (
651 | input_buffer[-1] -
652 | tf.einsum('fid,fi->fd', tf.conj(filter_taps), window)
653 | )
654 |
655 | nominator = tf.einsum('fij,fj->fi', inv_cov, window)
656 | denominator = tf.cast(alpha * power_estimate, window.dtype)
657 | denominator += tf.einsum('fi,fi->f', window_conj, nominator)
658 | kalman_gain = nominator / denominator[:, None]
659 |
660 | inv_cov_k = inv_cov - tf.einsum('fj,fjm,fi->fim', window_conj, inv_cov, kalman_gain)
661 | inv_cov_k /= alpha
662 |
663 | filter_taps_k = (
664 | filter_taps +
665 | tf.einsum('fi,fm->fim', kalman_gain, tf.conj(pred))
666 | )
667 | return pred, inv_cov_k, filter_taps_k
668 |
669 |
670 | def recursive_wpe(
671 | Y, power_estimate, alpha, taps=10, delay=2,
672 | only_use_final_filters=False):
673 | """Applies WPE in a framewise recursive fashion.
674 |
675 | Args:
676 | Y (tf.Tensor): Observed signal of shape (T, F, D)
677 | power_estimate (tf.Tensor): Estimate for the clean signal PSD of shape (T, F)
678 | alpha (float): Smoothing factor for the recursion
679 | taps (int, optional): Number of filter taps.
680 | delay (int, optional): Delay
681 | only_use_final_filters (bool, optional): Applies only the final
682 | estimated filter coefficients to the whole signal. This is for
683 | debugging purposes only and makes this method a offline one.
684 |
685 | Returns:
686 | tf.Tensor: Enhanced signal
687 | """
688 |
689 | num_frames = tf.shape(Y)[0]
690 | num_bins = Y.shape[1]
691 | num_ch = tf.shape(Y)[-1]
692 | dtype = Y.dtype
693 | k = delay + taps
694 |
695 | inv_cov_tm1 = tf.eye(num_ch * taps, batch_shape=[num_bins], dtype=dtype)
696 | filter_taps_tm1 = tf.zeros((num_bins, num_ch * taps, num_ch), dtype=dtype)
697 | enhanced_arr = tf.TensorArray(dtype, size=num_frames, name='dereverberated')
698 | Y = tf.pad(Y, ((delay + taps, 0), (0, 0), (0, 0)))
699 |
700 | def dereverb_step(k_, inv_cov_tm1, filter_taps_tm1, enhanced):
701 | pos = k_ - delay - taps
702 | input_buffer = Y[pos:k_ + 1]
703 | pred, inv_cov_k, filter_taps_k = online_wpe_step(
704 | input_buffer, power_estimate[pos],
705 | inv_cov_tm1, filter_taps_tm1, alpha, taps, delay
706 | )
707 | enhanced_k = enhanced.write(pos, pred)
708 | return k_ + 1, inv_cov_k, filter_taps_k, enhanced_k
709 |
710 | def cond(k, *_):
711 | return tf.less(k, num_frames + delay + taps)
712 |
713 | _, _, final_filter_taps, enhanced = tf.while_loop(
714 | cond, dereverb_step, (k, inv_cov_tm1, filter_taps_tm1, enhanced_arr))
715 |
716 | # Only for testing / oracle purposes
717 | def dereverb_with_filters(k_, filter_taps, enhanced):
718 | window = Y[k_ - delay - taps:k_ - delay][::-1]
719 | window = tf.reshape(
720 | tf.transpose(window, (1, 2, 0)), (-1, taps * num_ch)
721 | )
722 | pred = (
723 | Y[k_] -
724 | tf.einsum('lim,li->lm', tf.conj(filter_taps), window)
725 | )
726 | enhanced_k = enhanced.write(k_ - delay - taps, pred)
727 | return k_ + 1, filter_taps, enhanced_k
728 |
729 | if only_use_final_filters:
730 | k = tf.constant(0) + delay + taps
731 | enhanced_arr = tf.TensorArray(dtype, size=num_frames)
732 | _, _, enhanced = tf.while_loop(
733 | cond, dereverb_with_filters, (k, final_filter_taps, enhanced_arr))
734 |
735 | return enhanced.stack()
736 |
737 |
738 | def batched_recursive_wpe(
739 | Y, power_estimate, alpha, num_frames, taps=10, delay=2,
740 | only_use_final_filters=False):
741 | """Batched single WPE step. More suited for backpropagation.
742 |
743 | Args:
744 | Y (tf.Tensor): Observed signal of shape (B, T, F, D)
745 | power_estimate (tf.Tensor): Estimate for the clean signal PSD of shape (B, T, F)
746 | alpha (float): Smoothing factor for the recursion
747 | num_frames (tf.Tensor): Number of frames for each signal in the batch
748 | K (int, optional): Number of filter taps.
749 | delay (int, optional): Delay
750 | only_use_final_filters (bool, optional): Applies only the final
751 | estimated filter coefficients to the whole signal. This is for
752 | debugging purposes only and makes this method a offline one.
753 |
754 | Returns:
755 | Dereverberated signal of shape (B, T, F, D)
756 | """
757 | def _inner_func(signals):
758 | _Y, _power_estimate = signals
759 | out = recursive_wpe(
760 | _Y, _power_estimate, alpha, taps, delay, only_use_final_filters)
761 | return out
762 |
763 | return _batch_wrapper(
764 | _inner_func, [Y, power_estimate], num_frames, time_axis=1)
765 |
--------------------------------------------------------------------------------
/nara_wpe/torch_wpe.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional
6 | from nara_wpe.wpe import segment_axis
7 |
8 |
9 | def torch_moveaxis(x: torch.tensor, source, destination):
10 | """
11 |
12 | >>> torch_moveaxis(torch.ones(2, 25), 1, 0).shape
13 | torch.Size([25, 2])
14 | >>> torch_moveaxis(torch.ones(2, 25), -1, -2).shape
15 | torch.Size([25, 2])
16 | >>> torch_moveaxis(torch.ones(2, 25), 0, 1).shape
17 | torch.Size([25, 2])
18 | >>> torch_moveaxis(torch.ones(2, 25), -2, -1).shape
19 | torch.Size([25, 2])
20 | >>> torch_moveaxis(torch.ones(2, 25) + 1j, -2, -1).shape
21 | torch.Size([25, 2])
22 | """
23 | ndim = len(x.shape)
24 | permutation = list(range(ndim))
25 | source = permutation.pop(source)
26 | permutation.insert(destination % ndim, source)
27 | return x.permute(*permutation)
28 |
29 |
30 | def build_y_tilde(Y, taps, delay):
31 | """
32 |
33 | Note: The returned y_tilde consumes a similar amount of memory as Y, because
34 | of tricks with strides. Usually the memory consumprion is K times
35 | smaller than the memory consumprion of a contignous array,
36 |
37 | >>> T, D = 20, 2
38 | >>> Y = torch.arange(start=1, end=T * D + 1).reshape([T, D]).t()
39 | >>> # Y = torch.arange(start=1, end=T * D + 1).to(dtype=torch.complex128).reshape([T, D]).t()
40 | >>> print(Y.numpy())
41 | [[ 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39]
42 | [ 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40]]
43 | >>> taps, delay = 4, 2
44 | >>> Y_tilde = build_y_tilde(Y, taps, delay)
45 | >>> print(Y_tilde.shape, (taps*D, T))
46 | torch.Size([8, 20]) (8, 20)
47 | >>> print(Y_tilde.numpy())
48 | [[ 0 0 0 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29]
49 | [ 0 0 0 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30]
50 | [ 0 0 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31]
51 | [ 0 0 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32]
52 | [ 0 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33]
53 | [ 0 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34]
54 | [ 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35]
55 | [ 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36]]
56 | >>> Y_tilde = build_y_tilde(Y, taps, 0)
57 | >>> print(Y_tilde.shape, (taps*D, T), Y_tilde.stride())
58 | torch.Size([8, 20]) (8, 20) (1, 2)
59 | >>> print('Pseudo size:', np.prod(Y_tilde.size()) * Y_tilde.element_size())
60 | Pseudo size: 1280
61 | >>> print('Real size:', Y_tilde.storage().size() * Y_tilde.storage().element_size())
62 | Real size: 368
63 | >>> print(Y_tilde.numpy())
64 | [[ 0 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33]
65 | [ 0 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34]
66 | [ 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35]
67 | [ 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36]
68 | [ 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37]
69 | [ 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38]
70 | [ 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39]
71 | [ 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40]]
72 |
73 | >>> print(Y_tilde.shape, Y_tilde.stride())
74 | torch.Size([8, 20]) (1, 2)
75 | >>> print(Y_tilde[::3].shape, Y_tilde[::3].stride())
76 | torch.Size([3, 20]) (3, 2)
77 | >>> print(Y_tilde[::3].shape, Y_tilde[::3].contiguous().stride())
78 | torch.Size([3, 20]) (20, 1)
79 | >>> print(Y_tilde[::3].numpy())
80 | [[ 0 0 0 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33]
81 | [ 0 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36]
82 | [ 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39]]
83 |
84 | The first columns are zero because of the delay.
85 |
86 | """
87 | S = Y.shape[:-2]
88 | D = Y.shape[-2]
89 | T = Y.shape[-1]
90 |
91 | def pad(x, axis=-1, pad_width=taps + delay - 1):
92 | npad = np.zeros([x.ndimension(), 2], dtype=int)
93 | npad[axis, 0] = pad_width
94 | # x_np = (np.pad(x.numpy(),
95 | # pad_width=npad,
96 | # mode='constant',
97 | # constant_values=0))
98 | x = torch.nn.functional.pad(
99 | x,
100 | pad=npad[::-1].ravel().tolist(),
101 | mode='constant',
102 | value=0,
103 | )
104 | # assert x_np.shape == x.shape, (x_np.shape, x.shape)
105 | return x
106 |
107 | # Y_ = segment_axis(pad(Y), K, 1, axis=-1)
108 | # Y_ = np.flip(Y_, axis=-1)
109 | # if delay > 0:
110 | # Y_ = Y_[..., :-delay, :]
111 | # # Y_: ... x D x T x K
112 | # Y_ = np.moveaxis(Y_, -1, -3)
113 | # # Y_: ... x K x D x T
114 | # Y_ = np.reshape(Y_, [*S, K * D, T])
115 | # # Y_: ... x KD x T
116 |
117 | # ToDo: write the shape
118 | Y_ = pad(Y)
119 | Y_ = torch_moveaxis(Y_, -1, -2)
120 | Y_ = torch.flip(Y_, dims=[-1 % Y_.ndimension()])
121 | Y_ = Y_.contiguous() # Y_ = np.ascontiguousarray(Y_)
122 | Y_ = torch.flip(Y_, dims=[-1 % Y_.ndimension()])
123 | Y_ = segment_axis(Y_, taps, 1, axis=-2)
124 |
125 | # Pytorch does not support negative strides.
126 | # Without this flip, the output of this function does not match the
127 | # analytical form, but the output of WPE will be equal.
128 | # Y_ = torch.flip(Y_, dims=[-2 % Y_.ndimension()])
129 |
130 | if delay > 0:
131 | Y_ = Y_[..., :-delay, :, :]
132 | Y_ = torch.reshape(Y_, list(S) + [T, taps * D])
133 | Y_ = torch_moveaxis(Y_, -2, -1)
134 |
135 | return Y_
136 |
137 |
138 | def get_power_inverse(signal, psd_context=0):
139 | """
140 | Assumes single frequency bin with shape (D, T).
141 |
142 | >>> s = 1 / torch.tensor([np.arange(1, 6).astype(np.complex128)]*3)
143 | >>> get_power_inverse(s).numpy()
144 | array([ 1., 4., 9., 16., 25.])
145 |
146 | # >>> get_power_inverse(s * 0 + 1, 1).numpy()
147 | # array([1., 1., 1., 1., 1.])
148 | # >>> get_power_inverse(s, 1).numpy()
149 | # array([ 1.6 , 2.20408163, 7.08196721, 14.04421326, 19.51219512])
150 | # >>> get_power_inverse(s, np.inf).numpy()
151 | # array([3.41620801, 3.41620801, 3.41620801, 3.41620801, 3.41620801])
152 | """
153 | power = torch.mean(torch.abs(signal)**2, dim=-2)
154 |
155 | if np.isposinf(psd_context):
156 | raise NotImplementedError(psd_context)
157 | # power = torch.broadcast_to(torch.mean(power, dim=-1, keepdims=True), power.shape)
158 | elif psd_context > 0:
159 | raise NotImplementedError(psd_context)
160 | # assert int(psd_context) == psd_context, psd_context
161 | # psd_context = int(psd_context)
162 | # # import bottleneck as bn
163 | # # Handle the corner case correctly (i.e. sum() / count)
164 | # # Use bottleneck when only left context is requested
165 | # # power = bn.move_mean(power, psd_context*2+1, min_count=1)
166 | # power = window_mean(power, (psd_context, psd_context))
167 | elif psd_context == 0:
168 | pass
169 | else:
170 | raise ValueError(psd_context)
171 | eps = 1e-10 * torch.max(power)
172 | inverse_power = 1 / torch.max(power, eps)
173 | return inverse_power
174 |
175 |
176 | def transpose(x):
177 | return x.transpose(-2, -1)
178 |
179 |
180 | def hermite(x):
181 | return x.transpose(-2, -1).conj()
182 |
183 |
184 | def wpe_v6(Y, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode='full'):
185 | """
186 | Short of wpe_v7 with no extern references.
187 | Applicable in for-loops.
188 |
189 | >>> T = np.random.randint(100, 120)
190 | >>> D = np.random.randint(2, 6)
191 | >>> K = np.random.randint(3, 5)
192 | >>> delay = np.random.randint(1, 3)
193 |
194 | # Real test:
195 | >>> Y = np.random.normal(size=(D, T))
196 | >>> from nara_wpe import wpe as np_wpe
197 | >>> desired = np_wpe.wpe_v6(Y, K, delay, statistics_mode='full')
198 | >>> actual = wpe_v6(torch.tensor(Y), K, delay, statistics_mode='full').numpy()
199 | >>> np.testing.assert_allclose(actual, desired, atol=1e-6)
200 |
201 | # Complex test:
202 | >>> Y = np.random.normal(size=(D, T)) + 1j * np.random.normal(size=(D, T))
203 | >>> from nara_wpe import wpe as np_wpe
204 | >>> desired = np_wpe.wpe_v6(Y, K, delay, statistics_mode='full')
205 | >>> actual = wpe_v6(torch.tensor(Y), K, delay, statistics_mode='full').numpy()
206 | >>> np.testing.assert_allclose(actual, desired, atol=1e-6)
207 | """
208 |
209 | if statistics_mode == 'full':
210 | s = Ellipsis
211 | elif statistics_mode == 'valid':
212 | s = (Ellipsis, slice(delay + taps - 1, None))
213 | else:
214 | raise ValueError(statistics_mode)
215 |
216 | X = torch.clone(Y)
217 | Y_tilde = build_y_tilde(Y, taps, delay)
218 | for iteration in range(iterations):
219 | inverse_power = get_power_inverse(X, psd_context=psd_context)
220 | Y_tilde_inverse_power = Y_tilde * inverse_power[..., None, :]
221 | R = torch.matmul(Y_tilde_inverse_power[s], hermite(Y_tilde[s]))
222 | P = torch.matmul(Y_tilde_inverse_power[s], hermite(Y[s]))
223 | # G = _stable_solve(R, P)
224 | G = torch.linalg.solve(R, P)
225 | X = Y - torch.matmul(hermite(G), Y_tilde)
226 |
227 | return X
228 |
229 |
230 | def wpe_v8(
231 | Y,
232 | taps=10,
233 | delay=3,
234 | iterations=3,
235 | psd_context=0,
236 | statistics_mode='full',
237 | inplace=False
238 | ):
239 | """
240 | Loopy Multiple Input Multiple Output Weighted Prediction Error [1, 2] implementation
241 |
242 | For numpy it is often the fastest Numpy implementation, torch has to be
243 | profiled.
244 | It loops over the independent axes. This reduces the memory footprint.
245 |
246 | Args:
247 | Y: Complex valued STFT signal with shape (..., D, T).
248 | taps: Filter order
249 | delay: Delay as a guard interval, such that X does not become zero.
250 | iterations:
251 | psd_context: Defines the number of elements in the time window
252 | to improve the power estimation. Total number of elements will
253 | be (psd_context + 1 + psd_context).
254 | statistics_mode: Either 'full' or 'valid'.
255 | 'full': Pad the observation with zeros on the left for the
256 | estimation of the correlation matrix and vector.
257 | 'valid': Only calculate correlation matrix and vector on valid
258 | slices of the observation.
259 | inplace: Whether to change Y inplace. Has only advantages, when Y has
260 | independent axes, because the core WPE algorithm does not support
261 | an inplace modification of the observation.
262 | This option may be relevant, when Y is so large, that you do not
263 | want to double the memory consumption (i.e. save Y and the
264 | dereverberated signal in the memory).
265 |
266 | Returns:
267 | Estimated signal with the same shape as Y
268 |
269 | [1] "Generalization of multi-channel linear prediction methods for blind MIMO
270 | impulse response shortening", Yoshioka, Takuya and Nakatani, Tomohiro, 2012
271 | [2] NARA-WPE: A Python package for weighted prediction error dereverberation in
272 | Numpy and Tensorflow for online and offline processing, Drude, Lukas and
273 | Heymann, Jahn and Boeddeker, Christoph and Haeb-Umbach, Reinhold, 2018
274 |
275 | """
276 | ndim = Y.ndim
277 | if ndim == 2:
278 | out = wpe_v6(
279 | Y,
280 | taps=taps,
281 | delay=delay,
282 | iterations=iterations,
283 | psd_context=psd_context,
284 | statistics_mode=statistics_mode
285 | )
286 | if inplace:
287 | Y[...] = out
288 | return out
289 | elif ndim >= 3:
290 | if inplace:
291 | out = Y
292 | else:
293 | out = torch.empty_like(Y)
294 |
295 | for index in np.ndindex(Y.shape[:-2]):
296 | out[index] = wpe_v6(
297 | Y=Y[index],
298 | taps=taps,
299 | delay=delay,
300 | iterations=iterations,
301 | psd_context=psd_context,
302 | statistics_mode=statistics_mode,
303 | )
304 | return out
305 | else:
306 | raise NotImplementedError(
307 | 'Input shape has to be (..., D, T) and not {}.'.format(Y.shape)
308 | )
309 |
310 |
--------------------------------------------------------------------------------
/nara_wpe/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains the STFT function and related helper functions.
3 | """
4 | import numpy as np
5 | from math import ceil
6 | import scipy
7 |
8 | from scipy import signal
9 | from numpy.fft import rfft, irfft
10 |
11 | import string
12 |
13 | from nara_wpe.wpe import segment_axis as segment_axis_v2
14 |
15 |
16 | # http://stackoverflow.com/a/3153267
17 | def roll_zeropad(a, shift, axis=None):
18 | """
19 | Roll array elements along a given axis.
20 |
21 | Elements off the end of the array are treated as zeros.
22 |
23 | Args:
24 | a: array_like
25 | Input array.
26 | shift: int
27 | The number of places by which elements are shifted.
28 | axis (int): optional,
29 | The axis along which elements are shifted. By default, the array
30 | is flattened before shifting, after which the original
31 | shape is restored.
32 |
33 | Returns:
34 | ndarray: Output array, with the same shape as `a`.
35 |
36 | Note:
37 | roll : Elements that roll off one end come back on the other.
38 | rollaxis : Roll the specified axis backwards, until it lies in a
39 | given position.
40 |
41 | Examples:
42 | >>> x = np.arange(10)
43 | >>> roll_zeropad(x, 2)
44 | array([0, 0, 0, 1, 2, 3, 4, 5, 6, 7])
45 | >>> roll_zeropad(x, -2)
46 | array([2, 3, 4, 5, 6, 7, 8, 9, 0, 0])
47 |
48 | >>> x2 = np.reshape(x, (2,5))
49 | >>> x2
50 | array([[0, 1, 2, 3, 4],
51 | [5, 6, 7, 8, 9]])
52 | >>> roll_zeropad(x2, 1)
53 | array([[0, 0, 1, 2, 3],
54 | [4, 5, 6, 7, 8]])
55 | >>> roll_zeropad(x2, -2)
56 | array([[2, 3, 4, 5, 6],
57 | [7, 8, 9, 0, 0]])
58 | >>> roll_zeropad(x2, 1, axis=0)
59 | array([[0, 0, 0, 0, 0],
60 | [0, 1, 2, 3, 4]])
61 | >>> roll_zeropad(x2, -1, axis=0)
62 | array([[5, 6, 7, 8, 9],
63 | [0, 0, 0, 0, 0]])
64 | >>> roll_zeropad(x2, 1, axis=1)
65 | array([[0, 0, 1, 2, 3],
66 | [0, 5, 6, 7, 8]])
67 | >>> roll_zeropad(x2, -2, axis=1)
68 | array([[2, 3, 4, 0, 0],
69 | [7, 8, 9, 0, 0]])
70 |
71 | >>> roll_zeropad(x2, 50)
72 | array([[0, 0, 0, 0, 0],
73 | [0, 0, 0, 0, 0]])
74 | >>> roll_zeropad(x2, -50)
75 | array([[0, 0, 0, 0, 0],
76 | [0, 0, 0, 0, 0]])
77 | >>> roll_zeropad(x2, 0)
78 | array([[0, 1, 2, 3, 4],
79 | [5, 6, 7, 8, 9]])
80 |
81 | """
82 | a = np.asanyarray(a)
83 | if shift == 0:
84 | return a
85 | if axis is None:
86 | n = a.size
87 | reshape = True
88 | else:
89 | n = a.shape[axis]
90 | reshape = False
91 | if np.abs(shift) > n:
92 | res = np.zeros_like(a)
93 | elif shift < 0:
94 | shift += n
95 | zeros = np.zeros_like(a.take(np.arange(n - shift), axis))
96 | res = np.concatenate((a.take(np.arange(n - shift, n), axis), zeros),
97 | axis)
98 | else:
99 | zeros = np.zeros_like(a.take(np.arange(n - shift, n), axis))
100 | res = np.concatenate((zeros, a.take(np.arange(n - shift), axis)), axis)
101 | if reshape:
102 | return res.reshape(a.shape)
103 | else:
104 | return res
105 |
106 |
107 | def stft(
108 | time_signal,
109 | size,
110 | shift,
111 | axis=-1,
112 | window=signal.windows.blackman,
113 | window_length=None,
114 | fading=True,
115 | pad=True,
116 | symmetric_window=False,
117 | ):
118 | """
119 | ToDo: Open points:
120 | - sym_window need literature
121 | - fading why it is better?
122 | - should pad have more degrees of freedom?
123 |
124 | Calculates the short time Fourier transformation of a multi channel multi
125 | speaker time signal. It is able to add additional zeros for fade-in and
126 | fade out and should yield an STFT signal which allows perfect
127 | reconstruction.
128 |
129 | Args:
130 | time_signal: Multi channel time signal with dimensions
131 | AA x ... x AZ x T x BA x ... x BZ.
132 | size: Scalar FFT-size.
133 | shift: Scalar FFT-shift, the step between successive frames in
134 | samples. Typically shift is a fraction of size.
135 | axis: Scalar axis of time.
136 | Default: None means the biggest dimension.
137 | window: Window function handle. Default is blackman window.
138 | fading: Pads the signal with zeros for better reconstruction.
139 | window_length: Sometimes one desires to use a shorter window than
140 | the fft size. In that case, the window is padded with zeros.
141 | The default is to use the fft-size as a window size.
142 | pad: If true zero pad the signal to match the shape, else cut
143 | symmetric_window: symmetric or periodic window. Assume window is
144 | periodic. Since the implementation of the windows in scipy.signal have a
145 | curious behaviour for odd window_length. Use window(len+1)[:-1]. Since
146 | is equal to the behaviour of MATLAB.
147 |
148 | Returns:
149 | Single channel complex STFT signal with dimensions
150 | AA x ... x AZ x T' times size/2+1 times BA x ... x BZ.
151 | """
152 | time_signal = np.array(time_signal)
153 |
154 | axis = axis % time_signal.ndim
155 |
156 | if window_length is None:
157 | window_length = size
158 |
159 | # Pad with zeros to have enough samples for the window function to fade.
160 | if fading:
161 | pad_width = np.zeros((time_signal.ndim, 2), dtype=int)
162 | pad_width[axis, :] = window_length - shift
163 | time_signal = np.pad(time_signal, pad_width, mode='constant')
164 |
165 | if isinstance(window, str):
166 | window = getattr(signal.windows, window)
167 |
168 | if symmetric_window:
169 | window = window(window_length)
170 | else:
171 | # https://github.com/scipy/scipy/issues/4551
172 | window = window(window_length + 1)[:-1]
173 |
174 | time_signal_seg = segment_axis_v2(
175 | time_signal,
176 | window_length,
177 | shift=shift,
178 | axis=axis,
179 | end='pad' if pad else 'cut'
180 | )
181 |
182 | letters = string.ascii_lowercase[:time_signal_seg.ndim]
183 | mapping = letters + ',' + letters[axis + 1] + '->' + letters
184 |
185 | try:
186 | # ToDo: Implement this more memory efficient
187 | return rfft(
188 | np.einsum(mapping, time_signal_seg, window),
189 | n=size,
190 | axis=axis + 1
191 | )
192 | except ValueError as e:
193 | raise ValueError(
194 | 'Could not calculate the stft, something does not match.\n' +
195 | 'mapping: {}, '.format(mapping) +
196 | 'time_signal_seg.shape: {}, '.format(time_signal_seg.shape) +
197 | 'window.shape: {}, '.format(window.shape) +
198 | 'size: {}'.format(size) +
199 | 'axis+1: {axis+1}'
200 | )
201 |
202 |
203 | def _samples_to_stft_frames(
204 | samples,
205 | size,
206 | shift,
207 | *,
208 | pad=True,
209 | fading=False,
210 | ):
211 | """
212 | Calculates number of STFT frames from number of samples in time domain.
213 |
214 | Args:
215 | samples: Number of samples in time domain.
216 | size: FFT size.
217 | window_length often equal to FFT size. The name size should be
218 | marked as deprecated and replaced with window_length.
219 | shift: Hop in samples.
220 | pad: See stft.
221 | fading: See stft. Note to keep old behavior, default value is False.
222 |
223 | Returns:
224 | Number of STFT frames.
225 |
226 | >>> _samples_to_stft_frames(19, 16, 4)
227 | 2
228 | >>> _samples_to_stft_frames(20, 16, 4)
229 | 2
230 | >>> _samples_to_stft_frames(21, 16, 4)
231 | 3
232 |
233 | >>> stft(np.zeros(19), 16, 4, fading=False).shape
234 | (2, 9)
235 | >>> stft(np.zeros(20), 16, 4, fading=False).shape
236 | (2, 9)
237 | >>> stft(np.zeros(21), 16, 4, fading=False).shape
238 | (3, 9)
239 |
240 | >>> _samples_to_stft_frames(19, 16, 4, fading=True)
241 | 8
242 | >>> _samples_to_stft_frames(20, 16, 4, fading=True)
243 | 8
244 | >>> _samples_to_stft_frames(21, 16, 4, fading=True)
245 | 9
246 |
247 | >>> stft(np.zeros(19), 16, 4).shape
248 | (8, 9)
249 | >>> stft(np.zeros(20), 16, 4).shape
250 | (8, 9)
251 | >>> stft(np.zeros(21), 16, 4).shape
252 | (9, 9)
253 |
254 | >>> _samples_to_stft_frames(21, 16, 3, fading=True)
255 | 12
256 | >>> stft(np.zeros(21), 16, 3).shape
257 | (12, 9)
258 | >>> _samples_to_stft_frames(21, 16, 3)
259 | 3
260 | >>> stft(np.zeros(21), 16, 3, fading=False).shape
261 | (3, 9)
262 | """
263 | if fading:
264 | samples = samples + 2 * (size - shift)
265 |
266 | # I changed this from np.ceil to math.ceil, to yield an integer result.
267 | frames = (samples - size + shift) / shift
268 | if pad:
269 | return ceil(frames)
270 | return int(frames)
271 |
272 |
273 | def _stft_frames_to_samples(frames, size, shift):
274 | """
275 | Calculates samples in time domain from STFT frames
276 |
277 | Args:
278 | frames: Number of STFT frames.
279 | size: FFT size.
280 | shift: Hop in samples.
281 |
282 | Returns:
283 | Number of samples in time domain.
284 | """
285 | return frames * shift + size - shift
286 |
287 |
288 | def _biorthogonal_window_brute_force(analysis_window, shift,
289 | use_amplitude=False):
290 | """
291 | The biorthogonal window (synthesis_window) must verify the criterion:
292 | synthesis_window * analysis_window plus it's shifts must be one.
293 | 1 == sum m from -inf to inf over (synthesis_window(n - mB) * analysis_window(n - mB))
294 | B ... shift
295 | n ... time index
296 | m ... shift index
297 |
298 | Args:
299 | analysis_window:
300 | shift:
301 |
302 | """
303 | size = len(analysis_window)
304 |
305 | influence_width = (size - 1) // shift
306 |
307 | denominator = np.zeros_like(analysis_window)
308 |
309 | if use_amplitude:
310 | analysis_window_square = analysis_window
311 | else:
312 | analysis_window_square = analysis_window ** 2
313 | for i in range(-influence_width, influence_width + 1):
314 | denominator += roll_zeropad(analysis_window_square, shift * i)
315 |
316 | if use_amplitude:
317 | synthesis_window = 1 / denominator
318 | else:
319 | synthesis_window = analysis_window / denominator
320 | return synthesis_window
321 |
322 |
323 | _biorthogonal_window_fastest = _biorthogonal_window_brute_force
324 |
325 |
326 | def istft(
327 | stft_signal,
328 | size=1024,
329 | shift=256,
330 | window=signal.windows.blackman,
331 | fading=True,
332 | window_length=None,
333 | symmetric_window=False,
334 | ):
335 | """
336 | Calculated the inverse short time Fourier transform to exactly reconstruct
337 | the time signal.
338 |
339 | Notes:
340 | Be careful if you make modifications in the frequency domain (e.g.
341 | beamforming) because the synthesis window is calculated according to
342 | the unmodified! analysis window.
343 |
344 | Args:
345 | stft_signal: Single channel complex STFT signal
346 | with dimensions (..., frames, size/2+1).
347 | size: Scalar FFT-size.
348 | shift: Scalar FFT-shift. Typically shift is a fraction of size.
349 | window: Window function handle.
350 | fading: Removes the additional padding, if done during STFT.
351 | window_length: Sometimes one desires to use a shorter window than
352 | the fft size. In that case, the window is padded with zeros.
353 | The default is to use the fft-size as a window size.
354 | symmetric_window: symmetric or periodic window. Assume window is
355 | periodic. Since the implementation of the windows in scipy.signal have a
356 | curious behaviour for odd window_length. Use window(len+1)[:-1]. Since
357 | is equal to the behaviour of MATLAB.
358 |
359 | Returns:
360 | Single channel complex STFT signal
361 | Single channel time signal.
362 | """
363 | # Note: frame_axis and frequency_axis would make this function much more
364 | # complicated
365 | stft_signal = np.array(stft_signal)
366 |
367 | assert stft_signal.shape[-1] == size // 2 + 1, str(stft_signal.shape)
368 |
369 | if window_length is None:
370 | window_length = size
371 |
372 | if isinstance(window, str):
373 | window = getattr(signal.windows, window)
374 |
375 | if symmetric_window:
376 | window = window(window_length)
377 | else:
378 | window = window(window_length + 1)[:-1]
379 |
380 | window = _biorthogonal_window_fastest(window, shift)
381 |
382 | # window = _biorthogonal_window_fastest(
383 | # window, shift, use_amplitude_for_biorthogonal_window)
384 | # if disable_sythesis_window:
385 | # window = np.ones_like(window)
386 |
387 | time_signal = np.zeros(
388 | list(stft_signal.shape[:-2]) +
389 | [stft_signal.shape[-2] * shift + window_length - shift]
390 | )
391 |
392 | # Get the correct view to time_signal
393 | time_signal_seg = segment_axis_v2(
394 | time_signal, window_length, shift, end=None
395 | )
396 |
397 | # Unbuffered inplace add
398 | np.add.at(
399 | time_signal_seg,
400 | Ellipsis,
401 | window * np.real(irfft(stft_signal))[..., :window_length]
402 | )
403 | # The [..., :window_length] is the inverse of the window padding in rfft.
404 |
405 | # Compensate fade-in and fade-out
406 | if fading:
407 | time_signal = time_signal[
408 | ..., window_length - shift:time_signal.shape[-1] - (window_length - shift)]
409 |
410 | return time_signal
411 |
412 |
413 | def istft_single_channel(stft_signal, size=1024, shift=256,
414 | window=signal.windows.blackman, fading=True, window_length=None,
415 | use_amplitude_for_biorthogonal_window=False,
416 | disable_sythesis_window=False):
417 | """
418 | Calculated the inverse short time Fourier transform to exactly reconstruct
419 | the time signal.
420 |
421 | Notes:
422 | Be careful if you make modifications in the frequency domain (e.g.
423 | beamforming) because the synthesis window is calculated according to the
424 | unmodified! analysis window.
425 |
426 | Args:
427 | stft_signal: Single channel complex STFT signal
428 | with dimensions frames times size/2+1.
429 | size: Scalar FFT-size.
430 | shift: Scalar FFT-shift. Typically shift is a fraction of size.
431 | window: Window function handle.
432 | fading: Removes the additional padding, if done during STFT.
433 | window_length: Sometimes one desires to use a shorter window than
434 | the fft size. In that case, the window is padded with zeros.
435 | The default is to use the fft-size as a window size.
436 |
437 | Returns:
438 | Single channel complex STFT signal
439 | Single channel time signal.
440 | """
441 | assert stft_signal.shape[1] == size // 2 + 1, str(stft_signal.shape)
442 |
443 | if window_length is None:
444 | window = window(size)
445 | else:
446 | window = window(window_length)
447 | window = np.pad(window, (0, size-window_length), mode='constant')
448 | window = _biorthogonal_window_fastest(window, shift,
449 | use_amplitude_for_biorthogonal_window)
450 | if disable_sythesis_window:
451 | window = np.ones_like(window)
452 |
453 | time_signal = np.zeros(stft_signal.shape[0] * shift + size - shift)
454 |
455 | for j, i in enumerate(range(0, len(time_signal) - size + shift, shift)):
456 | time_signal[i:i + size] += window * np.real(irfft(stft_signal[j]))
457 |
458 | # Compensate fade-in and fade-out
459 | if fading:
460 | time_signal = time_signal[size-shift:len(time_signal)-(size-shift)]
461 |
462 | return time_signal
463 |
464 |
465 | def stft_to_spectrogram(stft_signal):
466 | """
467 | Calculates the power spectrum (spectrogram) of an stft signal.
468 | The output is guaranteed to be real.
469 |
470 | Args:
471 | stft_signal: Complex STFT signal with dimensions
472 | #time_frames times #frequency_bins.
473 |
474 | Returns:
475 | Real spectrogram with same dimensions as input.
476 | """
477 | spectrogram = stft_signal.real**2 + stft_signal.imag**2
478 | return spectrogram
479 |
480 |
481 | def spectrogram(time_signal, *args, **kwargs):
482 | """
483 | Thin wrapper of stft with power spectrum calculation.
484 |
485 | Args:
486 | time_signal:
487 | *args:
488 | **kwargs:
489 |
490 | Returns:
491 |
492 | """
493 | return stft_to_spectrogram(stft(time_signal, *args, **kwargs))
494 |
495 |
496 | def spectrogram_to_energy_per_frame(spectrogram):
497 | """
498 | The energy per frame is sometimes used as an additional feature to the MFCC
499 | features. Here, it is calculated from the power spectrum.
500 |
501 | Args:
502 | spectrogram: Real valued power spectrum.
503 |
504 | Returns:
505 | Real valued energy per frame.
506 | """
507 | energy = np.sum(spectrogram, 1)
508 |
509 | # If energy is zero, we get problems with log
510 | energy = np.where(energy == 0, np.finfo(float).eps, energy)
511 | return energy
512 |
513 |
514 | def get_stft_center_frequencies(size=1024, sample_rate=16000):
515 | """
516 | It is often necessary to know, which center frequency is
517 | represented by each frequency bin index.
518 |
519 | Args:
520 | size: Scalar FFT-size.
521 | sample_rate: Scalar sample frequency in Hertz.
522 |
523 | Returns:
524 | Array of all relevant center frequencies
525 | """
526 | frequency_index = np.arange(0, size/2 + 1)
527 | return frequency_index * sample_rate / size
528 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel", "Cython", "numpy", "scipy"]
3 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | addopts =
3 | --doctest-modules
4 | --doctest-continue-on-failure
5 | --junitxml=junit/test-results.xml
6 | --cov=nara_wpe
7 | --cov-report=xml
8 | --cov-report=html
9 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """A setuptools based setup module.
2 |
3 | See:
4 | https://packaging.python.org/en/latest/distributing.html
5 | https://github.com/pypa/sampleproject
6 | """
7 |
8 | import setuptools as st
9 | from codecs import open
10 | from os import path
11 |
12 |
13 | here = path.abspath(path.dirname(__file__))
14 |
15 | # Get the long description from the relevant file
16 | with open(path.join(here, 'README.rst'), encoding='utf-8') as f:
17 | long_description = f.read()
18 |
19 | st.setup(
20 | name='nara_wpe',
21 |
22 | # Versions should comply with PEP440. For a discussion on single-sourcing
23 | # the version across setup.py and the project code, see
24 | # https://packaging.python.org/en/latest/single_source_version.html
25 | version='0.0.11',
26 |
27 | description='Weighted Prediction Error for speech dereverberation',
28 | long_description=long_description,
29 |
30 | # The project's main homepage.
31 | url='https://github.com/fgnt/nara_wpe',
32 |
33 | # Author details
34 | author='Department of Communications Engineering, Paderborn University',
35 | author_email='sek@nt.upb.de',
36 |
37 | # Choose your license
38 | license='MIT',
39 |
40 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers
41 | classifiers=[
42 | # How mature is this project? Common values are
43 | # 3 - Alpha
44 | # 4 - Beta
45 | # 5 - Production/Stable
46 | 'Development Status :: 3 - Alpha',
47 |
48 | # Indicate who your project is intended for
49 | 'Intended Audience :: Developers',
50 | 'Topic :: Software Development :: Build Tools',
51 |
52 | # Pick your license as you wish (should match "license" above)
53 | 'License :: OSI Approved :: MIT License',
54 |
55 | # Specify the Python versions you support here. In particular, ensure
56 | # that you indicate whether you support Python 2, Python 3 or both.
57 | 'Programming Language :: Python :: 3.7',
58 | 'Programming Language :: Python :: 3.8',
59 | 'Programming Language :: Python :: 3.9',
60 | 'Programming Language :: Python :: 3.10',
61 | 'Programming Language :: Python :: 3.11',
62 | 'Programming Language :: Python :: 3.12',
63 | ],
64 |
65 | # What does your project relate to?
66 | keywords='speech',
67 |
68 | # You can just specify the packages manually here if your project is
69 | # simple. Or you can use find_packages().
70 | packages=st.find_packages(exclude=['contrib', 'docs', 'tests*']),
71 |
72 | # List run-time dependencies here. These will be installed by pip when
73 | # your project is installed. For an analysis of "install_requires" vs pip's
74 | # requirements files see:
75 | # https://packaging.python.org/en/latest/requirements.html
76 | install_requires=[
77 | 'pathlib2;python_version<"3.0"',
78 | 'numpy',
79 | 'tqdm',
80 | 'soundfile',
81 | 'bottleneck',
82 | 'click'
83 | ],
84 |
85 | # List additional groups of dependencies here (e.g. development
86 | # dependencies). You can install these using the following syntax,
87 | # for example:
88 | # $ pip install -e .[dev,test]
89 | extras_require={
90 | 'dev': ['check-manifest'],
91 | 'test': [
92 | 'pytest',
93 | 'coverage',
94 | 'jupyter',
95 | 'matplotlib',
96 | 'scipy',
97 | 'tensorflow==1.12.0;python_version<"3.7"', # Python 3.7 has no tensorflow==1.12.0
98 | 'pytest-cov',
99 | 'codecov',
100 | 'pandas',
101 | 'torch',
102 | 'cached_property',
103 | 'zmq',
104 | 'pyzmq', # Required to install pymatbridge
105 | 'pymatbridge',
106 | ],
107 | },
108 |
109 | # If there are data files included in your packages that need to be
110 | # installed, specify them here. If using Python 2.6 or less, then these
111 | # have to be included in MANIFEST.in as well.
112 | # package_data={
113 | # 'nt_feature_extraction': ['package_data.dat'],
114 | # },
115 |
116 | # Although 'package_data' is the preferred approach, in some case you may
117 | # need to place data files outside of your packages. See:
118 | # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa
119 | # In this case, 'data_file' will be installed into '/my_data'
120 | # data_files=[('my_data', ['data/data_file'])],
121 |
122 | # https://stackoverflow.com/questions/2379898/make-distutils-look-for-numpy-header-files-in-the-correct-place
123 | # include_dirs = [np.get_include()],
124 |
125 | # To provide executable scripts, use entry points in preference to the
126 | # "scripts" keyword. Entry points provide cross-platform support and allow
127 | # pip to create the appropriate form of executable for the target platform.
128 | # entry_points={
129 | # 'console_scripts': [
130 | # 'nt_feature_extraction=nt_feature_extraction:main',
131 | # ],
132 | # },
133 | )
134 |
--------------------------------------------------------------------------------
/tests/test_notebooks.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import subprocess
4 | import tempfile
5 | import unittest
6 | import pytest
7 |
8 | import nbformat
9 |
10 | from nara_wpe import project_root
11 |
12 |
13 | def _notebook_run(path):
14 | """Execute a notebook via nbconvert and collect output.
15 | :returns (parsed nb object, execution errors)
16 | """
17 | dirname = os.path.dirname(str(path))
18 | os.chdir(dirname)
19 | with tempfile.NamedTemporaryFile(suffix=".ipynb") as fout:
20 | args = ["jupyter", "nbconvert", "--to", "notebook", "--execute",
21 | "--ExecutePreprocessor.timeout=360",
22 | "--output", fout.name, str(path)]
23 | subprocess.check_call(args)
24 |
25 | fout.seek(0)
26 | nb = nbformat.read(fout, nbformat.current_nbformat)
27 |
28 | errors = [
29 | output for cell in nb.cells if "outputs" in cell
30 | for output in cell["outputs"]
31 | if output.output_type == "error"
32 | ]
33 |
34 | return nb, errors
35 |
36 |
37 | root = project_root / 'examples'
38 |
39 |
40 | @pytest.mark.skipif(sys.version_info >= (3, 6, 0), reason='Only with Python 3.6+')
41 | def test_wpe_numpy_offline(self):
42 | nb, errors = _notebook_run(root / 'WPE_Numpy_offline.ipynb')
43 | assert errors == []
44 |
45 |
46 | @pytest.mark.skipif(sys.version_info >= (3, 6, 0), reason='Only with Python 3.6+')
47 | def test_wpe_numpy_online(self):
48 | nb, errors = _notebook_run(root / 'WPE_Numpy_online.ipynb')
49 | assert errors == []
50 |
51 |
52 | @pytest.mark.skipif(not((3, 6, 0) <= sys.version_info < (3, 7, 0)), reason='Only with Python 3.6')
53 | def test_wpe_tensorflow_offline():
54 | nb, errors = _notebook_run(root / 'WPE_Tensorflow_offline.ipynb')
55 | assert errors == []
56 |
57 |
58 | @pytest.mark.skipif(not((3, 6, 0) <= sys.version_info < (3, 7, 0)), reason='Only with Python 3.6')
59 | def test_wpe_tensorflow_online():
60 | nb, errors = _notebook_run(root / 'WPE_Tensorflow_online.ipynb')
61 | assert errors == []
62 |
63 |
64 | # def test_NTT_wrapper(self):
65 | # nb, errors = _notebook_run(self.root / 'NTT_wrapper_offline.ipynb')
66 | # assert errors == []
67 |
--------------------------------------------------------------------------------
/tests/test_tf_wpe.py:
--------------------------------------------------------------------------------
1 | """
2 | Run individual test file with i.e.
3 | python nara_wpe/tests/test_tf_wpe.py
4 | """
5 | import sys
6 | import pytest
7 | import numpy as np
8 | from nara_wpe import wpe
9 | if sys.version_info < (3, 7):
10 | from nara_wpe import tf_wpe
11 | import tensorflow as tf
12 | else:
13 | # Dummy tf
14 | class tf:
15 | class test:
16 | class TestCase:
17 | pass
18 | from nara_wpe.test_utils import retry
19 |
20 |
21 | @pytest.mark.skipif(sys.version_info >= (3, 7), reason="Python 3.7+ has no tensorflow==1.12. Help wanted for tensorflow 2 support.")
22 | class TestWPE(tf.test.TestCase):
23 | def setUp(self):
24 | np.random.seed(0)
25 | self.T = np.random.randint(100, 120)
26 | self.D = np.random.randint(2, 8)
27 | self.K = np.random.randint(3, 5)
28 | self.delay = np.random.randint(0, 2)
29 | self.Y = np.random.normal(size=(self.D, self.T)) \
30 | + 1j * np.random.normal(size=(self.D, self.T))
31 |
32 | def test_inverse_power(self):
33 | np_inv_power = wpe.get_power_inverse(self.Y)
34 |
35 | with self.test_session() as sess:
36 | tf_signal = tf.placeholder(tf.complex128, shape=[None, None])
37 | tf_res = tf_wpe.get_power_inverse(tf_signal)
38 | tf_inv_power = sess.run(tf_res, {tf_signal: self.Y})
39 |
40 | np.testing.assert_allclose(np_inv_power, tf_inv_power)
41 |
42 | def test_correlations(self):
43 | np_inv_power = wpe.get_power_inverse(self.Y)
44 | np_corr = wpe.get_correlations_narrow(
45 | self.Y, np_inv_power, self.K, self.delay
46 | )
47 |
48 | with tf.Graph().as_default(), tf.Session() as sess:
49 | tf_signal = tf.placeholder(tf.complex128, shape=[None, None])
50 | tf_inverse_power = tf_wpe.get_power_inverse(tf_signal)
51 | tf_res = tf_wpe.get_correlations_for_single_frequency(
52 | tf_signal, tf_inverse_power, self.K, self.delay
53 | )
54 | tf_corr = sess.run(tf_res, {tf_signal: self.Y})
55 |
56 | np.testing.assert_allclose(np_corr[0], tf_corr[0])
57 | np.testing.assert_allclose(np_corr[1], tf_corr[1])
58 |
59 | @retry(5)
60 | def test_filter_matrix(self):
61 | np_inv_power = wpe.get_power_inverse(self.Y)
62 | np_filter_matrix = wpe.get_filter_matrix_conj_v5(
63 | self.Y, np_inv_power, self.K, self.delay
64 | )
65 |
66 | with tf.Graph().as_default(), tf.Session() as sess:
67 | tf_signal = tf.placeholder(tf.complex128, shape=[None, None])
68 | tf_inverse_power = tf_wpe.get_power_inverse(tf_signal)
69 | tf_matrix, tf_vector = tf_wpe.get_correlations_for_single_frequency(
70 | tf_signal, tf_inverse_power, self.K, self.delay
71 | )
72 | tf_filter = tf_wpe.get_filter_matrix_conj(
73 | tf_signal, tf_matrix, tf_vector,
74 | self.K, self.delay
75 | )
76 | tf_filter_matrix, tf_inv_power_2 = sess.run(
77 | [tf_filter, tf_inverse_power], {tf_signal: self.Y}
78 | )
79 |
80 | np.testing.assert_allclose(np_inv_power, tf_inv_power_2)
81 | np.testing.assert_allclose(np_filter_matrix, tf_filter_matrix)
82 |
83 | @retry(5)
84 | def test_filter_operation(self):
85 | np_inv_power = wpe.get_power_inverse(self.Y)
86 | np_filter_matrix = wpe.get_filter_matrix_conj_v5(
87 | self.Y, np_inv_power, self.K, self.delay
88 | )
89 | np_filter_op = wpe.perform_filter_operation_v4(
90 | self.Y, np_filter_matrix, self.K, self.delay
91 | )
92 |
93 | with tf.Graph().as_default(), tf.Session() as sess:
94 | tf_signal = tf.placeholder(tf.complex128, shape=[None, None])
95 | tf_inverse_power = tf_wpe.get_power_inverse(tf_signal)
96 | tf_matrix, tf_vector = tf_wpe.get_correlations_for_single_frequency(
97 | tf_signal, tf_inverse_power, self.K, self.delay
98 | )
99 | tf_filter = tf_wpe.get_filter_matrix_conj(
100 | tf_signal, tf_matrix, tf_vector,
101 | self.K, self.delay
102 | )
103 | tf_filter_op = tf_wpe.perform_filter_operation(
104 | tf_signal, tf_filter, self.K, self.delay
105 | )
106 | tf_filter_op = sess.run(tf_filter_op, {tf_signal: self.Y})
107 |
108 | np.testing.assert_allclose(np_filter_op, tf_filter_op)
109 |
110 | def test_wpe_step(self):
111 | with self.test_session() as sess:
112 | Y = tf.convert_to_tensor(self.Y[None])
113 | enhanced, inv_power = tf_wpe.single_frequency_wpe(
114 | Y[0], iterations=3
115 | )
116 | step_enhanced = tf_wpe.wpe_step(Y, inv_power[None])
117 | enhanced, step_enhanced = sess.run(
118 | [enhanced, step_enhanced]
119 | )
120 | np.testing.assert_allclose(enhanced, step_enhanced[0])
121 |
122 | def _get_batch_data(self):
123 | Y = tf.convert_to_tensor(self.Y[None])
124 | inv_power = tf_wpe.get_power_inverse(Y[0])[None]
125 | Y_short = Y[..., :self.T-20]
126 | inv_power_short = inv_power[..., :self.T-20]
127 | Y_batch = tf.stack(
128 | [Y, tf.pad(Y_short, ((0, 0), (0, 0), (0, 20)))]
129 | )
130 | inv_power_batch = tf.stack(
131 | [inv_power, tf.pad(inv_power_short, ((0, 0), (0, 20)))]
132 | )
133 | return Y_batch, inv_power_batch
134 |
135 | def test_batched_wpe_step(self):
136 | with self.test_session() as sess:
137 | Y_batch, inv_power_batch = self._get_batch_data()
138 | enhanced_ref_1 = tf_wpe.wpe_step(
139 | Y_batch[0], inv_power_batch[0]
140 | )
141 | enhanced_ref_2 = tf_wpe.wpe_step(
142 | Y_batch[0, ...,:self.T-20], inv_power_batch[0, ...,:self.T-20]
143 | )
144 | step_enhanced = tf_wpe.batched_wpe_step(
145 | Y_batch, inv_power_batch,
146 | num_frames=tf.convert_to_tensor([self.T, self.T-20])
147 | )
148 | enhanced, ref1, ref2 = sess.run(
149 | [step_enhanced, enhanced_ref_1, enhanced_ref_2]
150 | )
151 | np.testing.assert_allclose(enhanced[0], ref1)
152 | np.testing.assert_allclose(enhanced[1, ..., :-20], ref2)
153 |
154 | def test_wpe(self):
155 | with self.test_session() as sess:
156 | Y = tf.convert_to_tensor(self.Y)
157 | enhanced, inv_power = tf_wpe.single_frequency_wpe(
158 | Y, iterations=1
159 | )
160 | enhanced = sess.run(enhanced)
161 | ref = wpe.wpe_v7(self.Y, iterations=1, statistics_mode='valid')
162 | np.testing.assert_allclose(enhanced, ref)
163 |
164 | def test_batched_wpe(self):
165 | with self.test_session() as sess:
166 | Y_batch, _ = self._get_batch_data()
167 | enhanced_ref_1 = tf_wpe.wpe(Y_batch[0])
168 | enhanced_ref_2 = tf_wpe.wpe(Y_batch[0, ..., :self.T-20])
169 | step_enhanced = tf_wpe.batched_wpe(
170 | Y_batch,
171 | num_frames=tf.convert_to_tensor([self.T, self.T-20])
172 | )
173 | enhanced, ref1, ref2 = sess.run(
174 | [step_enhanced, enhanced_ref_1, enhanced_ref_2]
175 | )
176 | np.testing.assert_allclose(enhanced[0], ref1)
177 | np.testing.assert_allclose(enhanced[1, ..., :-20], ref2)
178 |
179 | def test_batched_block_wpe_step(self):
180 | with self.test_session() as sess:
181 | Y_batch, inv_power_batch = self._get_batch_data()
182 | enhanced_ref_1 = tf_wpe.block_wpe_step(
183 | Y_batch[0], inv_power_batch[0]
184 | )
185 | enhanced_ref_2 = tf_wpe.block_wpe_step(
186 | Y_batch[0, ..., :self.T-20], inv_power_batch[0, ..., :self.T-20]
187 | )
188 | step_enhanced = tf_wpe.batched_block_wpe_step(
189 | Y_batch, inv_power_batch,
190 | num_frames=tf.convert_to_tensor([self.T, self.T-20])
191 | )
192 | enhanced, ref1, ref2 = sess.run(
193 | [step_enhanced, enhanced_ref_1, enhanced_ref_2]
194 | )
195 | np.testing.assert_allclose(enhanced[0], ref1)
196 | np.testing.assert_allclose(enhanced[1, ..., :-20], ref2)
197 |
198 | @retry(5)
199 | def test_recursive_wpe(self):
200 | with self.test_session() as sess:
201 | T = 5000
202 | D = 2
203 | K = 1
204 | delay = 3
205 | Y = np.random.normal(size=(D, T)) \
206 | + 1j * np.random.normal(size=(D, T))
207 | Y = tf.convert_to_tensor(Y[None])
208 | power = tf.reduce_mean(tf.real(Y) ** 2 + tf.imag(Y) ** 2, axis=1)
209 | inv_power = tf.reciprocal(power)
210 | step_enhanced = tf_wpe.wpe_step(
211 | Y, inv_power, taps=K, delay=D)
212 | recursive_enhanced = tf_wpe.recursive_wpe(
213 | tf.transpose(Y, (2, 0, 1)),
214 | tf.transpose(power),
215 | 1.,
216 | taps=K,
217 | delay=D,
218 | only_use_final_filters=True
219 | )
220 | recursive_enhanced = tf.transpose(recursive_enhanced, (1, 2, 0))
221 | recursive_enhanced, step_enhanced = sess.run(
222 | [recursive_enhanced, step_enhanced]
223 | )
224 | np.testing.assert_allclose(
225 | recursive_enhanced[..., -200:],
226 | step_enhanced[..., -200:],
227 | atol=0.01, rtol=0.2
228 | )
229 |
230 | def test_batched_recursive_wpe(self):
231 | with self.test_session() as sess:
232 | Y_batch, inv_power_batch = self._get_batch_data()
233 | Y_batch = tf.transpose(Y_batch, (0, 3, 1, 2))
234 | inv_power_batch = tf.transpose(inv_power_batch, (0, 2, 1))
235 | enhanced_ref_1 = tf_wpe.recursive_wpe(
236 | Y_batch[0], inv_power_batch[0], 0.999
237 | )
238 | enhanced_ref_2 = tf_wpe.recursive_wpe(
239 | Y_batch[0, :self.T-20], inv_power_batch[0, :self.T-20],
240 | 0.999
241 | )
242 | step_enhanced = tf_wpe.batched_recursive_wpe(
243 | Y_batch, inv_power_batch, 0.999,
244 | num_frames=tf.convert_to_tensor([self.T, self.T-20])
245 | )
246 | enhanced, ref1, ref2 = sess.run(
247 | [step_enhanced, enhanced_ref_1, enhanced_ref_2]
248 | )
249 | np.testing.assert_allclose(enhanced[0], ref1)
250 | np.testing.assert_allclose(enhanced[1, :-20], ref2)
251 |
252 |
253 | if __name__ == '__main__':
254 | if sys.version_info < (3, 7):
255 | tf.test.main()
256 |
--------------------------------------------------------------------------------
/tests/test_wpe.py:
--------------------------------------------------------------------------------
1 | """
2 | Run all tests with:
3 | nosetests -w tests/
4 | """
5 |
6 | import unittest
7 | import numpy.testing as tc
8 | import numpy as np
9 | from nara_wpe import wpe
10 | from nara_wpe.test_utils import retry
11 |
12 |
13 | class TestWPE(unittest.TestCase):
14 | def setUp(self):
15 | self.T = np.random.randint(100, 120)
16 | self.D = np.random.randint(2, 6)
17 | self.K = np.random.randint(3, 5)
18 | self.delay = np.random.randint(1, 3)
19 | self.Y = np.random.normal(size=(self.D, self.T)) \
20 | + 1j * np.random.normal(size=(self.D, self.T))
21 |
22 | def test_correlations_v1_vs_v2_toy_example(self):
23 | K = 3
24 | delay = 1
25 | Y = np.asarray(
26 | [
27 | [11, 12, 13, 14],
28 | [41, 22, 23, 24]
29 | ], dtype=np.float32
30 | )
31 | inverse_power = wpe.get_power_inverse(Y)
32 | R_desired, r_desired = wpe.get_correlations(Y, inverse_power, K, delay)
33 | R_actual, r_actual = wpe.get_correlations_v2(Y, inverse_power, K, delay)
34 | tc.assert_allclose(R_actual, R_desired)
35 | tc.assert_allclose(r_actual, r_desired)
36 |
37 | def test_correlations_v1_vs_v2(self):
38 | inverse_power = wpe.get_power_inverse(self.Y)
39 | R_desired, r_desired = wpe.get_correlations(
40 | self.Y, inverse_power, self.K, self.delay
41 | )
42 | R_actual, r_actual = wpe.get_correlations_v2(
43 | self.Y, inverse_power, self.K, self.delay
44 | )
45 | tc.assert_allclose(R_actual, R_desired)
46 | tc.assert_allclose(r_actual, r_desired)
47 |
48 | @retry(5)
49 | def test_filter_operation_v1_vs_v4(self):
50 | filter_matrix_conj = np.random.normal(size=(self.K, self.D, self.D)) \
51 | + 1j * np.random.normal(size=(self.K, self.D, self.D))
52 |
53 | desired = wpe.perform_filter_operation(
54 | self.Y, filter_matrix_conj, self.K, self.delay
55 | )
56 | actual = wpe.perform_filter_operation_v4(
57 | self.Y, filter_matrix_conj, self.K, self.delay
58 | )
59 | tc.assert_allclose(desired, actual)
60 |
61 | def test_correlations_narrow_v1_vs_v5(self):
62 | inverse_power = wpe.get_power_inverse(self.Y)
63 | R_desired, r_desired = wpe.get_correlations_narrow(
64 | self.Y, inverse_power, self.K, self.delay
65 | )
66 | R_actual, r_actual = wpe.get_correlations_narrow_v5(
67 | self.Y, inverse_power, self.K, self.delay
68 | )
69 | tc.assert_allclose(R_actual, R_desired)
70 | tc.assert_allclose(r_actual, r_desired)
71 |
72 | def test_correlations_narrow_v1_vs_v6(self):
73 | inverse_power = wpe.get_power_inverse(self.Y)
74 | R_desired, r_desired = wpe.get_correlations_narrow(
75 | self.Y, inverse_power, self.K, self.delay
76 | )
77 |
78 | s = (Ellipsis, slice(self.delay + self.K - 1, None))
79 | Y_tilde = wpe.build_y_tilde(self.Y, self.K, self.delay)
80 | R_actual, r_actual = wpe.get_correlations_v6(
81 | self.Y[s], Y_tilde[s], inverse_power[s]
82 | )
83 | tc.assert_allclose(R_actual.conj(), R_desired)
84 | tc.assert_allclose(
85 | r_actual.conj(),
86 | np.swapaxes(r_desired, 1, 2).reshape(-1, r_desired.shape[-1]),
87 | rtol=1e-5, atol=1e-5
88 | )
89 |
90 | @retry(5)
91 | def test_filter_matrix_conj_v1_vs_v5(self):
92 | inverse_power = wpe.get_power_inverse(self.Y)
93 |
94 | correlation_matrix, correlation_vector = wpe.get_correlations(
95 | self.Y, inverse_power, self.K, self.delay
96 | )
97 | desired = wpe.get_filter_matrix_conj(
98 | correlation_matrix, correlation_vector, self.K, self.D
99 | )
100 | actual = wpe.get_filter_matrix_conj_v5(
101 | self.Y, inverse_power, self.K, self.delay
102 | )
103 | tc.assert_allclose(actual, desired, atol=1e-6)
104 |
105 | @retry(5)
106 | def test_filter_matrix_conj_v1_vs_v7(self):
107 | # ToDo: Fix 'wpe.get_correlations' to consider delay correctly at the
108 | # begin and end of the utterance.
109 | # delay = self.delay
110 | delay = 0
111 |
112 | inverse_power = wpe.get_power_inverse(self.Y)
113 |
114 | correlation_matrix, correlation_vector = wpe.get_correlations(
115 | self.Y, inverse_power, self.K, delay
116 | )
117 | desired = wpe.get_filter_matrix_conj(
118 | correlation_matrix, correlation_vector, self.K, self.D
119 | )
120 |
121 | s = [Ellipsis, slice(delay + self.K - 1, None)]
122 | Y_tilde = wpe.build_y_tilde(self.Y, self.K, delay)
123 | actual = wpe.get_filter_matrix_v7(
124 | self.Y, Y_tilde=Y_tilde, inverse_power=inverse_power,
125 | )
126 | tc.assert_allclose(
127 | actual.conj(),
128 | np.swapaxes(desired, 1, 2).reshape(-1, desired.shape[-1]),
129 | atol=1e-6
130 | )
131 |
132 |
133 | @retry(5)
134 | def test_delay_zero_cancels_all(self):
135 | delay = 0
136 | X_hat = wpe.wpe(self.Y, self.K, delay=delay)
137 |
138 | # Beginning is never zero. It is a copy of input signal.
139 | tc.assert_allclose(
140 | X_hat[:, delay + self.K - 1:],
141 | np.zeros_like(X_hat[:, delay + self.K - 1:]),
142 | atol=1e-6
143 | )
144 |
145 | @retry(5)
146 | def test_wpe_v0_vs_v7(self):
147 | # ToDo: Fix 'wpe.wpe_v0' to consider delay correctly at the
148 | # begin and end of the utterance.
149 | # delay = self.delay
150 | delay = 0
151 |
152 | desired = wpe.wpe_v0(self.Y, self.K, delay, statistics_mode='full')
153 | actual = wpe.wpe_v7(self.Y, self.K, delay, statistics_mode='full')
154 | tc.assert_allclose(actual, desired, atol=1e-6)
155 |
156 | desired = wpe.wpe_v0(self.Y, self.K, delay, statistics_mode='valid')
157 | actual = wpe.wpe_v7(self.Y, self.K, delay, statistics_mode='valid')
158 | tc.assert_allclose(actual, desired, atol=1e-6)
159 |
160 | desired = wpe.wpe_v6(self.Y, self.K, delay, statistics_mode='valid')
161 | actual = wpe.wpe_v7(self.Y, self.K, delay, statistics_mode='full')
162 | tc.assert_raises(AssertionError, tc.assert_array_equal, desired, actual)
163 |
164 | @retry(5)
165 | def test_wpe_v6_vs_v7(self):
166 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full')
167 | actual = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full')
168 | tc.assert_allclose(actual, desired, atol=1e-6)
169 |
170 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='valid')
171 | actual = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='valid')
172 | tc.assert_allclose(actual, desired, atol=1e-6)
173 |
174 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='valid')
175 | actual = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full')
176 | tc.assert_raises(AssertionError, tc.assert_array_equal, desired, actual)
177 |
178 | @retry(5)
179 | def test_wpe_v8(self):
180 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='valid')
181 | actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='valid')
182 | tc.assert_allclose(actual, desired, atol=1e-6)
183 |
184 | desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='valid')
185 | actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='valid')
186 | tc.assert_allclose(actual, desired, atol=1e-6)
187 |
188 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full')
189 | actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full')
190 | tc.assert_allclose(actual, desired, atol=1e-6)
191 |
192 | desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full')
193 | actual = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full')
194 | tc.assert_allclose(actual, desired, atol=1e-6)
195 |
196 | @retry(5)
197 | def test_wpe_multi_freq(self):
198 | desired = wpe.wpe_v0(self.Y, self.K, self.delay, statistics_mode='full')
199 | desired = [desired, desired]
200 | actual = wpe.wpe_v0(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full')
201 | tc.assert_allclose(actual, desired, atol=1e-6)
202 |
203 | desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full')
204 | desired = [desired, desired]
205 | actual = wpe.wpe_v7(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full')
206 | tc.assert_allclose(actual, desired, atol=1e-6)
207 |
208 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full')
209 | desired = [desired, desired]
210 | actual = wpe.wpe_v6(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full')
211 | tc.assert_allclose(actual, desired, atol=1e-6)
212 |
213 | desired = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full')
214 | desired = [desired, desired]
215 | actual = wpe.wpe_v8(np.array([self.Y, self.Y]), self.K, self.delay, statistics_mode='full')
216 | tc.assert_allclose(actual, desired, atol=1e-6)
217 |
218 | @retry(5)
219 | def test_wpe_batched_multi_freq(self):
220 | def to_batched_multi_freq(x):
221 | return np.array([
222 | [x, x*2],
223 | [x*3, x*4],
224 | [x*5, x*6],
225 | ])
226 | Y_batched_multi_freq = to_batched_multi_freq(self.Y)
227 |
228 | tc.assert_raises(NotImplementedError, wpe.wpe_v0, Y_batched_multi_freq, self.K, self.delay, statistics_mode='full')
229 |
230 | desired = wpe.wpe_v7(self.Y, self.K, self.delay, statistics_mode='full')
231 | desired = to_batched_multi_freq(desired)
232 | actual = wpe.wpe_v7(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full')
233 | assert desired.shape == (3, 2, self.D, self.T)
234 | assert actual.shape == (3, 2, self.D, self.T)
235 | tc.assert_allclose(actual, desired, atol=1e-6)
236 |
237 | desired = wpe.wpe_v6(self.Y, self.K, self.delay, statistics_mode='full')
238 | desired = to_batched_multi_freq(desired)
239 | actual = wpe.wpe_v6(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full')
240 | assert desired.shape == (3, 2, self.D, self.T)
241 | assert actual.shape == (3, 2, self.D, self.T)
242 | tc.assert_allclose(actual, desired, atol=1e-6)
243 |
244 | desired = wpe.wpe_v8(self.Y, self.K, self.delay, statistics_mode='full')
245 | desired = to_batched_multi_freq(desired)
246 | actual = wpe.wpe_v8(Y_batched_multi_freq, self.K, self.delay, statistics_mode='full')
247 | assert desired.shape == (3, 2, self.D, self.T)
248 | assert actual.shape == (3, 2, self.D, self.T)
249 | tc.assert_allclose(actual, desired, atol=1e-6)
250 |
--------------------------------------------------------------------------------