├── .github └── workflows │ ├── black.yml │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── CITATION.cff ├── LICENSE ├── MANIFEST.in ├── README.md ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── test_extension.py ├── test_grad.py └── test_vmap.py └── torchlpc ├── __init__.py ├── core.py ├── csrc ├── cuda │ ├── LICENSE.txt │ ├── linear_recurrence.cu │ └── lpc.cu └── scan_cpu.cpp ├── parallel_scan.py └── recurrence.py /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: psf/black@stable -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ${{matrix.os}} 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.9", "3.10", "3.11", "3.12"] 20 | os: [ubuntu-latest, macos-latest] 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install libomp on macOS 29 | if: matrix.os == 'macos-latest' 30 | run: | 31 | brew install libomp 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | python -m pip install flake8 pytest 36 | pip install "numpy<2.0" numba 37 | - name: Install torch (mac) 38 | if: matrix.os == 'macos-latest' 39 | run: pip install torch 40 | - name: Install torch (ubuntu) 41 | if: matrix.os == 'ubuntu-latest' 42 | run: pip install torch --index-url https://download.pytorch.org/whl/cpu 43 | - name: Lint with flake8 44 | run: | 45 | # stop the build if there are Python syntax errors or undefined names 46 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 47 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 48 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 49 | - name: Build CPP extension with clang++ 50 | if: matrix.os == 'macos-latest' 51 | run: | 52 | export CXX=$(brew --prefix llvm@15)/bin/clang++ 53 | export LDFLAGS="-L/usr/local/opt/libomp/lib" 54 | export CPPFLAGS="-I/usr/local/opt/libomp/include" 55 | python setup.py build 56 | find build/ -name "_C*.so" -exec cp {} ./torchlpc/ \; 57 | - name: Build CPP extension with g++ 58 | if: matrix.os == 'ubuntu-latest' 59 | run: | 60 | python setup.py build 61 | find build/ -name "_C*.so" -exec cp {} ./torchlpc/ \; 62 | - name: Test with pytest 63 | run: | 64 | pytest 65 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | push: 15 | tags: 16 | - v0.* 17 | - v1.* 18 | 19 | permissions: 20 | contents: read 21 | 22 | jobs: 23 | deploy: 24 | 25 | runs-on: ubuntu-latest 26 | 27 | steps: 28 | - uses: actions/checkout@v3 29 | - name: Set up Python 30 | uses: actions/setup-python@v3 31 | with: 32 | python-version: '3.x' 33 | - name: Install dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install setuptools numpy 37 | pip install torch --index-url https://download.pytorch.org/whl/cpu 38 | - name: Build package 39 | run: python setup.py sdist 40 | - name: Publish package 41 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 42 | with: 43 | user: __token__ 44 | password: ${{ secrets.PYPI_API_TOKEN }} 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode/ 3 | *.nbc 4 | *.nbi 5 | build/ 6 | *.egg* -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Yu" 5 | given-names: "Chin-Yun" 6 | orcid: "https://orcid.org/0000-0003-3782-2713" 7 | title: "TorchLPC: fast, efficient, and differentiable time-varying LPC filtering in PyTorch" 8 | version: 0.3.1 9 | date-released: 2023-07-09 10 | url: "https://github.com/DiffAPF/torchlpc" 11 | keywords: 12 | - differentiable DSP 13 | - all-pole filters 14 | - linear prediction 15 | license: MIT 16 | preferred-citation: 17 | type: generic 18 | title: "Differentiable All-pole Filters for Time-varying Audio Systems" 19 | authors: 20 | - given-names: Chin-Yun 21 | family-names: Yu 22 | email: chin-yun.yu@qmul.ac.uk 23 | affiliation: Queen Mary University of London 24 | orcid: 'https://orcid.org/0000-0003-3782-2713' 25 | - given-names: Christopher 26 | family-names: Mitcheltree 27 | email: c.mitcheltree@qmul.ac.uk 28 | affiliation: Queen Mary University of London 29 | - given-names: Alistair 30 | family-names: Carson 31 | email: alistair.carson@ed.ac.uk 32 | affiliation: University of Edinburgh 33 | - given-names: Stefan 34 | family-names: Bilbao 35 | email: sbilbao@ed.ac.uk 36 | affiliation: University of Edinburgh 37 | - given-names: Joshua D. 38 | family-names: Reiss 39 | email: joshua.reiss@qmul.ac.uk 40 | affiliation: Queen Mary University of London 41 | - given-names: György 42 | family-names: Fazekas 43 | email: george.fazekas@qmul.ac.uk 44 | affiliation: Queen Mary University of London 45 | status: preprint 46 | month: 4 47 | year: 2024 48 | identifiers: 49 | - type: other 50 | value: "arXiv:2404.07970" 51 | description: The ArXiv preprint of the paper 52 | url: "https://diffapf.github.io/web/" 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chin-Yun Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include torchlpc *.py 2 | recursive-include torchlpc *.h 3 | recursive-include torchlpc *.cpp 4 | recursive-include torchlpc *.c 5 | recursive-include torchlpc *.cu 6 | recursive-include tests *.py 7 | recursive-exclude * __pycache__ 8 | recursive-exclude * *.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchLPC 2 | [![PyPI version](https://badge.fury.io/py/torchlpc.svg)](https://badge.fury.io/py/torchlpc) 3 | 4 | `torchlpc` provides a PyTorch implementation of the Linear Predictive Coding (LPC) filter, also known as all-pole filter. 5 | It's fast, differentiable, and supports batched inputs with time-varying filter coefficients. 6 | 7 | Given an input signal $`\mathbf{x} \in \mathbb{R}^T`$ and time-varying LPC coefficients $`\mathbf{A} \in \mathbb{R}^{T \times N}`$ with an order of $`N`$, the LPC filter is defined as: 8 | 9 | $$ 10 | y_t = x_t - \sum_{i=1}^N A_{t,i} y_{t-i}. 11 | $$ 12 | 13 | ## Usage 14 | 15 | ```python 16 | 17 | import torch 18 | from torchlpc import sample_wise_lpc 19 | 20 | # Create a batch of 10 signals, each with 100 time steps 21 | x = torch.randn(10, 100) 22 | 23 | # Create a batch of 10 sets of LPC coefficients, each with 100 time steps and an order of 3 24 | A = torch.randn(10, 100, 3) 25 | 26 | # Apply LPC filtering 27 | y = sample_wise_lpc(x, A) 28 | 29 | # Optionally, you can provide initial values for the output signal (default is 0) 30 | zi = torch.randn(10, 3) 31 | y = sample_wise_lpc(x, A, zi=zi) 32 | 33 | # Return the delay values similar to `scipy.signal.lfilter` 34 | y, zf = sample_wise_lpc(x, A, zi=zi, return_zf=True) 35 | ``` 36 | 37 | 38 | ## Installation 39 | 40 | ```bash 41 | pip install torchlpc 42 | ``` 43 | 44 | or from source 45 | 46 | ```bash 47 | pip install git+https://github.com/DiffAPF/torchlpc.git 48 | ``` 49 | 50 | If you want to run it on NVIDIA GPU, make sure you have CUDA toolkit installed, with a verion compatible with your PyTorch installation. 51 | 52 | ### MacOS 53 | 54 | To compile with OpenMP support on MacOS, you need to install `libomp` via Homebrew. 55 | Also, use `llvm@15` as the C++ compiler to ensure compatibility with OpenMP. 56 | 57 | ```bash 58 | brew install libomp 59 | export CXX=$(brew --prefix llvm@15)/bin/clang++ 60 | export LDFLAGS="-L/usr/local/opt/libomp/lib" 61 | export CPPFLAGS="-I/usr/local/opt/libomp/include" 62 | ``` 63 | 64 | After performing the above steps, you can install `torchlpc` as usual. 65 | 66 | ## Derivation of the gradients of the LPC filter 67 | 68 | The details of the derivation can be found in our preprints[^1][^2]. 69 | We show that, given the instataneous gradient $\frac{\partial \mathcal{L}}{\partial y_t}$ where $\mathcal{L}$ is the loss function, the gradients of the LPC filter with respect to the input signal $\bf x$ and the filter coefficients $\bf A$ can be expresssed also through a time-varying filter: 70 | 71 | ```math 72 | \frac{\partial \mathcal{L}}{\partial x_t} 73 | = \frac{\partial \mathcal{L}}{\partial y_t} 74 | - \sum_{i=1}^{N} A_{t+i,i} \frac{\partial \mathcal{L}}{\partial x_{t+i}} 75 | ``` 76 | 77 | $$ 78 | \frac{\partial \mathcal{L}}{\partial \bf A} 79 | = -\begin{vmatrix} 80 | \frac{\partial \mathcal{L}}{\partial x_1} & 0 & \dots & 0 \\ 81 | 0 & \frac{\partial \mathcal{L}}{\partial x_2} & \dots & 0 \\ 82 | \vdots & \vdots & \ddots & \vdots \\ 83 | 0 & 0 & \dots & \frac{\partial \mathcal{L}}{\partial x_t} 84 | \end{vmatrix} 85 | \begin{vmatrix} 86 | y_0 & y_{-1} & \dots & y_{-N + 1} \\ 87 | y_1 & y_0 & \dots & y_{-N + 2} \\ 88 | \vdots & \vdots & \ddots & \vdots \\ 89 | y_{T-1} & y_{T - 2} & \dots & y_{T - N} 90 | \end{vmatrix}. 91 | $$ 92 | 93 | ### Gradients for the initial condition $`y_t|_{t \leq 0}`$ 94 | 95 | The initial conditions provide an entry point at $t=1$ for filtering, as we cannot evaluate $t=-\infty$. 96 | Let us assume $`A_{t, :}|_{t \leq 0} = 0`$ so $`y_t|_{t \leq 0} = x_t|_{t \leq 0}`$, which also means $`\frac{\partial \mathcal{L}}{\partial y_t}|_{t \leq 0} = \frac{\partial \mathcal{L}}{\partial x_t}|_{t \leq 0}`$. 97 | Thus, the initial condition gradients are 98 | 99 | $$ 100 | \frac{\partial \mathcal{L}}{\partial y_t} 101 | = \frac{\partial \mathcal{L}}{\partial x_t} 102 | = -\sum_{i=1-t}^{N} A_{t+i,i} \frac{\partial \mathcal{L}}{\partial x_{t+i}} \quad \text{for } -N < t \leq 0. 103 | $$ 104 | 105 | In practice, we pad $N$ and $N \times N$ zeros to the beginning of $\frac{\partial \mathcal{L}}{\partial \bf y}$ and $\mathbf{A}$ before evaluating $\frac{\partial \mathcal{L}}{\partial \bf x}$. 106 | The first $N$ outputs are the gradients to $`y_t|_{t \leq 0}`$ and the rest are to $`x_t|_{t > 0}`$. 107 | 108 | ### Time-invariant filtering 109 | 110 | In the time-invariant setting, $`A_{t, i} = A_{1, i} \forall t \in [1, T]`$ and the filter is simplified to 111 | 112 | ```math 113 | y_t = x_t - \sum_{i=1}^N a_i y_{t-i}, \mathbf{a} = A_{1,:}. 114 | ``` 115 | 116 | The gradients $`\frac{\partial \mathcal{L}}{\partial \mathbf{x}}`$ are filtering $`\frac{\partial \mathcal{L}}{\partial \mathbf{y}}`$ with $\mathbf{a}$ backwards in time, same as in the time-varying case. 117 | $\frac{\partial \mathcal{L}}{\partial \mathbf{a}}$ is simply doing a vector-matrix multiplication: 118 | 119 | $$ 120 | \frac{\partial \mathcal{L}}{\partial \mathbf{a}^T} = 121 | -\frac{\partial \mathcal{L}}{\partial \mathbf{x}^T} 122 | \begin{vmatrix} 123 | y_0 & y_{-1} & \dots & y_{-N + 1} \\ 124 | y_1 & y_0 & \dots & y_{-N + 2} \\ 125 | \vdots & \vdots & \ddots & \vdots \\ 126 | y_{T-1} & y_{T - 2} & \dots & y_{T - N} 127 | \end{vmatrix}. 128 | $$ 129 | 130 | This algorithm is more efficient than [^3] because it only needs one pass of filtering to get the two gradients while the latter needs two. 131 | 132 | [^1]: [Differentiable All-pole Filters for Time-varying Audio Systems](https://arxiv.org/abs/2404.07970). 133 | [^2]: [Differentiable Time-Varying Linear Prediction in the Context of End-to-End Analysis-by-Synthesis](https://arxiv.org/abs/2406.05128). 134 | [^3]: [Singing Voice Synthesis Using Differentiable LPC and Glottal-Flow-Inspired Wavetables](https://arxiv.org/abs/2306.17252). 135 | 136 | ## TODO 137 | 138 | - [x] Use PyTorch C++ extension for faster computation. 139 | - [x] Use native CUDA kernels for GPU computation. 140 | - [ ] Support Metal for MacOS. 141 | - [ ] Add examples. 142 | 143 | ## Related Projects 144 | 145 | - [torchcomp](https://github.com/DiffAPF/torchcomp): differentiable compressors that use `torchlpc` for differentiable backpropagation. 146 | - [jaxpole](https://github.com/rodrigodzf/jaxpole): equivalent implementation in JAX by @rodrigodzf. 147 | 148 | ## Citation 149 | 150 | If you find this repository useful in your research, please cite our work with the following BibTex entries: 151 | 152 | ```bibtex 153 | @inproceedings{ycy2024diffapf, 154 | title={Differentiable All-pole Filters for Time-varying Audio Systems}, 155 | author={Chin-Yun Yu and Christopher Mitcheltree and Alistair Carson and Stefan Bilbao and Joshua D. Reiss and György Fazekas}, 156 | booktitle={International Conference on Digital Audio Effects (DAFx)}, 157 | year={2024}, 158 | pages={345--352}, 159 | } 160 | 161 | @inproceedings{ycy2024golf, 162 | title = {Differentiable Time-Varying Linear Prediction in the Context of End-to-End Analysis-by-Synthesis}, 163 | author = {Chin-Yun Yu and György Fazekas}, 164 | year = {2024}, 165 | booktitle = {Proc. Interspeech}, 166 | pages = {1820--1824}, 167 | doi = {10.21437/Interspeech.2024-1187}, 168 | } 169 | ``` 170 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | numba 3 | torch>=2.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import os 3 | import glob 4 | import torch 5 | from torch.utils.cpp_extension import ( 6 | CppExtension, 7 | CUDAExtension, 8 | BuildExtension, 9 | CUDA_HOME, 10 | ) 11 | 12 | library_name = "torchlpc" 13 | VERSION = "0.7.1" 14 | MAINTAINER = "Chin-Yun Yu" 15 | EMAIL = "chin-yun.yu@qmul.ac.uk" 16 | 17 | 18 | with open("README.md", "r") as fh: 19 | long_description = fh.read() 20 | 21 | 22 | # if torch.__version__ >= "2.6.0": 23 | # py_limited_api = True 24 | # else: 25 | py_limited_api = False 26 | 27 | 28 | def get_extensions(): 29 | use_cuda = torch.cuda.is_available() and CUDA_HOME is not None 30 | use_openmp = torch.backends.openmp.is_available() 31 | extension = CUDAExtension if use_cuda else CppExtension 32 | 33 | extra_link_args = [] 34 | extra_compile_args = {} 35 | if use_openmp: 36 | extra_compile_args["cxx"] = ["-fopenmp"] 37 | extra_link_args.append("-fopenmp") 38 | 39 | this_dir = os.path.abspath(os.path.dirname(__file__)) 40 | extensions_dir = os.path.join(this_dir, library_name, "csrc") 41 | sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) 42 | 43 | extensions_cuda_dir = os.path.join(extensions_dir, "cuda") 44 | cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) 45 | 46 | if use_cuda: 47 | sources += cuda_sources 48 | 49 | ext_modules = [ 50 | extension( 51 | f"{library_name}._C", 52 | sources, 53 | extra_compile_args=extra_compile_args, 54 | extra_link_args=extra_link_args, 55 | py_limited_api=py_limited_api, 56 | ) 57 | ] 58 | 59 | return ext_modules 60 | 61 | 62 | setuptools.setup( 63 | name=library_name, 64 | version=VERSION, 65 | author=MAINTAINER, 66 | author_email=EMAIL, 67 | description="Fast, efficient, and differentiable time-varying LPC filtering in PyTorch.", 68 | long_description=long_description, 69 | long_description_content_type="text/markdown", 70 | url="https://github.com/DiffAPF/torchlpc", 71 | packages=["torchlpc"], 72 | install_requires=["torch>=2.0", "numpy", "numba"], 73 | classifiers=[ 74 | "Programming Language :: Python :: 3", 75 | "Operating System :: MacOS :: MacOS X", 76 | "Operating System :: POSIX", 77 | ], 78 | license="MIT", 79 | ext_modules=get_extensions(), 80 | cmdclass={"build_ext": BuildExtension}, 81 | options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, 82 | ) 83 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DiffAPF/torchlpc/8f5b2719e603365904c6d1e4a913843cd5ce1416/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_extension.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pytest 4 | from torchlpc.core import lpc_np, lpc_cuda 5 | 6 | 7 | from .test_grad import create_test_inputs 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "samples", 12 | [64, 4097], 13 | ) 14 | @pytest.mark.parametrize( 15 | "cmplx", 16 | [True, False], 17 | ) 18 | @pytest.mark.parametrize( 19 | "device", 20 | [ 21 | "cpu", 22 | pytest.param( 23 | "cuda", 24 | marks=pytest.mark.skipif( 25 | not torch.cuda.is_available(), reason="CUDA not available" 26 | ), 27 | ), 28 | ], 29 | ) 30 | def test_scan_equiv(samples: int, cmplx: bool, device: str): 31 | batch_size = 4 32 | x = torch.randn( 33 | batch_size, 34 | samples, 35 | dtype=torch.float32 if not cmplx else torch.complex64, 36 | device=device, 37 | ) 38 | if cmplx: 39 | A = torch.rand( 40 | batch_size, samples, dtype=x.dtype, device=device 41 | ).sqrt() * torch.exp( 42 | 2j 43 | * torch.rand(batch_size, samples, dtype=x.dtype, device=device) 44 | * torch.pi 45 | ) 46 | else: 47 | A = torch.rand_like(x) * 1.8 - 0.9 48 | zi = torch.randn(batch_size, dtype=x.dtype, device=device) 49 | 50 | if device == "cuda": 51 | numba_y = lpc_cuda(x, -A.unsqueeze(2), zi.unsqueeze(1)) 52 | else: 53 | numba_y = torch.from_numpy( 54 | lpc_np( 55 | x.cpu().numpy(), 56 | -A.cpu().unsqueeze(2).numpy(), 57 | zi.cpu().unsqueeze(1).numpy(), 58 | ) 59 | ) 60 | ext_y = torch.ops.torchlpc.scan(x, A, zi) 61 | 62 | assert torch.allclose(numba_y, ext_y, atol=5e-7), torch.max( 63 | torch.abs(numba_y - ext_y) 64 | ).item() 65 | 66 | 67 | @pytest.mark.parametrize("samples", [1021, 4097]) 68 | @pytest.mark.parametrize( 69 | "cmplx", 70 | [True, False], 71 | ) 72 | @pytest.mark.parametrize( 73 | "device", 74 | [ 75 | "cpu", 76 | pytest.param( 77 | "cuda", 78 | marks=pytest.mark.skipif( 79 | not torch.cuda.is_available(), reason="CUDA not available" 80 | ), 81 | ), 82 | ], 83 | ) 84 | def test_lpc_equiv(samples: int, cmplx: bool, device: str): 85 | batch_size = 4 86 | x, A, zi = tuple( 87 | x.to(device) for x in create_test_inputs(batch_size, samples, cmplx) 88 | ) 89 | if device == "cuda": 90 | numba_y = lpc_cuda(x, A, zi) 91 | else: 92 | numba_y = torch.from_numpy(lpc_np(x.numpy(), A.numpy(), zi.numpy())) 93 | ext_y = torch.ops.torchlpc.lpc(x, A, zi) 94 | 95 | assert torch.allclose(numba_y, ext_y) 96 | -------------------------------------------------------------------------------- /tests/test_grad.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.autograd.gradcheck import gradcheck, gradgradcheck 4 | from torchlpc.core import LPC 5 | from torchlpc.recurrence import Recurrence 6 | 7 | 8 | def get_random_biquads(cmplx=False): 9 | if cmplx: 10 | mag = torch.rand(2, dtype=torch.double) 11 | phase = torch.rand(2, dtype=torch.double) * 2 * torch.pi 12 | roots = mag * torch.exp(1j * phase) 13 | return torch.tensor( 14 | [-roots[0] - roots[1], roots[0] * roots[1]], dtype=torch.complex128 15 | ) 16 | mag = torch.rand(1, dtype=torch.double) 17 | phase = torch.rand(1, dtype=torch.double) * torch.pi 18 | return torch.tensor([-mag * torch.cos(phase) * 2, mag**2], dtype=torch.double) 19 | 20 | 21 | def create_test_inputs(batch_size, samples, cmplx=False): 22 | start_coeffs = get_random_biquads(cmplx) 23 | end_coeffs = get_random_biquads(cmplx) 24 | dtype = torch.complex128 if cmplx else torch.double 25 | 26 | A = ( 27 | torch.stack( 28 | [ 29 | torch.linspace(start_coeffs[i], end_coeffs[i], samples, dtype=dtype) 30 | for i in range(2) 31 | ] 32 | ) 33 | .T.unsqueeze(0) 34 | .repeat(batch_size, 1, 1) 35 | ) 36 | x = torch.randn(batch_size, samples, dtype=dtype) 37 | zi = torch.randn(batch_size, 2, dtype=dtype) 38 | return x, A, zi 39 | 40 | 41 | @pytest.mark.parametrize( 42 | "x_requires_grad", 43 | [True], 44 | ) 45 | @pytest.mark.parametrize( 46 | "a_requires_grad", 47 | [True, False], 48 | ) 49 | @pytest.mark.parametrize( 50 | "zi_requires_grad", 51 | [True, False], 52 | ) 53 | @pytest.mark.parametrize( 54 | "samples", 55 | [32], 56 | ) 57 | @pytest.mark.parametrize( 58 | "cmplx", 59 | [True, False], 60 | ) 61 | @pytest.mark.parametrize( 62 | "device", 63 | [ 64 | "cpu", 65 | pytest.param( 66 | "cuda", 67 | marks=pytest.mark.skipif( 68 | not torch.cuda.is_available(), reason="CUDA not available" 69 | ), 70 | ), 71 | ], 72 | ) 73 | def test_low_order( 74 | x_requires_grad: bool, 75 | a_requires_grad: bool, 76 | zi_requires_grad: bool, 77 | samples: int, 78 | cmplx: bool, 79 | device: str, 80 | ): 81 | batch_size = 4 82 | x, A, zi = tuple( 83 | x.to(device) for x in create_test_inputs(batch_size, samples, cmplx) 84 | ) 85 | A.requires_grad = a_requires_grad 86 | x.requires_grad = x_requires_grad 87 | zi.requires_grad = zi_requires_grad 88 | 89 | assert gradcheck(LPC.apply, (x, A, zi), check_forward_ad=True) 90 | assert gradgradcheck(LPC.apply, (x, A, zi)) 91 | 92 | 93 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") 94 | def test_float64_vs_32_cuda(): 95 | batch_size = 4 96 | samples = 32 97 | x, A, zi = create_test_inputs(batch_size, samples) 98 | x = x.cuda() 99 | A = A.cuda() 100 | zi = zi.cuda() 101 | 102 | x32 = x.float() 103 | A32 = A.float() 104 | zi32 = zi.float() 105 | 106 | y64 = LPC.apply(x, A, zi) 107 | y32 = LPC.apply(x32, A32, zi32) 108 | 109 | assert torch.allclose(y64, y32.double(), atol=1e-6), torch.max( 110 | torch.abs(y64 - y32.double()) 111 | ) 112 | 113 | 114 | @pytest.mark.parametrize( 115 | "x_requires_grad", 116 | [True], 117 | ) 118 | @pytest.mark.parametrize( 119 | "a_requires_grad", 120 | [True, False], 121 | ) 122 | @pytest.mark.parametrize( 123 | "zi_requires_grad", 124 | [True, False], 125 | ) 126 | @pytest.mark.parametrize( 127 | "cmplx", 128 | [True, False], 129 | ) 130 | @pytest.mark.parametrize( 131 | "device", 132 | [ 133 | "cpu", 134 | pytest.param( 135 | "cuda", 136 | marks=pytest.mark.skipif( 137 | not torch.cuda.is_available(), reason="CUDA not available" 138 | ), 139 | ), 140 | ], 141 | ) 142 | def test_parallel_scan( 143 | x_requires_grad: bool, 144 | a_requires_grad: bool, 145 | zi_requires_grad: bool, 146 | cmplx: bool, 147 | device: str, 148 | ): 149 | batch_size = 2 150 | samples = 123 151 | dtype = torch.complex128 if cmplx else torch.double 152 | x = torch.randn(batch_size, samples, dtype=dtype, device=device) 153 | if cmplx: 154 | A = torch.rand( 155 | batch_size, samples, dtype=torch.double, device=device 156 | ).sqrt() * torch.exp( 157 | 1j 158 | * torch.rand(batch_size, samples, dtype=torch.double, device=device) 159 | * 2 160 | * torch.pi 161 | ) 162 | else: 163 | A = torch.rand(batch_size, samples, dtype=dtype, device=device) * 2 - 1 164 | zi = torch.randn(batch_size, dtype=dtype, device=device) 165 | 166 | A.requires_grad = a_requires_grad 167 | x.requires_grad = x_requires_grad 168 | zi.requires_grad = zi_requires_grad 169 | 170 | assert gradcheck(Recurrence.apply, (A, x, zi), check_forward_ad=True) 171 | assert gradgradcheck(Recurrence.apply, (A, x, zi)) 172 | -------------------------------------------------------------------------------- /tests/test_vmap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.func import jacfwd 4 | import pytest 5 | from torchlpc.core import LPC 6 | from torchlpc.recurrence import Recurrence 7 | 8 | 9 | from .test_grad import create_test_inputs 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "device", 14 | [ 15 | "cpu", 16 | pytest.param( 17 | "cuda", 18 | marks=pytest.mark.skipif( 19 | not torch.cuda.is_available(), reason="CUDA not available" 20 | ), 21 | ), 22 | ], 23 | ) 24 | def test_vmap(device: str): 25 | batch_size = 4 26 | samples = 40 27 | x, A, zi = tuple( 28 | x.to(device) for x in create_test_inputs(batch_size, samples, False) 29 | ) 30 | y = torch.randn_like(x) 31 | 32 | A = A[:, 0, :].clone() 33 | 34 | A.requires_grad = True 35 | zi.requires_grad = True 36 | x.requires_grad = True 37 | 38 | args = (x, A, zi) 39 | 40 | def func(x, A, zi): 41 | return F.mse_loss(LPC.apply(x, A[:, None, :].expand(-1, samples, -1), zi), y) 42 | 43 | jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args) 44 | 45 | loss = func(*args) 46 | loss.backward() 47 | for jac, arg in zip(jacs, args): 48 | assert torch.allclose(jac, arg.grad) 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "device", 53 | [ 54 | "cpu", 55 | pytest.param( 56 | "cuda", 57 | marks=pytest.mark.skipif( 58 | not torch.cuda.is_available(), reason="CUDA not available" 59 | ), 60 | ), 61 | ], 62 | ) 63 | def test_parallel_scan_vmap(device: str): 64 | batch_size = 3 65 | samples = 255 66 | x = torch.randn(batch_size, samples, dtype=torch.double, device=device) 67 | A = torch.rand(batch_size, samples, dtype=torch.double, device=device) * 2 - 1 68 | zi = torch.randn(batch_size, dtype=torch.double, device=device) 69 | y = torch.randn(batch_size, samples, dtype=torch.double, device=device) 70 | 71 | A.requires_grad = True 72 | x.requires_grad = True 73 | zi.requires_grad = True 74 | 75 | args = (x, A, zi) 76 | 77 | def func(x, A, zi): 78 | return F.mse_loss(Recurrence.apply(A, x, zi), y) 79 | 80 | jacs = jacfwd(func, argnums=tuple(range(len(args))))(*args) 81 | 82 | loss = func(*args) 83 | loss.backward() 84 | for jac, arg in zip(jacs, args): 85 | assert torch.allclose(jac, arg.grad) 86 | -------------------------------------------------------------------------------- /torchlpc/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Union, Tuple 3 | from pathlib import Path 4 | import warnings 5 | 6 | # so_files = list(Path(__file__).parent.glob("_C*.so")) 7 | # # assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" 8 | # if len(so_files) == 1: 9 | # torch.ops.load_library(so_files[0]) 10 | # EXTENSION_LOADED = True 11 | # elif len(so_files) > 1: 12 | # raise ValueError(f"Expected one _C*.so file, found {len(so_files)}") 13 | # else: 14 | # warnings.warn("No _C*.so file found. Custom extension not loaded.") 15 | # EXTENSION_LOADED = False 16 | 17 | try: 18 | from . import _C 19 | 20 | EXTENSION_LOADED = True 21 | except ImportError: 22 | EXTENSION_LOADED = False 23 | warnings.warn("Custom extension not loaded. Falling back to Numba implementation.") 24 | 25 | from .core import LPC 26 | 27 | # from .parallel_scan import WARPSIZE 28 | from .recurrence import Recurrence 29 | 30 | __all__ = ["sample_wise_lpc"] 31 | 32 | 33 | def sample_wise_lpc( 34 | x: torch.Tensor, 35 | a: torch.Tensor, 36 | zi: Optional[torch.Tensor] = None, 37 | return_zf: bool = False, 38 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 39 | """Compute LPC filtering sample-wise. 40 | 41 | Args: 42 | x (torch.Tensor): Input signal. 43 | a (torch.Tensor): LPC coefficients. 44 | zi (torch.Tensor): Initial conditions. 45 | return_zf (bool): If True, return the final filter delay values. Defaults to False. 46 | 47 | Shape: 48 | - x: :math:`(B, T)` 49 | - a: :math:`(B, T, order)` 50 | - zi: :math:`(B, order)` 51 | 52 | Returns: 53 | Filtered signal with the same shape as x if `return_zf` is False. 54 | If `return_zf` is True, returns a tuple of the filtered signal and the final delay values. 55 | """ 56 | assert x.shape[0] == a.shape[0] 57 | assert x.shape[1] == a.shape[1] 58 | assert x.ndim == 2 59 | assert a.ndim == 3 60 | 61 | B, T, order = a.shape 62 | if zi is None: 63 | zi = a.new_zeros(B, order) 64 | else: 65 | assert zi.shape == (B, order) 66 | 67 | # if order == 1 and x.is_cuda and B * WARPSIZE < T: 68 | # return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1)) 69 | if order == 1: 70 | y = Recurrence.apply(-a.squeeze(2), x, zi.squeeze(1)) 71 | else: 72 | y = LPC.apply(x, a, zi) 73 | 74 | if return_zf: 75 | return y, y[:, -order:].flip(1) 76 | return y 77 | -------------------------------------------------------------------------------- /torchlpc/core.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | from typing import Any, Tuple, Optional, Callable, List 7 | from numba import jit, njit, prange, cuda, float32, float64, complex64, complex128 8 | 9 | from . import EXTENSION_LOADED 10 | 11 | lpc_cuda_kernel_float32: Callable = None 12 | lpc_cuda_kernel_float64: Callable = None 13 | lpc_cuda_kernel_complex64: Callable = None 14 | lpc_cuda_kernel_complex128: Callable = None 15 | 16 | 17 | for t in ["float32", "float64"]: 18 | exec( 19 | f"""@cuda.jit 20 | def lpc_cuda_kernel_{t}(padded_y, A, B, T, order) -> None: 21 | sm = cuda.shared.array(shape=0, dtype={t}) 22 | batch_idx = cuda.blockIdx.x 23 | tid = cuda.threadIdx.x 24 | 25 | i = tid 26 | b = batch_idx 27 | 28 | if b >= B or i >= order: 29 | return 30 | 31 | circular_idx = 0 32 | sm[i] = padded_y[b, i] 33 | 34 | for t in range(T): 35 | circular_idx = t % order 36 | a = -A[b, t, i] 37 | if i > circular_idx - 1: 38 | s = sm[circular_idx - i - 1 + order] 39 | else: 40 | s = sm[circular_idx - i - 1] 41 | 42 | v = a * s 43 | 44 | if i == (order - 1): 45 | sm[circular_idx] = v 46 | v = padded_y[b, t + order] 47 | cuda.syncthreads() 48 | cuda.atomic.add(sm, circular_idx, v) 49 | cuda.syncthreads() 50 | 51 | if i == (order - 1): 52 | padded_y[b, t + order] = sm[circular_idx]""" 53 | ) 54 | 55 | # separate kernel for complex type as atomic.add does not support complex types 56 | for t, dt in zip(["complex64", "complex128"], ["float32", "float64"]): 57 | exec( 58 | f"""@cuda.jit 59 | def lpc_cuda_kernel_{t}(padded_y, A, B, T, order) -> None: 60 | sm = cuda.shared.array(shape=0, dtype={dt}) 61 | batch_idx = cuda.blockIdx.x 62 | tid = cuda.threadIdx.x 63 | 64 | i = tid 65 | b = batch_idx 66 | 67 | if b >= B or i >= order: 68 | return 69 | 70 | sm_real = sm[:order] 71 | sm_imag = sm[order:2*order] 72 | 73 | circular_idx = 0 74 | sm_real[i] = padded_y.real[b, i] 75 | sm_imag[i] = padded_y.imag[b, i] 76 | 77 | for t in range(T): 78 | circular_idx = t % order 79 | a = -A[b, t, i] 80 | if i > circular_idx - 1: 81 | s_real = sm_real[circular_idx - i - 1 + order] 82 | s_imag = sm_imag[circular_idx - i - 1 + order] 83 | else: 84 | s_real = sm_real[circular_idx - i - 1] 85 | s_imag = sm_imag[circular_idx - i - 1] 86 | 87 | v_real = a.real * s_real - a.imag * s_imag 88 | v_imag = a.real * s_imag + a.imag * s_real 89 | 90 | if i == (order - 1): 91 | sm_real[circular_idx] = v_real 92 | sm_imag[circular_idx] = v_imag 93 | v_real = padded_y.real[b, t + order] 94 | v_imag = padded_y.imag[b, t + order] 95 | cuda.syncthreads() 96 | 97 | cuda.atomic.add(sm_real, circular_idx, v_real) 98 | cuda.atomic.add(sm_imag, circular_idx, v_imag) 99 | cuda.syncthreads() 100 | 101 | if i == (order - 1): 102 | padded_y[b, t + order] = sm_real[circular_idx] + 1j * sm_imag[circular_idx]""" 103 | ) 104 | 105 | 106 | def lpc_cuda(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor: 107 | B, T, order = A.shape 108 | assert order <= 1024 109 | padded_y = torch.empty((B, T + order), dtype=x.dtype, device=x.device) 110 | padded_y[:, :order] = zi.flip(1) 111 | padded_y[:, order:] = x 112 | 113 | threads_per_block = order 114 | blocks_per_grid = B 115 | stream = cuda.stream() 116 | 117 | if x.dtype == torch.float32: 118 | runner = lpc_cuda_kernel_float32[ 119 | blocks_per_grid, threads_per_block, stream, 4 * order 120 | ] 121 | elif x.dtype == torch.float64: 122 | runner = lpc_cuda_kernel_float64[ 123 | blocks_per_grid, threads_per_block, stream, 8 * order 124 | ] 125 | elif x.dtype == torch.complex64: 126 | runner = lpc_cuda_kernel_complex64[ 127 | blocks_per_grid, threads_per_block, stream, 8 * order 128 | ] 129 | elif x.dtype == torch.complex128: 130 | runner = lpc_cuda_kernel_complex128[ 131 | blocks_per_grid, threads_per_block, stream, 16 * order 132 | ] 133 | else: 134 | raise NotImplementedError(f"Unsupported dtype: {x.dtype}") 135 | 136 | runner(cuda.as_cuda_array(padded_y), cuda.as_cuda_array(A), B, T, order) 137 | 138 | return padded_y[:, order:].contiguous() 139 | 140 | 141 | @njit(parallel=True) 142 | def lpc_np(x: np.ndarray, A: np.ndarray, zi: np.ndarray) -> np.ndarray: 143 | B, T = x.shape 144 | order = zi.shape[1] 145 | padded_y = np.empty((B, T + order), dtype=x.dtype) 146 | padded_y[:, :order] = zi[:, ::-1] 147 | padded_y[:, order:] = x 148 | 149 | for b in prange(B): 150 | for t in range(T): 151 | ref = padded_y[b, t + order] 152 | for i in prange(order): 153 | ref -= A[b, t, i] * padded_y[b, t + order - i - 1] 154 | padded_y[b, t + order] = ref 155 | 156 | return padded_y[:, order:] 157 | 158 | 159 | class LPC(Function): 160 | @staticmethod 161 | def forward(x: torch.Tensor, A: torch.Tensor, zi: torch.Tensor) -> torch.Tensor: 162 | if EXTENSION_LOADED: 163 | y = torch.ops.torchlpc.lpc(x, A, zi) 164 | else: 165 | warnings.warn( 166 | "Cannot find custom extension. Falling back to Numba implementation which will be deprecated in v1.0." 167 | ) 168 | if x.is_cuda: 169 | y = lpc_cuda(x.detach(), A.detach(), zi.detach()) 170 | else: 171 | y = lpc_np( 172 | x.detach().cpu().numpy(), 173 | A.detach().cpu().numpy(), 174 | zi.detach().cpu().numpy(), 175 | ) 176 | y = torch.from_numpy(y).to(x.device, x.dtype) 177 | return y 178 | 179 | @staticmethod 180 | def setup_context(ctx: Any, inputs: List[Any], output: Any) -> Any: 181 | _, A, zi = inputs 182 | y = output 183 | ctx.save_for_backward(A, zi, y) 184 | ctx.save_for_forward(A, zi, y) 185 | 186 | @staticmethod 187 | def backward( 188 | ctx: Any, grad_y: torch.Tensor 189 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 190 | A, zi, y = ctx.saved_tensors 191 | grad_x = grad_A = grad_zi = None 192 | B, T, order = A.shape 193 | 194 | flipped_A = A.flip(2) 195 | padded_flipped_A = F.pad(flipped_A.transpose(1, 2), (0, order + 1)) 196 | shifted_A = ( 197 | padded_flipped_A.reshape(B, T + order + 1, order)[:, :-1, :] 198 | .reshape(B, order, T + order) 199 | .transpose(1, 2) 200 | .flip(2) 201 | ) 202 | 203 | if not ctx.needs_input_grad[2]: 204 | shifted_A = shifted_A[:, order:, :] 205 | padded_grad_y = grad_y 206 | else: 207 | padded_grad_y = F.pad(grad_y.unsqueeze(1), (order, 0)).squeeze(1) 208 | 209 | flipped_grad_x = LPC.apply( 210 | padded_grad_y.flip(1), 211 | shifted_A.flip(1).conj_physical(), 212 | torch.zeros_like(zi), 213 | ) 214 | 215 | if ctx.needs_input_grad[2]: 216 | grad_zi = flipped_grad_x[:, -order:] 217 | flipped_grad_x = flipped_grad_x[:, :-order] 218 | 219 | if ctx.needs_input_grad[0]: 220 | grad_x = flipped_grad_x.flip(1) 221 | 222 | if ctx.needs_input_grad[1]: 223 | valid_y = y[:, :-1] 224 | padded_y = torch.cat([zi.flip(1), valid_y], dim=1) 225 | 226 | unfolded_y = padded_y.unfold(1, order, 1).flip(2) 227 | grad_A = unfolded_y.conj_physical() * -flipped_grad_x.flip(1).unsqueeze(2) 228 | 229 | return grad_x, grad_A, grad_zi 230 | 231 | @staticmethod 232 | def jvp( 233 | ctx: Any, grad_x: torch.Tensor, grad_A: torch.Tensor, grad_zi: torch.Tensor 234 | ) -> torch.Tensor: 235 | A, zi, y = ctx.saved_tensors 236 | *_, order = A.shape 237 | 238 | fwd_zi = grad_zi if grad_zi is not None else torch.zeros_like(zi) 239 | fwd_x = grad_x if grad_x is not None else torch.zeros_like(y) 240 | 241 | if grad_A is not None: 242 | unfolded_y = ( 243 | torch.cat([zi.flip(1), y[:, :-1]], dim=1).unfold(1, order, 1).flip(2) 244 | ) 245 | fwd_A = -torch.sum(unfolded_y * grad_A, dim=2) 246 | fwd_x = fwd_x + fwd_A 247 | 248 | return LPC.apply(fwd_x, A, fwd_zi) 249 | 250 | @staticmethod 251 | def vmap(info, in_dims, *args): 252 | def maybe_expand_bdim_at_front(x, x_bdim): 253 | if x_bdim is None: 254 | return x.expand(info.batch_size, *x.shape) 255 | return x.movedim(x_bdim, 0) 256 | 257 | x, A, zi = tuple( 258 | map( 259 | lambda x: x.reshape(-1, *x.shape[2:]), 260 | map(maybe_expand_bdim_at_front, args, in_dims), 261 | ) 262 | ) 263 | 264 | y = LPC.apply(x, A, zi) 265 | return y.reshape(info.batch_size, -1, *y.shape[1:]), 0 266 | -------------------------------------------------------------------------------- /torchlpc/csrc/cuda/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) <2017> 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /torchlpc/csrc/cuda/linear_recurrence.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define CEIL_DIV(x, y) ((x + y - 1) / y) 9 | 10 | #define gpuErrChk(ans) \ 11 | { \ 12 | gpuAssert((ans), __FILE__, __LINE__); \ 13 | } 14 | void gpuAssert(cudaError_t code, const char *file, int line) { 15 | if (code != cudaSuccess) { 16 | fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file, 17 | line); 18 | } 19 | } 20 | 21 | __device__ int2 divide_work(int n_jobs, int n_workers, int worker_idx) { 22 | // Each worker will do a continuous slice of either n_jobs / n_workers 23 | // or ceil_div(n_jobs, n_workers). The return value is an int2 representing 24 | // a half open interval of jobs for the worker to perform (perform jobs 25 | // i for a <= i < b) 26 | 27 | int cd = CEIL_DIV(n_jobs, n_workers); 28 | int d = n_jobs / n_workers; 29 | 30 | int doing_cd = n_jobs % n_workers; 31 | 32 | int2 retval; 33 | if (worker_idx < doing_cd) { 34 | retval.x = worker_idx * cd; 35 | retval.y = retval.x + cd; 36 | } else { 37 | retval.x = doing_cd * cd + (worker_idx - doing_cd) * d; 38 | retval.y = retval.x + d; 39 | } 40 | 41 | return retval; 42 | } 43 | 44 | __device__ int2 compute_warp_start_stop(int block_idx, int warp_idx, 45 | int n_blocks, int n_steps) { 46 | int2 block_ss = divide_work(n_steps, n_blocks, block_idx); 47 | int block_start = block_ss.x; 48 | int block_stop = block_ss.y; 49 | int block_jobs = block_stop - block_start; 50 | 51 | int2 warp_ss = divide_work(block_jobs, 32, warp_idx); 52 | int warp_start = block_start + warp_ss.x; 53 | int warp_stop = block_start + warp_ss.y; 54 | 55 | int2 retval; 56 | retval.x = warp_start; 57 | retval.y = warp_stop; 58 | return retval; 59 | } 60 | 61 | // decay storage, h_storage: 62 | // each a n_dims x 33 x n_blocks matrix on GPU with 33rd column for block 63 | // reduction 64 | template 65 | __global__ void reduction_kernel(const scalar_t *decays, 66 | const scalar_t *impulses, 67 | const scalar_t *initial_state, 68 | scalar_t *_decay_storage, scalar_t *_h_storage, 69 | int n_dims, int n_steps) { 70 | int warp = threadIdx.x / 32; 71 | int lane = threadIdx.x % 32; 72 | 73 | scalar_t *decay_storage = &_decay_storage[blockIdx.x * 33 * n_dims]; 74 | scalar_t *h_storage = &_h_storage[blockIdx.x * 33 * n_dims]; 75 | 76 | int2 start_stop = 77 | compute_warp_start_stop(blockIdx.x, lane, gridDim.x, n_steps); 78 | int warp_start = start_stop.x; 79 | int warp_stop = start_stop.y; 80 | 81 | /* 82 | * Reduce within warps. 83 | * After this loop exits, the storage arrays should contain the reduction 84 | * from warp_start to warp_stop (including initial state) at index 85 | * (feature_idx, warp, block). 86 | */ 87 | for (int i = warp; i < n_dims; i += CEIL_DIV(blockDim.x, 32)) { 88 | scalar_t cum_decay = static_cast(1.0); 89 | scalar_t h = static_cast(0.0); 90 | if (blockIdx.x == 0 && lane == 0 && initial_state != NULL) { 91 | h = initial_state[i]; 92 | } 93 | 94 | for (int t = warp_start; t < warp_stop; t++) { 95 | cum_decay *= decays[i * n_steps + t]; 96 | h = decays[i * n_steps + t] * h + impulses[i * n_steps + t]; 97 | } 98 | 99 | // TODO: store into shared memory, work in shared memory sized blocks 100 | // store into global memory 101 | decay_storage[i + lane * n_dims] = cum_decay; 102 | h_storage[i + lane * n_dims] = h; 103 | } 104 | 105 | __syncthreads(); 106 | 107 | /* 108 | * Reduce over warps. 109 | * After this loop exits, the storage arrays should contain the reduction 110 | * from block_start to block_finish (including initial state) at index 111 | * (feature_idx, 32, block). 112 | */ 113 | // TODO: parallel reduction (or scan). Need to worry about changing the warp 114 | // reduction values (as I use them again later) 115 | for (int i = threadIdx.x; i < n_dims; i += blockDim.x) { 116 | scalar_t cum_decay = static_cast(1.0); 117 | scalar_t h = static_cast(0.0); 118 | for (int t = 0; t < 32; t++) { 119 | cum_decay *= decay_storage[i + t * n_dims]; 120 | h = decay_storage[i + t * n_dims] * h + h_storage[i + t * n_dims]; 121 | } 122 | decay_storage[i + 32 * n_dims] = cum_decay; 123 | h_storage[i + 32 * n_dims] = h; 124 | } 125 | } 126 | 127 | template 128 | __global__ void block_scan_kernel(scalar_t *decay_storage, scalar_t *h_storage, 129 | int n_dims, int n_blocks) { 130 | /* 131 | * Scan over blocks. 132 | * After this loop exits, the storage arrays should contain the cumulative 133 | * sum from block_idx 0 to i (inclusive) at index (feature_idx, 32, i) This 134 | * means (feature_idx, 32, 2) contains the reduction of blocks 0, 1, and 2. 135 | */ 136 | // TODO: parallel scan (tricky because number of blocks isn't necessarily 137 | // smaller than number of warps that can fit in a single block) 138 | for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < n_dims; 139 | i += blockDim.x * gridDim.x) { 140 | for (int t = 1; t < n_blocks; t++) { 141 | int cur_idx = i + 32 * n_dims + t * 33 * n_dims; 142 | int prev_idx = i + 32 * n_dims + (t - 1) * 33 * n_dims; 143 | 144 | // TODO: remove unneccessary reads from global memory (prev_idx 145 | // accesses) 146 | h_storage[cur_idx] = decay_storage[cur_idx] * h_storage[prev_idx] + 147 | h_storage[cur_idx]; 148 | decay_storage[cur_idx] *= decay_storage[prev_idx]; 149 | } 150 | } 151 | } 152 | 153 | template 154 | __global__ void warp_scan_kernel(const scalar_t *decays, 155 | const scalar_t *impulses, 156 | const scalar_t *initial_state, scalar_t *out, 157 | scalar_t *decay_storage, scalar_t *h_storage, 158 | int n_dims, int n_steps) { 159 | int warp = threadIdx.x / 32; 160 | int lane = threadIdx.x % 32; 161 | 162 | // Note: Due to the index ordering of the storage arrays, the following 163 | // indices are equivalent: 164 | // 165 | // i + (t - 1) * n_dims + blockIdx.x * 33 * n_dims 166 | // i + 32 * n_dims + (blockIdx.x - 1) * 33 * n_dims 167 | // 168 | // when t is 0. This means something that looks like negative indexing 169 | // (t-1) can be used to safely access the stored value for the previous 170 | // warp (even if the previous warp belonged to the previous block). 171 | 172 | /* 173 | * Scan over warps. 174 | * After this loop executes, the storage arrays should contain the 175 | * cumulative sum from the beginning of sequence (including initial 176 | * condition) up to and including the indexed warp and block. 177 | */ 178 | // TODO: parallel scan 179 | for (int i = threadIdx.x; i < n_dims; i += blockDim.x) { 180 | for (int t = 0; t < 32; t++) { 181 | if (t == 0 && blockIdx.x == 0) { 182 | // the reduction over warp 0 (including initial condition) is 183 | // correct val for scan, so there's no work to do 184 | continue; 185 | } 186 | 187 | int cur_idx = i + t * n_dims + blockIdx.x * 33 * n_dims; 188 | int prev_idx = i + (t - 1) * n_dims + blockIdx.x * 33 * n_dims; 189 | h_storage[cur_idx] = decay_storage[cur_idx] * h_storage[prev_idx] + 190 | h_storage[cur_idx]; 191 | decay_storage[cur_idx] *= decay_storage[prev_idx]; 192 | } 193 | } 194 | 195 | __syncthreads(); 196 | 197 | int2 start_stop = 198 | compute_warp_start_stop(blockIdx.x, lane, gridDim.x, n_steps); 199 | int warp_start = start_stop.x; 200 | int warp_stop = start_stop.y; 201 | 202 | /* 203 | * Scan within warps. 204 | * This loop writes to the output array. Each warp reads in it's initial 205 | * state (either from the "initial_state" or the storage arrays) and then 206 | * writes to output for indices warp_start up to warp_stop. 207 | */ 208 | for (int i = warp; i < n_dims; i += CEIL_DIV(blockDim.x, 32)) { 209 | scalar_t h = static_cast(0.0); 210 | if (blockIdx.x == 0 && lane == 0) { 211 | if (initial_state != NULL) { 212 | h = initial_state[i]; 213 | } 214 | } else { 215 | h = h_storage[i + (lane - 1) * n_dims + blockIdx.x * 33 * n_dims]; 216 | } 217 | 218 | for (int t = warp_start; t < warp_stop; t++) { 219 | h = decays[i * n_steps + t] * h + impulses[i * n_steps + t]; 220 | out[i * n_steps + t] = h; 221 | } 222 | } 223 | } 224 | 225 | /* 226 | * This is the main method for the prefix sum kernels. 227 | * decays, impulses, out: 228 | * each a n_dims x n_steps column major matrix located on GPU 229 | * initial_state: 230 | * array of size n_dims located on GPU 231 | */ 232 | template 233 | void compute_linear_recurrence(const scalar_t *decays, const scalar_t *impulses, 234 | const scalar_t *initial_state, scalar_t *out, 235 | int n_dims, int n_steps) { 236 | // we want at least 32 elements per block, but no reason to run 237 | // with more than the maximum number of concurrent blocks 238 | // NOTE: 128 is decided empirically. 239 | int n_blocks = min(CEIL_DIV(n_steps, 32), 128); 240 | 241 | // TODO: make user pass in working memory? This allows integration 242 | // with CNMeM (used by Theano) 243 | int reduction_mem_sz = 2 * n_blocks * 33 * n_dims * sizeof(scalar_t); 244 | scalar_t *d_reduction_mem; 245 | gpuErrChk(cudaMalloc(&d_reduction_mem, reduction_mem_sz)); 246 | scalar_t *d_decay_storage = &d_reduction_mem[0 * n_blocks * 33 * n_dims]; 247 | scalar_t *d_h_storage = &d_reduction_mem[1 * n_blocks * 33 * n_dims]; 248 | 249 | // TODO: run kernels on non-default stream? 250 | reduction_kernel<<>>(decays, impulses, initial_state, 251 | d_decay_storage, d_h_storage, n_dims, 252 | n_steps); 253 | 254 | block_scan_kernel<<>>(d_decay_storage, d_h_storage, n_dims, 255 | n_blocks); 256 | 257 | warp_scan_kernel<<>>(decays, impulses, initial_state, out, 258 | d_decay_storage, d_h_storage, n_dims, 259 | n_steps); 260 | 261 | gpuErrChk(cudaFree(d_reduction_mem)); 262 | } 263 | 264 | at::Tensor scan_cuda_wrapper(const at::Tensor &input, const at::Tensor &weights, 265 | const at::Tensor &initials) { 266 | TORCH_CHECK(input.is_floating_point() || input.is_complex(), 267 | "Input must be floating point or complex"); 268 | TORCH_CHECK(initials.scalar_type() == input.scalar_type(), 269 | "Initials must have the same scalar type as input"); 270 | TORCH_CHECK(weights.scalar_type() == input.scalar_type(), 271 | "Weights must have the same scalar type as input"); 272 | 273 | auto input_contiguous = input.contiguous(); 274 | auto weights_contiguous = weights.contiguous(); 275 | auto output = at::empty_like(input_contiguous); 276 | 277 | const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); 278 | 279 | AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( 280 | input.scalar_type(), "compute_linear_recurrence", [&] { 281 | compute_linear_recurrence( 282 | weights_contiguous.const_data_ptr(), 283 | input_contiguous.const_data_ptr(), 284 | initials.const_data_ptr(), 285 | output.mutable_data_ptr(), input_contiguous.size(0), 286 | input_contiguous.size(1)); 287 | }); 288 | return output.contiguous(); 289 | } 290 | 291 | TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("scan", &scan_cuda_wrapper); } 292 | -------------------------------------------------------------------------------- /torchlpc/csrc/cuda/lpc.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | // CUDA kernel for LPC computation 9 | template 10 | __global__ void lpc_cuda_kernel(scalar_t* padded_y, // [B, T + order] 11 | const scalar_t* A, // [B, T, order] 12 | int64_t B, int64_t T, int64_t order) { 13 | extern __shared__ char smem[]; 14 | scalar_t* sm = reinterpret_cast(smem); 15 | 16 | int b = blockIdx.x; 17 | int i = threadIdx.x; 18 | 19 | if (b >= B || i >= order) return; 20 | 21 | // Initialize shared memory with the first 'order' elements 22 | sm[i] = padded_y[b * (T + order) + i]; 23 | __syncthreads(); 24 | 25 | int circular_idx = 0; 26 | for (int t = 0; t < T; ++t) { 27 | circular_idx = t % order; 28 | scalar_t a = -A[((b * T + t) * order) + i]; 29 | 30 | // Compute s as in the Python code 31 | int idx_offset = circular_idx - i - 1; 32 | if (i > circular_idx - 1) { 33 | idx_offset += order; 34 | } 35 | scalar_t s = sm[(idx_offset + order) % order]; 36 | 37 | scalar_t v = a * s; 38 | 39 | if (i == order - 1) { 40 | sm[circular_idx] = v; 41 | v = padded_y[b * (T + order) + t + order]; 42 | } 43 | __syncthreads(); 44 | 45 | // Atomic add to shared memory 46 | atomicAdd(&sm[circular_idx], v); 47 | __syncthreads(); 48 | 49 | if (i == order - 1) { 50 | padded_y[b * (T + order) + t + order] = sm[circular_idx]; 51 | } 52 | __syncthreads(); 53 | } 54 | } 55 | // CUDA kernel for complex LPC computation 56 | template 57 | __global__ void lpc_cuda_kernel_complex( 58 | scalar_t* padded_y_real, // [B, T + order] 59 | scalar_t* padded_y_imag, // [B, T + order] 60 | const scalar_t* A_real, // [B, T, order] 61 | const scalar_t* A_imag, // [B, T, order] 62 | int64_t B, int64_t T, int64_t order) { 63 | extern __shared__ char smem[]; 64 | scalar_t* sm_real = reinterpret_cast(smem); 65 | scalar_t* sm_imag = sm_real + order; 66 | 67 | int b = blockIdx.x; 68 | int i = threadIdx.x; 69 | 70 | if (b >= B || i >= order) return; 71 | 72 | // Initialize shared memory with the first 'order' elements 73 | sm_real[i] = padded_y_real[b * (T + order) + i]; 74 | sm_imag[i] = padded_y_imag[b * (T + order) + i]; 75 | __syncthreads(); 76 | 77 | int circular_idx = 0; 78 | for (int t = 0; t < T; ++t) { 79 | circular_idx = t % order; 80 | scalar_t a_real = -A_real[((b * T + t) * order) + i]; 81 | scalar_t a_imag = -A_imag[((b * T + t) * order) + i]; 82 | 83 | int idx_offset = circular_idx - i - 1; 84 | if (i > circular_idx - 1) { 85 | idx_offset += order; 86 | } 87 | int s_idx = (idx_offset + order) % order; 88 | scalar_t s_real = sm_real[s_idx]; 89 | scalar_t s_imag = sm_imag[s_idx]; 90 | 91 | // Complex multiply: v = a * s 92 | scalar_t v_real = a_real * s_real - a_imag * s_imag; 93 | scalar_t v_imag = a_real * s_imag + a_imag * s_real; 94 | 95 | if (i == order - 1) { 96 | sm_real[circular_idx] = v_real; 97 | sm_imag[circular_idx] = v_imag; 98 | v_real = padded_y_real[b * (T + order) + t + order]; 99 | v_imag = padded_y_imag[b * (T + order) + t + order]; 100 | } 101 | __syncthreads(); 102 | 103 | atomicAdd(&sm_real[circular_idx], v_real); 104 | atomicAdd(&sm_imag[circular_idx], v_imag); 105 | __syncthreads(); 106 | 107 | if (i == order - 1) { 108 | padded_y_real[b * (T + order) + t + order] = sm_real[circular_idx]; 109 | padded_y_imag[b * (T + order) + t + order] = sm_imag[circular_idx]; 110 | } 111 | __syncthreads(); 112 | } 113 | } 114 | 115 | at::Tensor lpc_cuda_wrapper(const at::Tensor& x, const at::Tensor& a, 116 | const at::Tensor& zi) { 117 | TORCH_CHECK(x.is_floating_point() || x.is_complex(), 118 | "Input must be floating point or complex"); 119 | TORCH_CHECK(a.scalar_type() == x.scalar_type(), 120 | "Coefficients must have the same scalar type as input"); 121 | TORCH_CHECK(zi.scalar_type() == x.scalar_type(), 122 | "Initial conditions must have the same scalar type as input"); 123 | 124 | TORCH_CHECK(x.dim() == 2, "Input must be 2D"); 125 | TORCH_CHECK(zi.dim() == 2, "Initial conditions must be 2D"); 126 | TORCH_CHECK(x.size(0) == zi.size(0), 127 | "Batch size of input and initial conditions must match"); 128 | 129 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 130 | 131 | auto a_contiguous = a.contiguous(); 132 | 133 | at::Tensor out; 134 | auto order = a_contiguous.size(2); 135 | assert(order <= 1024 && "LPC order must be less than or equal to 1024"); 136 | auto threads_per_block = order; 137 | 138 | if (x.is_floating_point()) { 139 | out = at::cat({zi.flip(1), x}, 1).contiguous(); 140 | AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "lpc_cuda", [&] { 141 | auto padded_y = out.mutable_data_ptr(); 142 | auto A = a_contiguous.const_data_ptr(); 143 | auto B = x.size(0); 144 | auto T = x.size(1); 145 | 146 | lpc_cuda_kernel<<>>( 148 | padded_y, A, B, T, order); 149 | }); 150 | } else { 151 | auto out_real = 152 | at::cat({at::real(zi).flip(1), at::real(x)}, 1).contiguous(); 153 | auto out_imag = 154 | at::cat({at::imag(zi).flip(1), at::imag(x)}, 1).contiguous(); 155 | auto a_real = at::real(a_contiguous).contiguous(); 156 | auto a_imag = at::imag(a_contiguous).contiguous(); 157 | AT_DISPATCH_FLOATING_TYPES( 158 | out_real.scalar_type(), "lpc_cuda_complex", [&] { 159 | auto padded_y_real = out_real.mutable_data_ptr(); 160 | auto padded_y_imag = out_imag.mutable_data_ptr(); 161 | auto A_real = a_real.const_data_ptr(); 162 | auto A_imag = a_imag.const_data_ptr(); 163 | auto B = x.size(0); 164 | auto T = x.size(1); 165 | 166 | lpc_cuda_kernel_complex 167 | <<>>( 169 | padded_y_real, padded_y_imag, A_real, A_imag, B, T, 170 | order); 171 | }); 172 | out = at::view_as_complex(at::stack({out_real, out_imag}, -1)); 173 | } 174 | return out.slice(1, order, out.size(1)).contiguous(); 175 | } 176 | 177 | TORCH_LIBRARY_IMPL(torchlpc, CUDA, m) { m.impl("lpc", &lpc_cuda_wrapper); } -------------------------------------------------------------------------------- /torchlpc/csrc/scan_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | extern "C" { 10 | /* Creates a dummy empty _C module that can be imported from Python. 11 | The import from Python will load the .so associated with this extension 12 | built from this file, so that all the TORCH_LIBRARY calls below are run.*/ 13 | PyObject *PyInit__C(void) { 14 | static struct PyModuleDef module_def = { 15 | PyModuleDef_HEAD_INIT, 16 | "_C", /* name of module */ 17 | NULL, /* module documentation, may be NULL */ 18 | -1, /* size of per-interpreter state of the module, 19 | or -1 if the module keeps state in global variables. */ 20 | NULL, /* methods */ 21 | }; 22 | return PyModule_Create(&module_def); 23 | } 24 | } 25 | 26 | template 27 | void scan_cpu(const at::Tensor &input, const at::Tensor &weights, 28 | const at::Tensor &initials, const at::Tensor &output) { 29 | TORCH_CHECK(input.dim() == 2, "Input must be 2D"); 30 | TORCH_CHECK(initials.dim() == 1, "Initials must be 1D"); 31 | TORCH_CHECK(weights.sizes() == input.sizes(), 32 | "Weights must have the same size as input"); 33 | TORCH_CHECK(output.sizes() == input.sizes(), 34 | "Output must have the same size as input"); 35 | TORCH_CHECK(initials.size(0) == input.size(0), 36 | "The first dimension of initials must be the same as the first " 37 | "dimension of input"); 38 | TORCH_INTERNAL_ASSERT(input.device().is_cpu(), "Input must be on CPU"); 39 | TORCH_INTERNAL_ASSERT(initials.device().is_cpu(), 40 | "Initials must be on CPU"); 41 | TORCH_INTERNAL_ASSERT(weights.device().is_cpu(), "Weights must be on CPU"); 42 | TORCH_INTERNAL_ASSERT(output.device().is_cpu(), "Output must be on CPU"); 43 | TORCH_INTERNAL_ASSERT(output.is_contiguous(), "Output must be contiguous"); 44 | 45 | auto input_contiguous = input.contiguous(); 46 | auto weights_contiguous = weights.contiguous(); 47 | auto initials_contiguous = initials.contiguous(); 48 | 49 | auto n_batch = input.size(0); 50 | auto T = input.size(1); 51 | auto total_size = input.numel(); 52 | 53 | std::pair buffer[total_size]; 54 | 55 | const scalar_t *input_ptr = input_contiguous.const_data_ptr(); 56 | const scalar_t *initials_ptr = 57 | initials_contiguous.const_data_ptr(); 58 | const scalar_t *weights_ptr = weights_contiguous.const_data_ptr(); 59 | scalar_t *output_ptr = output.mutable_data_ptr(); 60 | 61 | std::transform(weights_ptr, weights_ptr + total_size, input_ptr, buffer, 62 | [](const scalar_t &a, const scalar_t &b) { 63 | return std::make_pair(a, b); 64 | }); 65 | 66 | at::parallel_for(0, n_batch, 1, [&](int64_t start, int64_t end) { 67 | for (auto b = start; b < end; b++) { 68 | std::inclusive_scan( 69 | buffer + b * T, buffer + (b + 1) * T, buffer + b * T, 70 | [](const std::pair &a, 71 | const std::pair &b) { 72 | return std::make_pair(a.first * b.first, 73 | a.second * b.first + b.second); 74 | }, 75 | std::make_pair((scalar_t)1.0, initials_ptr[b])); 76 | } 77 | }); 78 | 79 | std::transform( 80 | buffer, buffer + total_size, output_ptr, 81 | [](const std::pair &a) { return a.second; }); 82 | } 83 | 84 | template 85 | void lpc_cpu_core(const torch::Tensor &a, const torch::Tensor &padded_out) { 86 | // Ensure input dimensions are correct 87 | TORCH_CHECK(a.dim() == 3, "a must be 3-dimensional"); 88 | TORCH_CHECK(padded_out.dim() == 2, "out must be 2-dimensional"); 89 | TORCH_CHECK(padded_out.size(0) == a.size(0), 90 | "Batch size of out and x must match"); 91 | TORCH_CHECK(padded_out.size(1) == (a.size(1) + a.size(2)), 92 | "Time dimension of out must match x and a"); 93 | TORCH_INTERNAL_ASSERT(a.device().is_cpu(), "a must be on CPU"); 94 | TORCH_INTERNAL_ASSERT(padded_out.device().is_cpu(), 95 | "Output must be on CPU"); 96 | TORCH_INTERNAL_ASSERT(padded_out.is_contiguous(), 97 | "Output must be contiguous"); 98 | 99 | // Get the dimensions 100 | const auto B = a.size(0); 101 | const auto T = a.size(1); 102 | const auto order = a.size(2); 103 | 104 | auto a_contiguous = a.contiguous(); 105 | 106 | const scalar_t *a_ptr = a_contiguous.const_data_ptr(); 107 | scalar_t *out_ptr = padded_out.mutable_data_ptr(); 108 | 109 | at::parallel_for(0, B, 1, [&](int64_t start, int64_t end) { 110 | for (auto b = start; b < end; b++) { 111 | auto out_offset = b * (T + order) + order; 112 | auto a_offset = b * T * order; 113 | for (int64_t t = 0; t < T; t++) { 114 | scalar_t y = out_ptr[out_offset + t]; 115 | for (int64_t i = 0; i < order; i++) { 116 | y -= a_ptr[a_offset + t * order + i] * 117 | out_ptr[out_offset + t - i - 1]; 118 | } 119 | out_ptr[out_offset + t] = y; 120 | } 121 | } 122 | }); 123 | } 124 | 125 | at::Tensor scan_cpu_wrapper(const at::Tensor &input, const at::Tensor &weights, 126 | const at::Tensor &initials) { 127 | TORCH_CHECK(input.is_floating_point() || input.is_complex(), 128 | "Input must be floating point or complex"); 129 | TORCH_CHECK(initials.scalar_type() == input.scalar_type(), 130 | "Initials must have the same scalar type as input"); 131 | TORCH_CHECK(weights.scalar_type() == input.scalar_type(), 132 | "Weights must have the same scalar type as input"); 133 | 134 | auto output = at::empty_like(input); 135 | 136 | AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( 137 | input.scalar_type(), "scan_cpu", 138 | [&] { scan_cpu(input, weights, initials, output); }); 139 | return output; 140 | } 141 | 142 | at::Tensor lpc_cpu(const at::Tensor &x, const at::Tensor &a, 143 | const at::Tensor &zi) { 144 | TORCH_CHECK(x.is_floating_point() || x.is_complex(), 145 | "Input must be floating point or complex"); 146 | TORCH_CHECK(a.scalar_type() == x.scalar_type(), 147 | "Coefficients must have the same scalar type as input"); 148 | TORCH_CHECK(zi.scalar_type() == x.scalar_type(), 149 | "Initial conditions must have the same scalar type as input"); 150 | 151 | TORCH_CHECK(x.dim() == 2, "Input must be 2D"); 152 | TORCH_CHECK(zi.dim() == 2, "Initial conditions must be 2D"); 153 | TORCH_CHECK(x.size(0) == zi.size(0), 154 | "Batch size of input and initial conditions must match"); 155 | 156 | auto out = at::cat({zi.flip(1), x}, 1).contiguous(); 157 | 158 | AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES( 159 | x.scalar_type(), "lpc_cpu", [&] { lpc_cpu_core(a, out); }); 160 | return out.slice(1, zi.size(1), out.size(1)).contiguous(); 161 | } 162 | 163 | TORCH_LIBRARY(torchlpc, m) { 164 | m.def("torchlpc::scan(Tensor a, Tensor b, Tensor c) -> Tensor"); 165 | m.def("torchlpc::lpc(Tensor a, Tensor b, Tensor c) -> Tensor"); 166 | } 167 | 168 | TORCH_LIBRARY_IMPL(torchlpc, CPU, m) { 169 | m.impl("scan", &scan_cpu_wrapper); 170 | m.impl("lpc", &lpc_cpu); 171 | } 172 | -------------------------------------------------------------------------------- /torchlpc/parallel_scan.py: -------------------------------------------------------------------------------- 1 | from numba import cuda 2 | 3 | WARPSIZE = 32 4 | 5 | # implementation was translated from https://github.com/eamartin/parallelizing_linear_rnns/blob/master/linear_recurrent_net/linear_recurrence.cu 6 | 7 | 8 | @cuda.jit(device=True) 9 | def divide_work(n_jobs, n_workers, worker_idx) -> tuple: 10 | cd = (n_jobs + n_workers - 1) // n_workers 11 | d, doing_cd = divmod(n_jobs, n_workers) 12 | if worker_idx < doing_cd: 13 | x = cd * worker_idx 14 | y = x + cd 15 | else: 16 | x = cd * doing_cd + d * (worker_idx - doing_cd) 17 | y = x + d 18 | return x, y 19 | 20 | 21 | @cuda.jit(device=True) 22 | def compute_warp_start_stop(blockIdx, warp_idx, n_blocks, n_steps): 23 | block_start, block_stop = divide_work(n_steps, n_blocks, blockIdx) 24 | block_jobs = block_stop - block_start 25 | 26 | warp_start, warp_stop = divide_work(block_jobs, WARPSIZE, warp_idx) 27 | warp_start += block_start 28 | warp_stop += block_start 29 | 30 | return warp_start, warp_stop 31 | 32 | 33 | @cuda.jit 34 | def reduction_kernel( 35 | decay, impulses, initial_state, decay_storage, h_storage, n_dims, n_steps 36 | ): 37 | warp, lane = divmod(cuda.threadIdx.x, WARPSIZE) 38 | 39 | storage_offset = cuda.blockIdx.x * (WARPSIZE + 1) 40 | 41 | warp_start, warp_stop = compute_warp_start_stop( 42 | cuda.blockIdx.x, lane, cuda.gridDim.x, n_steps 43 | ) 44 | 45 | # reduce within warp 46 | for i in range(warp, n_dims, (cuda.blockDim.x + WARPSIZE - 1) // WARPSIZE): 47 | cum_decay = 1.0 48 | h = 0.0 49 | if (cuda.blockIdx.x == 0) and (lane == 0): 50 | h = initial_state[i] 51 | 52 | for t in range(warp_start, warp_stop): 53 | cum_decay *= decay[i, t] 54 | h = decay[i, t] * h + impulses[i, t] 55 | 56 | decay_storage[lane + storage_offset, i] = cum_decay 57 | h_storage[lane + storage_offset, i] = h 58 | 59 | cuda.syncthreads() 60 | 61 | # reduce within block 62 | for i in range(cuda.threadIdx.x, n_dims, cuda.blockDim.x): 63 | cum_decay = 1.0 64 | h = 0.0 65 | for t in range(storage_offset, storage_offset + WARPSIZE): 66 | cum_decay *= decay_storage[t, i] 67 | h = decay_storage[t, i] * h + h_storage[t, i] 68 | 69 | decay_storage[WARPSIZE + storage_offset, i] = cum_decay 70 | h_storage[WARPSIZE + storage_offset, i] = h 71 | 72 | 73 | @cuda.jit 74 | def block_scan_kernel(decay_storage, h_storage, n_dims, n_blocks): 75 | for i in range( 76 | cuda.grid(1), 77 | n_dims, 78 | cuda.gridsize(1), 79 | ): 80 | for t in range(1, n_blocks): 81 | cur_idx = t * (WARPSIZE + 1) + WARPSIZE 82 | prev_idx = (t - 1) * (WARPSIZE + 1) + WARPSIZE 83 | h_storage[cur_idx, i] += h_storage[prev_idx, i] * decay_storage[cur_idx, i] 84 | decay_storage[cur_idx, i] *= decay_storage[prev_idx, i] 85 | 86 | 87 | @cuda.jit 88 | def warp_scan_kernel( 89 | decay, impulses, initial_state, out, decay_storage, h_storage, n_dims, n_steps 90 | ): 91 | warp, lane = divmod(cuda.threadIdx.x, WARPSIZE) 92 | 93 | for i in range(cuda.threadIdx.x, n_dims, cuda.blockDim.x): 94 | offset = cuda.blockIdx.x * (WARPSIZE + 1) 95 | for cur_idx in range(offset, offset + WARPSIZE): 96 | if cur_idx == 0: 97 | continue 98 | prev_idx = cur_idx - 1 99 | h_storage[cur_idx, i] = ( 100 | h_storage[prev_idx, i] * decay_storage[cur_idx, i] 101 | + h_storage[cur_idx, i] 102 | ) 103 | decay_storage[cur_idx, i] *= decay_storage[prev_idx, i] 104 | 105 | cuda.syncthreads() 106 | 107 | warp_start, warp_stop = compute_warp_start_stop( 108 | cuda.blockIdx.x, lane, cuda.gridDim.x, n_steps 109 | ) 110 | 111 | # scan within warp 112 | for i in range(warp, n_dims, (cuda.blockDim.x + WARPSIZE - 1) // WARPSIZE): 113 | if (cuda.blockIdx.x == 0) and (lane == 0): 114 | h = initial_state[i] 115 | else: 116 | h = h_storage[lane - 1 + cuda.blockIdx.x * (WARPSIZE + 1), i] 117 | 118 | for t in range(warp_start, warp_stop): 119 | h = decay[i, t] * h + impulses[i, t] 120 | out[i, t] = h 121 | 122 | 123 | def compute_linear_recurrence( 124 | decays, impulses, init_states, out, n_dims: int, n_steps: int 125 | ): 126 | n_blocks = min((n_steps + WARPSIZE - 1) // WARPSIZE, 128) 127 | 128 | reduction_mem_shape = (n_blocks * (WARPSIZE + 1), n_dims) 129 | decay_storage = cuda.device_array(reduction_mem_shape, dtype=decays.dtype) 130 | h_storage = cuda.device_array(reduction_mem_shape, dtype=impulses.dtype) 131 | 132 | reduction_kernel[n_blocks, 512]( 133 | decays, impulses, init_states, decay_storage, h_storage, n_dims, n_steps 134 | ) 135 | 136 | block_scan_kernel[n_blocks, 512](decay_storage, h_storage, n_dims, n_blocks) 137 | 138 | warp_scan_kernel[n_blocks, 512]( 139 | decays, impulses, init_states, out, decay_storage, h_storage, n_dims, n_steps 140 | ) 141 | 142 | 143 | if __name__ == "__main__": 144 | import numpy as np 145 | 146 | n_dims = 16 147 | n_steps = 20480 148 | decays = np.full((n_dims, n_steps), 0.9, dtype=np.float32) 149 | impulses = np.full((n_dims, n_steps), 0.0, dtype=np.float32) 150 | impulses[:, 0] = 1.0 151 | init_states = np.full(n_dims, 0.0, dtype=np.float32) 152 | 153 | decays = cuda.to_device(decays) 154 | impulses = cuda.to_device(impulses) 155 | init_states = cuda.to_device(init_states) 156 | out = cuda.device_array((n_dims, n_steps), dtype=np.float32) 157 | 158 | compute_linear_recurrence(decays, impulses, init_states, out, n_dims, n_steps) 159 | 160 | print(out.copy_to_host()) 161 | -------------------------------------------------------------------------------- /torchlpc/recurrence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Function 4 | from numba import cuda 5 | from typing import Tuple, Optional, Any, List 6 | 7 | from .parallel_scan import compute_linear_recurrence, WARPSIZE 8 | from .core import lpc_cuda, lpc_np 9 | from . import EXTENSION_LOADED 10 | 11 | if EXTENSION_LOADED: 12 | lpc_cuda_runner = torch.ops.torchlpc.lpc 13 | lpc_cpu_runner = torch.ops.torchlpc.lpc 14 | 15 | scan_cuda_runner = torch.ops.torchlpc.scan 16 | scan_cpu_runner = torch.ops.torchlpc.scan 17 | else: 18 | lpc_cuda_runner = lpc_cuda 19 | lpc_cpu_runner = lambda x, A, zi: torch.from_numpy( 20 | lpc_np(x.detach().numpy(), A.detach().numpy(), zi.detach().numpy()) 21 | ) 22 | 23 | scan_cuda_runner = lambda impulse, decay, initial_state: ( 24 | lambda out: ( 25 | out, 26 | compute_linear_recurrence( 27 | cuda.as_cuda_array(decay.detach()), 28 | cuda.as_cuda_array(impulse.detach()), 29 | cuda.as_cuda_array(initial_state.detach()), 30 | cuda.as_cuda_array(out), 31 | decay.shape[0], 32 | decay.shape[1], 33 | ), 34 | ) 35 | )(torch.empty_like(impulse))[0] 36 | scan_cpu_runner = lambda impulse, decay, initial_state: torch.from_numpy( 37 | lpc_np( 38 | impulse.detach().numpy(), 39 | -decay.unsqueeze(2).detach().numpy(), 40 | initial_state.unsqueeze(1).detach().numpy(), 41 | ) 42 | ) 43 | 44 | 45 | def _cuda_recurrence( 46 | impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor 47 | ) -> torch.Tensor: 48 | n_dims, n_steps = decay.shape 49 | if n_dims * WARPSIZE < n_steps: 50 | runner = scan_cuda_runner 51 | else: 52 | runner = lambda impulse, decay, initial_state: lpc_cuda_runner( 53 | impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1) 54 | ) 55 | return runner(impulse, decay, initial_state) 56 | 57 | 58 | def _cpu_recurrence( 59 | impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor 60 | ) -> torch.Tensor: 61 | num_threads = torch.get_num_threads() 62 | n_dims, _ = decay.shape 63 | # This is just a rough estimation of the computational cost 64 | if EXTENSION_LOADED and min(n_dims, num_threads) < num_threads / 3: 65 | runner = scan_cpu_runner 66 | else: 67 | runner = lambda impulse, decay, initial_state: lpc_cpu_runner( 68 | impulse, -decay.unsqueeze(2), initial_state.unsqueeze(1) 69 | ) 70 | return runner(impulse, decay, initial_state) 71 | 72 | 73 | class Recurrence(Function): 74 | @staticmethod 75 | def forward( 76 | decay: torch.Tensor, 77 | impulse: torch.Tensor, 78 | initial_state: torch.Tensor, 79 | ) -> torch.Tensor: 80 | if decay.is_cuda: 81 | out = _cuda_recurrence(impulse, decay, initial_state) 82 | else: 83 | out = _cpu_recurrence(impulse, decay, initial_state) 84 | return out 85 | 86 | @staticmethod 87 | def setup_context(ctx: Any, inputs: List[Any], output: Any) -> Any: 88 | decay, _, initial_state = inputs 89 | ctx.save_for_backward(decay, initial_state, output) 90 | ctx.save_for_forward(decay, initial_state, output) 91 | 92 | @staticmethod 93 | def backward( 94 | ctx: Any, grad_out: torch.Tensor 95 | ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: 96 | decay, initial_state, out = ctx.saved_tensors 97 | grad_decay = grad_impulse = grad_initial_state = None 98 | n_dims, _ = decay.shape 99 | 100 | padded_decay = F.pad(decay.unsqueeze(1), (0, 1)).squeeze(1) 101 | if ctx.needs_input_grad[2]: 102 | padded_grad_out = F.pad(grad_out.unsqueeze(1), (1, 0)).squeeze(1) 103 | else: 104 | padded_grad_out = grad_out 105 | padded_decay = padded_decay[:, 1:] 106 | 107 | init = padded_grad_out.new_zeros(n_dims) 108 | flipped_grad_impulse = Recurrence.apply( 109 | padded_decay.flip(1).conj_physical(), 110 | padded_grad_out.flip(1), 111 | init, 112 | ) 113 | 114 | if ctx.needs_input_grad[2]: 115 | grad_initial_state = flipped_grad_impulse[:, -1] 116 | flipped_grad_impulse = flipped_grad_impulse[:, :-1] 117 | 118 | if ctx.needs_input_grad[1]: 119 | grad_impulse = flipped_grad_impulse.flip(1) 120 | 121 | if ctx.needs_input_grad[0]: 122 | valid_out = out[:, :-1] 123 | padded_out = torch.cat([initial_state.unsqueeze(1), valid_out], dim=1) 124 | grad_decay = padded_out.conj_physical() * flipped_grad_impulse.flip(1) 125 | 126 | return grad_decay, grad_impulse, grad_initial_state 127 | 128 | @staticmethod 129 | def jvp( 130 | ctx: Any, 131 | grad_decay: torch.Tensor, 132 | grad_impulse: torch.Tensor, 133 | grad_initial_state: torch.Tensor, 134 | ) -> torch.Tensor: 135 | decay, initial_state, out = ctx.saved_tensors 136 | 137 | fwd_initial_state = ( 138 | grad_initial_state 139 | if grad_initial_state is not None 140 | else torch.zeros_like(initial_state) 141 | ) 142 | fwd_impulse = ( 143 | grad_impulse if grad_impulse is not None else torch.zeros_like(out) 144 | ) 145 | 146 | if grad_decay is not None: 147 | concat_out = torch.cat([initial_state.unsqueeze(1), out[:, :-1]], dim=1) 148 | fwd_decay = concat_out * grad_decay 149 | fwd_impulse = fwd_impulse + fwd_decay 150 | 151 | return Recurrence.apply(decay, fwd_impulse, fwd_initial_state) 152 | 153 | @staticmethod 154 | def vmap(info, in_dims, *args): 155 | def maybe_expand_bdim_at_front(x, x_bdim): 156 | if x_bdim is None: 157 | return x.expand(info.batch_size, *x.shape) 158 | return x.movedim(x_bdim, 0) 159 | 160 | decay, impulse, initial_state = tuple( 161 | map( 162 | lambda x: x.reshape(-1, *x.shape[2:]), 163 | map(maybe_expand_bdim_at_front, args, in_dims), 164 | ) 165 | ) 166 | 167 | out = Recurrence.apply(decay, impulse, initial_state) 168 | return out.reshape(info.batch_size, -1, *out.shape[1:]), 0 169 | 170 | 171 | RecurrenceCUDA = Recurrence 172 | --------------------------------------------------------------------------------