├── .github ├── FUNDING.yml └── workflows │ └── python-package.yml ├── README.md ├── doc ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── index.rst │ └── modules │ ├── methods.rst │ └── metrics.rst ├── main.py ├── readthedocs.yml ├── requirements.txt ├── setup.py ├── test ├── __init__.py ├── consts.py ├── test_admm.py ├── test_griffin.py ├── test_lbfgs.py └── test_rtisila.py └── torch_specinv ├── __init__.py ├── methods.py ├── metrics.py └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: yoyololicon 4 | custom: PayPal.Me/iamycy 5 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ${{ matrix.os }} 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest, macos-latest] 19 | python-version: [3.7, 3.8, 3.9] 20 | env: 21 | OS: ${{ matrix.os }} 22 | PYTHON: ${{ matrix.python-version }} 23 | steps: 24 | - uses: actions/checkout@v2 25 | - name: Set up Python ${{ matrix.python-version }} 26 | uses: actions/setup-python@v2 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | python -m pip install flake8 pytest pytest-cov 33 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 34 | - name: Lint with flake8 35 | run: | 36 | # stop the build if there are Python syntax errors or undefined names 37 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 38 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 39 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 40 | - name: Test with pytest 41 | run: | 42 | pytest --cov-report=xml --cov-report term-missing:skip-covered --cov=torch_specinv 43 | - name: Codecov 44 | uses: codecov/codecov-action@v3 45 | with: 46 | flags: unittests 47 | env_vars: OS,PYTHON 48 | name: codecov-umbrella 49 | fail_ci_if_error: true 50 | verbose: false 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Spectrogram Inversion Documentation 2 | 3 | A major direction of Deep Learning in audio, especially generative models, is using features in frequency domain because 4 | directly model raw time signal is hard. 5 | But this require an extra process to convert the predicted spectrogram (magnitude-only in most situation) back to time domain. 6 | 7 | To help researcher no need to care this post-precessing step, this package provide some useful and classic spectrogram 8 | inversion algorithms. These algorithms are selected base on their performance and high parallelizability, and can even 9 | be integrated in your model training process. 10 | 11 | We hope this tool can serve as a standard, making fair comparison of different audio generation models. 12 | 13 | ## Installation 14 | 15 | ### PyPi 16 | 17 | First [Install PyTorch](https://pytorch.org/get-started/locally/) with the desired cpu/gpu support and version >= 0.4.1. 18 | Then install via pip 19 | ``` 20 | pip install torch_specinv 21 | ``` 22 | or 23 | ``` 24 | pip install git+https://github.com/yoyololicon/spectrogram-inversion 25 | ``` 26 | to get the latest version. 27 | 28 | 29 | ## Getting Started 30 | The following example estimated the time signal given only the magnitude information of an audio file. 31 | 32 | ```python 33 | import torch 34 | import librosa 35 | from torch_specinv import griffin_lim 36 | from torch_specinv.metrics import spectral_convergence as SC 37 | 38 | y, sr = librosa.load(librosa.util.example_audio_file()) 39 | y = torch.from_numpy(y) 40 | windowsize = 2048 41 | window = torch.hann_window(windowsize) 42 | S = torch.stft(y, windowsize, window=window) 43 | 44 | # discard phase information 45 | mag = S.pow(2).sum(2).sqrt() 46 | 47 | # move to gpu memory for faster computation 48 | mag = mag.cuda() 49 | 50 | yhat = griffin_lim(mag, maxiter=100, alpha=0.3, window=window) 51 | 52 | # check convergence 53 | mag_hat = torch.stft(yhat, windowsize, window=window).pow(2).sum(2).sqrt() 54 | print(SC(mag_hat, mag)) 55 | ``` 56 | 57 | Reconstruct from other spectral representation: 58 | 59 | ```python 60 | from librosa.filters import mel 61 | from torch_specinv import L_BFGS 62 | 63 | filter_banks = torch.from_numpy(mel(sr, windowsize)).cuda() 64 | 65 | def trsfn(x): 66 | S = torch.stft(x, windowsize, window=window).pow(2).sum(2).sqrt() 67 | mel_S = filter_banks @ S 68 | return torch.log1p(mel_S) 69 | 70 | y = y.cuda() 71 | mag = trsfn(y) 72 | yhat = L_BFGS(mag, trsfn, len(y)) 73 | ``` 74 | 75 | ## TODO 76 | - [ ] Speed comparison on GPU. 77 | - [x] Documentation. 78 | - [ ] Examples. -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /doc/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | #import pytorch_sphinx_theme 16 | 17 | sys.path.insert(0, os.path.abspath('../..')) 18 | # sys.path.insert(0, os.path.abspath('../../torch_specinv')) 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = 'torch_specinv' 24 | copyright = '2019, Chin Yun Yu' 25 | author = 'Chin Yun Yu' 26 | 27 | # The full version, including alpha/beta/rc tags 28 | release = '0.1' 29 | 30 | # -- General configuration --------------------------------------------------- 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = ['sphinx.ext.autodoc', 36 | 'sphinx.ext.doctest', 37 | 'sphinx.ext.intersphinx', 38 | 'sphinx.ext.todo', 39 | 'sphinx.ext.coverage', 40 | 'sphinx.ext.viewcode', 41 | 'sphinx.ext.githubpages', 42 | 'sphinx.ext.mathjax', 43 | 'sphinx.ext.napoleon'] 44 | 45 | # Add any paths that contain templates here, relative to this directory. 46 | templates_path = ['_templates'] 47 | 48 | # List of patterns, relative to source directory, that match files and 49 | # directories to ignore when looking for source files. 50 | # This pattern also affects html_static_path and html_extra_path. 51 | exclude_patterns = [] 52 | 53 | # -- Options for HTML output ------------------------------------------------- 54 | 55 | # The theme to use for HTML and HTML Help pages. See the documentation for 56 | # a list of builtin themes. 57 | # 58 | #html_theme = 'alabaster' 59 | #html_theme = 'pytorch_sphinx_theme' 60 | html_theme = 'sphinx_rtd_theme' 61 | # html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] 62 | 63 | 64 | # Theme options are theme-specific and customize the look and feel of a theme 65 | # further. For a list of options available for each theme, see the 66 | # documentation. 67 | # 68 | 69 | 70 | html_theme_options = { 71 | 'collapse_navigation': False, 72 | 'display_version': True, 73 | 'logo_only': True, 74 | } 75 | 76 | 77 | # Add any paths that contain custom static files (such as style sheets) here, 78 | # relative to this directory. They are copied after the builtin static files, 79 | # so a file named "default.css" will overwrite the builtin "default.css". 80 | html_static_path = ['_static'] 81 | 82 | 83 | intersphinx_mapping = { 84 | 'python': ('https://docs.python.org/', None), 85 | 'torch': ('https://pytorch.org/docs/master/', None) 86 | } 87 | 88 | autodoc_mock_imports = ['torch', 'tqdm'] 89 | 90 | # prevent contents.rst error on readthedocs 91 | master_doc = 'index' -------------------------------------------------------------------------------- /doc/source/index.rst: -------------------------------------------------------------------------------- 1 | .. torch_specinv documentation master file, created by 2 | sphinx-quickstart on Thu Oct 10 21:52:34 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | :github_url: https://github.com/yoyololicon/spectrogram-inversion 7 | 8 | PyTorch Spectrogram Inversion Documentation 9 | ============================================ 10 | 11 | A major direction of Deep Learning in audio, especially generative models, is using features in frequency domain because 12 | directly model raw time signal is hard. 13 | But this require an extra process to convert the predicted spectrogram (magnitude-only in most situation) back to time domain. 14 | 15 | To help researcher no need to care this post-precessing step, this package provide some useful and classic spectrogram 16 | inversion algorithms. These algorithms are selected base on their performance and high parallelizability, and can even 17 | be integrated in your model training process. 18 | 19 | We hope this tool can serve as a standard, making fair comparison of different audio generation models. 20 | 21 | Installation 22 | ============ 23 | 24 | PyPi 25 | ~~~~ 26 | 27 | First `Install PyTorch `_ with the desired cpu/gpu support and version >= 0.4.1. 28 | Then install via pip:: 29 | 30 | pip install torch_specinv 31 | 32 | or:: 33 | 34 | pip install git+https://github.com/yoyololicon/spectrogram-inversion 35 | 36 | to get the latest version. 37 | 38 | 39 | 40 | Getting Started 41 | =============== 42 | The following example estimated the time signal given only the magnitude information of an audio file. 43 | 44 | .. code-block:: python 45 | 46 | import torch 47 | import librosa 48 | from torch_specinv import griffin_lim 49 | from torch_specinv.metrics import spectral_convergence as SC 50 | 51 | y, sr = librosa.load(librosa.util.example_audio_file()) 52 | y = torch.from_numpy(y) 53 | windowsize = 2048 54 | window = torch.hann_window(windowsize) 55 | S = torch.stft(y, windowsize, window=window) 56 | 57 | # discard phase information 58 | mag = S.pow(2).sum(2).sqrt() 59 | 60 | # move to gpu memory for faster computation 61 | mag = mag.cuda() 62 | 63 | yhat = griffin_lim(mag, maxiter=100, alpha=0.3, window=window) 64 | 65 | # check convergence 66 | mag_hat = torch.stft(yhat, windowsize, window=window).pow(2).sum(2).sqrt() 67 | print(SC(mag_hat, mag)) 68 | 69 | Reconstruct from other spectral representation: 70 | 71 | .. code-block:: python 72 | 73 | from librosa.filters import mel 74 | from torch_specinv import L_BFGS 75 | 76 | filter_banks = torch.from_numpy(mel(sr, windowsize)).cuda() 77 | 78 | def trsfn(x): 79 | S = torch.stft(x, windowsize, window=window).pow(2).sum(2).sqrt() 80 | mel_S = filter_banks @ S 81 | return torch.log1p(mel_S) 82 | 83 | y = y.cuda() 84 | mag = trsfn(y) 85 | yhat = L_BFGS(mag, trsfn, len(y)) 86 | 87 | .. toctree:: 88 | :glob: 89 | :maxdepth: 1 90 | :caption: Package Reference 91 | 92 | modules/methods 93 | modules/metrics 94 | 95 | Indices and Tables 96 | ================== 97 | 98 | * :ref:`genindex` 99 | * :ref:`modindex` -------------------------------------------------------------------------------- /doc/source/modules/methods.rst: -------------------------------------------------------------------------------- 1 | torch_specinv.methods 2 | ====================== 3 | 4 | .. automodule:: torch_specinv.methods 5 | :members: 6 | :special-members: 7 | :undoc-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /doc/source/modules/metrics.rst: -------------------------------------------------------------------------------- 1 | torch_specinv.metrics 2 | ====================== 3 | 4 | .. automodule:: torch_specinv.metrics 5 | :members: 6 | :special-members: 7 | :undoc-members: 8 | :show-inheritance: 9 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | from librosa import display 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from functools import partial 6 | 7 | import torch 8 | from torch_specinv import * 9 | 10 | if __name__ == '__main__': 11 | 12 | nfft = 1024 13 | winsize = 1024 14 | hopsize = 128 15 | 16 | y, sr = librosa.load(librosa.util.example_audio_file(), duration=30) 17 | # librosa.output.write_wav('origin.wav', y, sr) 18 | y = torch.Tensor(y).cuda() 19 | window = torch.hann_window(winsize).cuda() 20 | 21 | 22 | def spectrogram(x, *args, p=1, **kwargs): 23 | return torch.stft(x, *args, **kwargs).pow(2).sum(2).add_(1e-7).pow(p / 2) 24 | 25 | 26 | arg_dict = { 27 | 'win_length': winsize, 28 | 'window': window, 29 | 'hop_length': hopsize, 30 | 'pad_mode': 'reflect', 31 | 'onesided': True, 32 | 'normalized': False, 33 | 'center': True 34 | } 35 | 36 | #spec = spectrogram(y, nfft, **arg_dict) 37 | func = partial(spectrogram, n_fft=nfft, **arg_dict) 38 | spec = func(y) 39 | # mag = spec.pow(0.5).cpu().numpy() 40 | # phase = np.random.uniform(-np.pi, np.pi, mag.shape) 41 | # _, init_x = istft(mag * np.exp(1j * phase), noverlap=1024 - 256) 42 | 43 | estimated = L_BFGS(spec, func, [len(y)], max_iter=50, lr=1, history_size=10, eva_iter=5) 44 | #estimated = griffin_lim(spec, max_iter=100, alpha=0.3, **arg_dict) 45 | #estimated = ADMM(spec, max_iter=100, rho=0.2, **arg_dict) 46 | # arg_dict['hop_length'] = 333 47 | # estimated = RTISI_LA(spec, maxiter=4, look_ahead=3, asymmetric_window=True, **arg_dict) 48 | #estimated = SPSI(spec, **arg_dict) 49 | # arg_dict.pop('window') 50 | # estimated = PGHI(spec, **arg_dict) 51 | estimated_spec = func(estimated) 52 | #estimated_spec = estimated.pow(2).sum(2).sqrt() 53 | display.specshow(librosa.amplitude_to_db(estimated_spec.cpu().numpy(), ref=np.max), y_axis='log') 54 | plt.show() 55 | 56 | #print(spectral_convergence(estimated_spec, spec)) 57 | 58 | #librosa.output.write_wav('test.wav', estimated.cpu().numpy(), sr) -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | build: 4 | image: latest 5 | 6 | python: 7 | version: 3.6 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6.0 2 | tqdm -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | import setuptools 3 | from torch_specinv import __version__, __email__, name, __maintainer__ 4 | 5 | with open("README.md", "r") as fh: 6 | long_description = fh.read() 7 | 8 | setuptools.setup( 9 | name=name, 10 | version=__version__, 11 | author=__maintainer__, 12 | author_email=__email__, 13 | description="A pytorch package for Spectrogram Inversion", 14 | long_description=long_description, 15 | long_description_content_type="text/markdown", 16 | url="https://github.com/yoyololicon/spectrogram-inversion", 17 | packages=["torch_specinv"], 18 | install_requires=['torch>=1.6.0', 'tqdm'], 19 | classifiers=[ 20 | "Programming Language :: Python :: 3", 21 | "License :: OSI Approved :: MIT License", 22 | "Operating System :: OS Independent", 23 | ], 24 | ) 25 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoyolicoris/spectrogram-inversion/82fa16b65928c33c94ef13f34bf42f6d57960be4/test/__init__.py -------------------------------------------------------------------------------- /test/consts.py: -------------------------------------------------------------------------------- 1 | nfft_list = [ 2 | 128, 256, 512 3 | ] 4 | -------------------------------------------------------------------------------- /test/test_admm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytest 4 | from torch_specinv.methods import ADMM 5 | 6 | from .consts import nfft_list 7 | 8 | 9 | @pytest.mark.parametrize("x_sizes", [(4410,), (2, 4410), (1, 4410)]) 10 | @pytest.mark.parametrize("device", ["cpu"]) 11 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) 12 | @pytest.mark.parametrize("nfft", nfft_list) 13 | def test_empty_args(x_sizes, device, dtype, nfft): 14 | x = torch.randn(*x_sizes, device=device, dtype=dtype) 15 | spec = torch.stft(x, nfft, return_complex=True) 16 | y = ADMM(spec.abs(), max_iter=4) 17 | assert len(y.shape) == len(x.shape) 18 | if len(y.shape) > 1: 19 | assert y.shape[0] == x.shape[0] 20 | assert y.shape[1] <= x.shape[1] 21 | return 22 | 23 | 24 | @pytest.mark.parametrize("win_length, window", [(None, None), 25 | (300, None), 26 | (300, torch.hann_window(300))]) 27 | @pytest.mark.parametrize("hop_length", [None, 128]) 28 | @pytest.mark.parametrize("center", [True, False]) 29 | @pytest.mark.parametrize("normalized", [False, True]) 30 | @pytest.mark.parametrize("onesided", [False, True]) 31 | @pytest.mark.parametrize("pad_mode", ["reflect", "constant", "replicate", "circular"]) 32 | @pytest.mark.parametrize("return_complex", [True, False]) 33 | def test_stft_args( 34 | win_length, 35 | window, 36 | hop_length, 37 | center, 38 | normalized, 39 | onesided, 40 | pad_mode, 41 | return_complex): 42 | x = torch.randn(4410) 43 | n_fft = 512 44 | spec = torch.stft(x, n_fft, 45 | hop_length=hop_length, 46 | win_length=win_length, 47 | window=window, 48 | center=center, 49 | pad_mode=pad_mode, 50 | normalized=normalized, 51 | onesided=onesided, 52 | return_complex=True).abs() 53 | 54 | spec.requires_grad = True 55 | y = ADMM(spec, max_iter=2, 56 | hop_length=hop_length, 57 | win_length=win_length, 58 | window=window, 59 | center=center, 60 | pad_mode=pad_mode, 61 | normalized=normalized, 62 | onesided=onesided, 63 | return_complex=return_complex) 64 | 65 | loss = F.mse_loss(x[:y.shape[0]], y) 66 | loss.backward() 67 | assert hasattr(spec, "grad") 68 | return 69 | -------------------------------------------------------------------------------- /test/test_griffin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytest 4 | from torch_specinv.methods import griffin_lim 5 | 6 | from .consts import nfft_list 7 | 8 | 9 | @pytest.mark.parametrize("x_sizes", [(4410,), (2, 4410), (1, 4410)]) 10 | @pytest.mark.parametrize("device", ["cpu"]) 11 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) 12 | @pytest.mark.parametrize("nfft", nfft_list) 13 | def test_empty_args(x_sizes, device, dtype, nfft): 14 | x = torch.randn(*x_sizes, device=device, dtype=dtype) 15 | spec = torch.stft(x, nfft, return_complex=True) 16 | y = griffin_lim(spec.abs(), max_iter=4) 17 | assert len(y.shape) == len(x.shape) 18 | if len(y.shape) > 1: 19 | assert y.shape[0] == x.shape[0] 20 | assert y.shape[1] <= x.shape[1] 21 | return 22 | 23 | 24 | @pytest.mark.parametrize("win_length, window", [(None, None), 25 | (300, None), 26 | (300, torch.hann_window(300))]) 27 | @pytest.mark.parametrize("hop_length", [None, 128]) 28 | @pytest.mark.parametrize("center", [True, False]) 29 | @pytest.mark.parametrize("normalized", [False, True]) 30 | @pytest.mark.parametrize("onesided", [False, True]) 31 | @pytest.mark.parametrize("pad_mode", ["reflect", "constant", "replicate", "circular"]) 32 | @pytest.mark.parametrize("return_complex", [True, False]) 33 | def test_stft_args( 34 | win_length, 35 | window, 36 | hop_length, 37 | center, 38 | normalized, 39 | onesided, 40 | pad_mode, 41 | return_complex): 42 | x = torch.randn(4410) 43 | n_fft = 512 44 | spec = torch.stft(x, n_fft, 45 | hop_length=hop_length, 46 | win_length=win_length, 47 | window=window, 48 | center=center, 49 | pad_mode=pad_mode, 50 | normalized=normalized, 51 | onesided=onesided, 52 | return_complex=True).abs() 53 | 54 | spec.requires_grad = True 55 | y = griffin_lim(spec, max_iter=2, 56 | hop_length=hop_length, 57 | win_length=win_length, 58 | window=window, 59 | center=center, 60 | pad_mode=pad_mode, 61 | normalized=normalized, 62 | onesided=onesided, 63 | return_complex=return_complex) 64 | 65 | loss = F.mse_loss(x[:y.shape[0]], y) 66 | loss.backward() 67 | assert hasattr(spec, "grad") 68 | return 69 | -------------------------------------------------------------------------------- /test/test_lbfgs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytest 4 | from torch_specinv.methods import L_BFGS 5 | 6 | from .consts import nfft_list 7 | 8 | 9 | @pytest.mark.parametrize("x_sizes", [(4410,), (2, 4410), (1, 4410)]) 10 | @pytest.mark.parametrize("device", ["cpu"]) 11 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) 12 | @pytest.mark.parametrize("nfft", nfft_list) 13 | @pytest.mark.parametrize("metric", ['sc', 'snr', 'ser']) 14 | def test_args(x_sizes, device, dtype, nfft, metric): 15 | x = torch.randn(*x_sizes, device=device, dtype=dtype) 16 | 17 | def trsfn(x): 18 | return torch.stft(x, nfft, return_complex=True).abs() 19 | 20 | spec = trsfn(x) 21 | 22 | y = L_BFGS(spec, trsfn, samples=x.shape, max_iter=10, metric=metric, eva_iter=3) 23 | assert len(y.shape) == len(x.shape) 24 | return 25 | -------------------------------------------------------------------------------- /test/test_rtisila.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytest 4 | from torch_specinv.methods import RTISI_LA 5 | 6 | from .consts import nfft_list 7 | 8 | 9 | @pytest.mark.parametrize("x_sizes", [(4410,), (2, 4410), (1, 4410)]) 10 | @pytest.mark.parametrize("device", ["cpu"]) 11 | @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) 12 | @pytest.mark.parametrize("nfft", nfft_list) 13 | def test_empty_args(x_sizes, device, dtype, nfft): 14 | x = torch.randn(*x_sizes, device=device, dtype=dtype) 15 | spec = torch.stft(x, nfft, return_complex=True) 16 | y = RTISI_LA(spec.abs(), max_iter=4) 17 | assert len(y.shape) == len(x.shape) 18 | if len(y.shape) > 1: 19 | assert y.shape[0] == x.shape[0] 20 | assert y.shape[1] <= x.shape[1] 21 | return 22 | 23 | 24 | @pytest.mark.parametrize("look_ahead", [-1, 2]) 25 | @pytest.mark.parametrize("asymmetric_window", [True, False]) 26 | @pytest.mark.parametrize("win_length, window", [(None, None), 27 | (300, None), 28 | (300, torch.hann_window(300))]) 29 | @pytest.mark.parametrize("hop_length", [None, 128]) 30 | @pytest.mark.parametrize("center", [True, False]) 31 | @pytest.mark.parametrize("normalized", [False, True]) 32 | @pytest.mark.parametrize("onesided", [False, True]) 33 | @pytest.mark.parametrize("pad_mode", ["reflect", "constant", "replicate", "circular"]) 34 | @pytest.mark.parametrize("return_complex", [True, False]) 35 | def test_stft_args( 36 | look_ahead, 37 | asymmetric_window, 38 | win_length, 39 | window, 40 | hop_length, 41 | center, 42 | normalized, 43 | onesided, 44 | pad_mode, 45 | return_complex): 46 | x = torch.randn(4410) 47 | n_fft = 512 48 | spec = torch.stft(x, n_fft, 49 | hop_length=hop_length, 50 | win_length=win_length, 51 | window=window, 52 | center=center, 53 | pad_mode=pad_mode, 54 | normalized=normalized, 55 | onesided=onesided, 56 | return_complex=True).abs() 57 | 58 | spec.requires_grad = True 59 | y = RTISI_LA(spec, max_iter=2, look_ahead=look_ahead, asymmetric_window=asymmetric_window, 60 | hop_length=hop_length, 61 | win_length=win_length, 62 | window=window, 63 | center=center, 64 | pad_mode=pad_mode, 65 | normalized=normalized, 66 | onesided=onesided, 67 | return_complex=return_complex) 68 | 69 | loss = F.mse_loss(x[:y.shape[0]], y) 70 | loss.backward() 71 | assert hasattr(spec, "grad") 72 | return 73 | -------------------------------------------------------------------------------- /torch_specinv/__init__.py: -------------------------------------------------------------------------------- 1 | name = "torch_specinv" 2 | __version__ = '0.2.1' 3 | __maintainer__ = 'Chin-Yun Yu' 4 | __email__ = 'lolimaster.cs03@nctu.edu.tw' 5 | 6 | from .methods import L_BFGS, RTISI_LA, griffin_lim, ADMM, phase_init 7 | -------------------------------------------------------------------------------- /torch_specinv/methods.py: -------------------------------------------------------------------------------- 1 | from .metrics import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.fft as fft 5 | import torch.nn.functional as F 6 | from torch.optim.lbfgs import LBFGS 7 | from tqdm import tqdm 8 | from functools import partial 9 | from typing import Tuple 10 | import math 11 | 12 | pi2 = 2 * math.pi 13 | 14 | _func_mapper = { 15 | 'SC': sc, 16 | 'SNR': snr, 17 | 'SER': ser 18 | } 19 | 20 | 21 | def _args_helper(spec, **stft_kwargs): 22 | """A helper function to get stft arguments from the provided kwargs. 23 | 24 | Args: 25 | spec: The magnitude spectrum of size (*, freq, time). 26 | **stft_kwargs: Keyword arguments that computed spec from 'torch.stft'. 27 | See `torch.stft` for details. 28 | 29 | Returns: 30 | n_fft: FFT size of the spectrum. 31 | processed_kwargs: Dict object that stored the processed keyword arguments. 32 | 33 | """ 34 | args_dict = {'win_length': None, 35 | 'window': None, 36 | 'hop_length': None, 37 | 'center': True, 38 | 'pad_mode': 'reflect', 39 | 'normalized': False, 40 | 'onesided': None, 41 | 'return_complex': None} 42 | for key, item in args_dict.items(): 43 | try: 44 | args_dict[key] = stft_kwargs[key] 45 | except: 46 | pass 47 | win_length, window, hop_length, center, pad_mode, normalized, onesided, return_complex = tuple( 48 | args_dict.values()) 49 | 50 | device = spec.device 51 | dtype = spec.dtype 52 | if dtype == torch.complex32: 53 | dtype = torch.float16 54 | elif dtype == torch.complex64: 55 | dtype = torch.float32 56 | elif dtype == torch.complex128: 57 | dtype = torch.float64 58 | 59 | if onesided is None: 60 | if window is not None and window.is_complex(): 61 | onesided = False 62 | else: 63 | onesided = True 64 | 65 | if onesided: 66 | n_fft = (spec.shape[-2] - 1) * 2 67 | else: 68 | n_fft = spec.shape[-2] 69 | 70 | if not win_length: 71 | win_length = n_fft 72 | 73 | if not hop_length: 74 | hop_length = n_fft // 4 75 | 76 | if window is None: 77 | window = torch.ones(win_length, dtype=dtype, device=device) 78 | 79 | assert n_fft >= win_length 80 | if n_fft > win_length: 81 | window = F.pad(window, [(n_fft - win_length) // 82 | 2, (n_fft - win_length + 1) // 2]) 83 | win_length = n_fft 84 | 85 | args_dict['win_length'] = win_length 86 | args_dict['hop_length'] = hop_length 87 | args_dict['window'] = window 88 | args_dict['return_complex'] = True 89 | args_dict['onesided'] = onesided 90 | 91 | return n_fft, args_dict 92 | 93 | 94 | def _get_ola_weight(window): 95 | ola_weight = torch.diag(window).unsqueeze(1) 96 | return ola_weight 97 | 98 | 99 | def _spec_formatter(spec, **stft_kwargs): 100 | shape = spec.shape 101 | assert 4 > len(shape) > 1 102 | if len(shape) == 2: 103 | spec = spec.unsqueeze(0) 104 | 105 | if not spec.is_complex(): 106 | cmplx_spec = phase_init(spec, **stft_kwargs) 107 | target_spec = spec 108 | else: 109 | cmplx_spec = spec 110 | target_spec = spec.abs() 111 | return cmplx_spec, target_spec 112 | 113 | 114 | def _ola(x, hop_length, weight, padding, norm_envelope=None): 115 | """A helper function to do overlap-and-add. 116 | 117 | Args: 118 | x: input tensor of size :math: '(batch, window_size, time)'. 119 | hop_length: The distance between neighboring sliding window frames. 120 | weight: An identity matrix of size (win_length x win_length) . 121 | norm_envelope: The normalized coefficient apply on synthesis window. 122 | 123 | Returns: 124 | A 1d tensor containing the overlap-and-add result. 125 | 126 | """ 127 | ola_x = F.conv_transpose1d( 128 | x, weight, stride=hop_length, padding=padding).squeeze(1) 129 | if norm_envelope is None: 130 | norm_envelope = F.conv_transpose1d(torch.ones_like( 131 | x[:1]), weight * weight, stride=hop_length, padding=padding).squeeze() 132 | return ola_x / norm_envelope, norm_envelope 133 | 134 | 135 | def _istft(x, n_fft, ola_weight, 136 | win_length, window, hop_length, center, normalized, onesided, pad_mode, return_complex, 137 | norm_envelope=None): 138 | """ 139 | A helper function to do istft. 140 | """ 141 | if onesided: 142 | x = fft.irfft(x, n=n_fft, dim=-2, 143 | norm='ortho' if normalized else 'backward') 144 | else: 145 | x = fft.ifft(x, n=n_fft, dim=-2, 146 | norm='ortho' if normalized else 'backward').real 147 | 148 | x, norm_envelope = _ola(x, hop_length, ola_weight, padding=n_fft // 2 if center else 0, 149 | norm_envelope=norm_envelope) 150 | return x, norm_envelope 151 | 152 | 153 | def _training_loop( 154 | closure, 155 | status_dict, 156 | target, 157 | max_iter, 158 | tol, 159 | verbose, 160 | eva_iter, 161 | metric, 162 | ): 163 | assert eva_iter > 0 164 | assert max_iter > 0 165 | assert tol >= 0 166 | 167 | metric = metric.upper() 168 | assert metric.upper() in _func_mapper.keys() 169 | 170 | bar_dict = {} 171 | bar_dict[metric] = 0 172 | metric_func = _func_mapper[metric] 173 | 174 | criterion = F.mse_loss 175 | init_loss = None 176 | 177 | with tqdm(total=max_iter, disable=not verbose) as pbar: 178 | for i in range(max_iter): 179 | output = closure(status_dict) 180 | if i % eva_iter == eva_iter - 1: 181 | bar_dict[metric] = metric_func(output, target).item() 182 | l2_loss = criterion(output, target).item() 183 | pbar.set_postfix(**bar_dict, loss=l2_loss) 184 | pbar.update(eva_iter) 185 | 186 | if not init_loss: 187 | init_loss = l2_loss 188 | elif (previous_loss - l2_loss) / init_loss < tol and previous_loss > l2_loss: 189 | break 190 | previous_loss = l2_loss 191 | 192 | 193 | def griffin_lim(spec, 194 | max_iter=200, 195 | tol=1e-6, 196 | alpha=0.99, 197 | verbose=True, 198 | eva_iter=10, 199 | metric='sc', 200 | **stft_kwargs): 201 | r"""Reconstruct spectrogram phase using the will known `Griffin-Lim`_ algorithm and its variation, `Fast Griffin-Lim`_. 202 | 203 | 204 | .. _`Griffin-Lim`: https://pdfs.semanticscholar.org/14bc/876fae55faf5669beb01667a4f3bd324a4f1.pdf 205 | .. _`Fast Griffin-Lim`: https://perraudin.info/publications/perraudin-note-002.pdf 206 | 207 | Args: 208 | spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude) or :math:`(N \times T \times 2)` 209 | (complex input). If a magnitude spectrogram is given, the phase will first be intialized using 210 | :func:`torch_specinv.methods.phase_init`; otherwise start from the complex input. 211 | max_iter (int): maximum number of iterations before timing out. 212 | tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6`` 213 | alpha (float): speedup parameter used in `Fast Griffin-Lim`_, set it to zero will disable it. Default: ``0`` 214 | verbose (bool): whether to be verbose. Default: :obj:`True` 215 | eva_iter (int): steps size for evaluation. After each step, the function defined in `metric` will evaluate. Default: ``10`` 216 | metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'`` 217 | **stft_kwargs: other arguments that pass to :func:`torch.stft` 218 | 219 | Returns: 220 | A 1d tensor converted from the given spectrogram 221 | 222 | """ 223 | assert alpha >= 0 224 | 225 | cmplx_spec, target_spec = _spec_formatter(spec, **stft_kwargs) 226 | n_fft, processed_args = _args_helper(target_spec, **stft_kwargs) 227 | ola_weight = _get_ola_weight(processed_args['window']) 228 | 229 | istft = partial(_istft, n_fft=n_fft, ola_weight=ola_weight, 230 | **processed_args) 231 | 232 | pre_spec = cmplx_spec.clone() 233 | x, norm_envelope = istft(cmplx_spec) 234 | 235 | lr = alpha / (1 + alpha) 236 | 237 | def closure(status_dict): 238 | x = status_dict['x'] 239 | pre_spec = status_dict['pre_spec'] 240 | 241 | new_spec = torch.stft(x, n_fft, **processed_args) 242 | output = new_spec.abs() 243 | new_spec = new_spec - pre_spec * lr 244 | status_dict['pre_spec'] = new_spec 245 | 246 | norm = new_spec.abs().add_(1e-16) 247 | new_spec = new_spec * target_spec / norm 248 | x, _ = istft(new_spec, norm_envelope=norm_envelope) 249 | status_dict['x'] = x 250 | return output 251 | 252 | stats = { 253 | 'x': x, 254 | 'pre_spec': pre_spec 255 | } 256 | 257 | _training_loop( 258 | closure, 259 | stats, 260 | target_spec, 261 | max_iter, 262 | tol, 263 | verbose, 264 | eva_iter, 265 | metric 266 | ) 267 | x = stats['x'] 268 | if not (spec.shape[0] == 1 and len(spec.shape) == 3): 269 | x = x.squeeze(0) 270 | return x 271 | 272 | 273 | def RTISI_LA(spec, look_ahead=-1, asymmetric_window=False, max_iter=25, alpha=0.99, verbose=1, **stft_kwargs): 274 | r""" 275 | Reconstruct spectrogram phase using `Real-Time Iterative Spectrogram Inversion with Look Ahead`_ (RTISI-LA). 276 | 277 | .. _`Real-Time Iterative Spectrogram Inversion with Look Ahead`: 278 | https://lonce.org/home/Publications/publications/2007_RealtimeSignalReconstruction.pdf 279 | 280 | 281 | Args: 282 | spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude). 283 | look_ahead (int): how many future frames will be consider. ``-1`` will set it to ``(win_length - 1) / hop_length``, 284 | ``0`` will disable look-ahead strategy and fall back to original RTISI algorithm. Default: ``-1`` 285 | asymmetric_window (bool): whether to apply asymmetric window on the first iteration for new coming frame. 286 | max_iter (int): number of iterations for each step. 287 | alpha (float): speedup parameter used in `Fast Griffin-Lim`_, set it to zero will disable it. Default: ``0`` 288 | verbose (bool): whether to be verbose. Default: :obj:`True` 289 | **stft_kwargs: other arguments that pass to :func:`torch.stft`. 290 | 291 | Returns: 292 | A 1d tensor converted from the given spectrogram 293 | 294 | """ 295 | assert max_iter > 0 296 | assert alpha >= 0 297 | assert not spec.is_complex() 298 | 299 | shape = spec.shape 300 | assert 4 > len(shape) > 1 301 | target_spec = spec 302 | if len(shape) == 2: 303 | target_spec = target_spec.unsqueeze(0) 304 | 305 | n_fft, processed_args = _args_helper(target_spec, **stft_kwargs) 306 | ola_weight = _get_ola_weight(processed_args['window']) 307 | 308 | copyed_kwargs = stft_kwargs.copy() 309 | copyed_kwargs['center'] = False 310 | copyed_kwargs['return_complex'] = True 311 | 312 | win_length = processed_args['win_length'] 313 | hop_length = processed_args['hop_length'] 314 | onesided = processed_args['onesided'] 315 | normalized = processed_args['normalized'] 316 | 317 | window = processed_args['window'] 318 | synth_coeff = hop_length / (window @ window) 319 | 320 | # ola_weight = ola_weight * synth_coeff 321 | 322 | num_keep = (win_length - 1) // hop_length 323 | if look_ahead < 0: 324 | look_ahead = num_keep 325 | 326 | asym_window1 = target_spec.new_zeros(win_length) 327 | for i in range(num_keep): 328 | asym_window1[(i + 1) * hop_length:] += window.flip(0)[:- 329 | (i + 1) * hop_length] 330 | asym_window1 *= synth_coeff 331 | 332 | asym_window2 = target_spec.new_zeros(win_length) 333 | for i in range(num_keep + 1): 334 | asym_window2[i * 335 | hop_length:] += window.flip(0)[:-i * hop_length if i else None] 336 | asym_window2 *= synth_coeff 337 | 338 | steps = target_spec.shape[2] 339 | target_spec = F.pad(target_spec, [look_ahead, look_ahead]) 340 | 341 | if onesided: 342 | irfft = partial(fft.irfft, n=n_fft, dim=-2, 343 | norm='ortho' if normalized else 'backward') 344 | rfft = partial(fft.rfft, n=n_fft, dim=-2, 345 | norm='ortho' if normalized else 'backward') 346 | else: 347 | def irfft(x): return fft.ifft(x, n=n_fft, dim=-2, 348 | norm='ortho' if normalized else 'backward').real 349 | def rfft(x): return fft.fft(x, n=n_fft, dim=-2, 350 | norm='ortho' if normalized else 'backward') 351 | 352 | # initialize first frame with zero phase 353 | first_frame = target_spec[..., look_ahead, None] 354 | keeped_chunk = target_spec.new_zeros(target_spec.shape[0], n_fft, num_keep) 355 | update_chunk = target_spec.new_zeros( 356 | target_spec.shape[0], n_fft, look_ahead) 357 | update_chunk = torch.cat((update_chunk, 358 | irfft(first_frame + 0j)), 2) 359 | 360 | lr = alpha / (1 + alpha) 361 | output_xt_list = [] 362 | with tqdm(total=steps + look_ahead, disable=not verbose) as pbar: 363 | for i in range(steps + look_ahead): 364 | for j in range(max_iter): 365 | x, _ = _ola(torch.cat((keeped_chunk, 366 | update_chunk), 2), 367 | hop_length, 368 | ola_weight * synth_coeff, padding=0, norm_envelope=1) 369 | 370 | x = x[:, num_keep * hop_length:] 371 | if asymmetric_window: 372 | xt_winview = x.unfold( 373 | 1, win_length, hop_length).transpose(1, 2) 374 | xt_norm_wind = xt_winview[:, :, :-1] * window[:, None] 375 | if j: 376 | xt_asym_wind = xt_winview[:, 377 | :, -1:] * asym_window2[:, None] 378 | else: 379 | xt_asym_wind = xt_winview[:, 380 | :, -1:] * asym_window1[:, None] 381 | 382 | xt = torch.cat((xt_norm_wind, xt_asym_wind), 2) 383 | new_spec = rfft(xt) 384 | else: 385 | new_spec = torch.stft(x, n_fft=n_fft, **copyed_kwargs) 386 | 387 | if j: 388 | new_spec = new_spec - lr * pre_spec 389 | elif i: 390 | new_spec = torch.cat( 391 | (new_spec[:, :, :-1] - lr * pre_spec[:, :, 1:], new_spec[:, :, -1:]), 2) 392 | pre_spec = new_spec 393 | 394 | norm = new_spec.abs() + 1e-16 395 | new_spec = new_spec * \ 396 | target_spec[..., i:i + look_ahead + 1] / norm 397 | 398 | update_chunk = irfft(new_spec) 399 | 400 | pbar.update() 401 | output_xt_list.append(update_chunk[:, :, 0]) 402 | keeped_chunk = torch.cat( 403 | (keeped_chunk[:, :, 1:], update_chunk[:, :, :1]), 2) 404 | update_chunk = F.pad(update_chunk[:, :, 1:], [0, 1]) 405 | 406 | all_xt = torch.stack(output_xt_list[look_ahead if look_ahead else 0:], 2) 407 | x, _ = _ola(all_xt, hop_length, ola_weight, padding=win_length // 408 | 2 if processed_args['center'] else 0) 409 | 410 | if not (spec.shape[0] == 1 and len(spec.shape) == 3): 411 | x = x.squeeze(0) 412 | return x 413 | 414 | 415 | def ADMM(spec, max_iter=1000, tol=1e-6, rho=0.1, verbose=1, eva_iter=10, metric='sc', **stft_kwargs): 416 | r""" 417 | Reconstruct spectrogram phase using `Griffin–Lim Like Phase Recovery via Alternating Direction Method of Multipliers`_ . 418 | 419 | .. _`Griffin–Lim Like Phase Recovery via Alternating Direction Method of Multipliers`: 420 | https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8552369 421 | 422 | Args: 423 | spec (Tensor): the input tensor of size :math:`(N \times T)` (magnitude) or :math:`(N \times T \times 2)` 424 | (complex input). If a magnitude spectrogram is given, the phase will first be intialized using 425 | :func:`torch_specinv.methods.phase_init`; otherwise start from the complex input. 426 | max_iter (int): maximum number of iterations before timing out. 427 | tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6`` 428 | rho (float): non-negative speedup parameter. Small value is preferable when the input spectrogram is noisy (inperfect); 429 | set it to 1 will behave similar to ``griffin_lim``. Default: ``0.1`` 430 | verbose (bool): whether to be verbose. Default: :obj:`True` 431 | eva_iter (int): steps size for evaluation. After each step, the function defined in ``metric`` will evaluate. Default: ``10`` 432 | metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'`` 433 | **stft_kwargs: other arguments that pass to :func:`torch.stft`. 434 | 435 | 436 | Returns: 437 | A 1d tensor converted from the given spectrogram 438 | 439 | """ 440 | assert eva_iter > 0 441 | assert max_iter > 0 442 | assert tol >= 0 443 | assert metric.upper() in list(_func_mapper.keys()) 444 | 445 | cmplx_spec, target_spec = _spec_formatter(spec, **stft_kwargs) 446 | n_fft, processed_args = _args_helper(target_spec, **stft_kwargs) 447 | ola_weight = _get_ola_weight(processed_args['window']) 448 | 449 | istft = partial(_istft, n_fft=n_fft, ola_weight=ola_weight, 450 | **processed_args) 451 | 452 | X = cmplx_spec 453 | x, norm_envelope = istft(X) 454 | Z = X.clone() 455 | Y = X.clone() 456 | U = torch.zeros_like(X) 457 | 458 | def closure(status_dict): 459 | X = status_dict['X'] 460 | Y = status_dict['Y'] 461 | U = status_dict['U'] 462 | x = status_dict['x'] 463 | 464 | reconstruted = torch.stft(x, n_fft, **processed_args) 465 | output = reconstruted.abs() 466 | 467 | Z = (rho * Y + reconstruted) / (1 + rho) 468 | U = U + X - Z 469 | 470 | # Pc2 471 | X = Z - U 472 | norm = X.abs() + 1e-16 473 | X = X * target_spec / norm 474 | 475 | Y = X + U 476 | # Pc1 477 | x, _ = istft(Y, norm_envelope=norm_envelope) 478 | 479 | status_dict['Y'] = Y 480 | status_dict['X'] = X 481 | status_dict['U'] = U 482 | status_dict['x'] = x 483 | return output 484 | 485 | stats = { 486 | 'Y': Y, 487 | 'U': U, 488 | 'X': X, 489 | 'x': x 490 | } 491 | 492 | _training_loop( 493 | closure, 494 | stats, 495 | target_spec, 496 | max_iter, 497 | tol, 498 | verbose, 499 | eva_iter, 500 | metric 501 | ) 502 | 503 | x = stats['x'] 504 | if not (spec.shape[0] == 1 and len(spec.shape) == 3): 505 | x = x.squeeze(0) 506 | return x 507 | 508 | 509 | def L_BFGS(spec, transform_fn, samples=None, init_x0=None, outer_max_iter=1000, tol=1e-6, verbose=1, eva_iter=10, metric='sc', 510 | **kwargs): 511 | r""" 512 | 513 | Reconstruct spectrogram phase using `Inversion of Auditory Spectrograms, Traditional Spectrograms, and Other 514 | Envelope Representations`_, where I directly use the :class:`torch.optim.LBFGS` optimizer provided in PyTorch. 515 | This method doesn't restrict to traditional short-time Fourier Transform, but any kinds of presentation (ex: Mel-scaled Spectrogram) as 516 | long as the transform function is differentiable. 517 | 518 | .. _`Inversion of Auditory Spectrograms, Traditional Spectrograms, and Other Envelope Representations`: 519 | https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=6949659 520 | 521 | Args: 522 | spec (Tensor): the input presentation. 523 | transform_fn: a function that has the form ``spec = transform_fn(x)`` where x is an 1d tensor. 524 | samples (int, optional): number of samples in time domain. Default: :obj:`None` 525 | init_x0 (Tensor, optional): an 1d tensor that make use as initial time domain samples. If not provided, will use random 526 | value tensor with length equal to ``samples``. 527 | outer_max_iter (int): maximum number of iterations before timing out. 528 | tol (float): tolerance of the stopping condition base on L2 loss. Default: ``1e-6``. 529 | verbose (bool): whether to be verbose. Default: :obj:`True` 530 | eva_iter (int): steps size for evaluation. After each step, the function defined in ``metric`` will evaluate. Default: ``10`` 531 | metric (str): evaluation function. Currently available functions: ``'sc'`` (spectral convergence), ``'snr'`` or ``'ser'``. Default: ``'sc'`` 532 | **kwargs: other arguments that pass to :class:`torch.optim.LBFGS`. 533 | 534 | Returns: 535 | A 1d tensor converted from the given presentation 536 | """ 537 | if init_x0 is None: 538 | init_x0 = spec.new_empty(*samples).normal_(std=1e-6) 539 | x = nn.Parameter(init_x0) 540 | T = spec 541 | 542 | criterion = nn.MSELoss() 543 | optimizer = LBFGS([x], **kwargs) 544 | 545 | def inner_closure(): 546 | optimizer.zero_grad() 547 | V = transform_fn(x) 548 | loss = criterion(V, T) 549 | loss.backward() 550 | return loss 551 | 552 | def outer_closure(status_dict): 553 | optimizer.step(inner_closure) 554 | with torch.no_grad(): 555 | V = transform_fn(x) 556 | return V 557 | 558 | _training_loop( 559 | outer_closure, 560 | {}, 561 | T, 562 | outer_max_iter, 563 | tol, 564 | verbose, 565 | eva_iter, 566 | metric 567 | ) 568 | 569 | return x.detach() 570 | 571 | 572 | def phase_init(spec, **stft_kwargs): 573 | r""" 574 | A phase initialize function that can be seen as a simplified version of `Single Pass Spectrogram Inversion`_. 575 | 576 | .. _`Single Pass Spectrogram Inversion`: 577 | https://ieeexplore.ieee.org/document/7251907 578 | 579 | Args: 580 | spec (Tensor): the input tensor of size :math:`(* \times N \times T)` (magnitude). 581 | **stft_kwargs: other arguments that pass to :func:`torch.stft` 582 | 583 | Returns: 584 | The estimated complex value spectrogram of size :math:`(N \times T \times 2)` 585 | """ 586 | assert not spec.is_complex() 587 | shape = spec.shape 588 | if len(spec.shape) == 2: 589 | spec = spec.unsqueeze(0) 590 | assert len(spec.shape) == 3 591 | 592 | n_fft, processed_args = _args_helper(spec, **stft_kwargs) 593 | hop_length = processed_args['hop_length'] 594 | 595 | phase = torch.zeros_like(spec) 596 | 597 | mask = (spec[:, 1:-1] > spec[:, 2:]) & (spec[:, 1:-1] > spec[:, :-2]) 598 | mask = F.pad(mask, [0, 0, 1, 1]) 599 | 600 | b = torch.masked_select(spec, mask) 601 | a = torch.masked_select(spec[:, :-1], mask[:, 1:]) 602 | r = torch.masked_select(spec[:, 1:], mask[:, :-1]) 603 | idx1, idx2, idx3 = torch.nonzero(mask).t().unbind() 604 | p = 0.5 * (a - r) / (a - 2 * b + r) 605 | omega = pi2 * (idx2.float() + p) / n_fft * hop_length 606 | 607 | phase[idx1, idx2, idx3] = omega 608 | phase[idx1, idx2 - 1, idx3] = omega 609 | phase[idx1, idx2 + 1, idx3] = omega 610 | 611 | phase = torch.cumsum(phase, 2) 612 | angle = torch.exp(phase * 1j) 613 | 614 | spec = spec * angle 615 | return spec.view(shape) 616 | -------------------------------------------------------------------------------- /torch_specinv/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def sc(input, target): 5 | r""" 6 | The Spectral Convergence score is calculated as follow: 7 | 8 | .. math:: 9 | \mathcal{C}(\mathbf{\hat{S}}, \mathbf{S})=\frac{\|\mathbf{\hat{S}}-\mathbf{S}\|_{Fro}}{\|\mathbf{S}\|_{Fro}} 10 | 11 | Returns: 12 | scalar output in db scale. 13 | """ 14 | return 20 * ((input - target).norm().log10() - target.norm().log10()) 15 | 16 | 17 | def snr(input, target): 18 | r""" 19 | The Signal-to-Noise Ratio (SNR) is calculated as follow: 20 | 21 | .. math:: 22 | SNR(\mathbf{\hat{S}}, \mathbf{S})= 23 | 10\log_{10}\frac{1}{\sum (\frac{\hat{s}_i}{\|\mathbf{\hat{S}}\|_{Fro}} - \frac{s_i}{\|\mathbf{S}\|_{Fro}})^2} 24 | 25 | Returns: 26 | scalar output. 27 | """ 28 | norm = target.norm() 29 | return -10 * (input / norm - target / norm).pow(2).sum().log10() 30 | 31 | 32 | def ser(input, target): 33 | r""" 34 | The Signal-to-Error Ratio (SER) is calculated as follow: 35 | 36 | .. math:: 37 | SER(\mathbf{\hat{S}}, \mathbf{S})= 38 | 10\log_{10}\frac{\sum \hat{s}_i^2}{\sum (\hat{s}_i - s_i)^2} 39 | 40 | Returns: 41 | scalar output. 42 | """ 43 | return 10 * (input.pow(2).sum().log10() - (input - target).pow(2).sum().log10()) 44 | -------------------------------------------------------------------------------- /torch_specinv/utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoyolicoris/spectrogram-inversion/82fa16b65928c33c94ef13f34bf42f6d57960be4/torch_specinv/utils.py --------------------------------------------------------------------------------