├── .github └── workflows │ └── publish.yaml ├── .gitignore ├── AUTHORS ├── LICENSE ├── README.md ├── causal_conv1d ├── __init__.py ├── causal_conv1d_interface.py ├── causal_conv1d_varlen.py └── cpp_functions.py ├── csrc ├── causal_conv1d.cpp ├── causal_conv1d.h ├── causal_conv1d_bwd.cu ├── causal_conv1d_common.h ├── causal_conv1d_fwd.cu ├── causal_conv1d_update.cu └── static_switch.h ├── rocm_patch └── rocm6_0.patch ├── setup.py └── tests └── test_causal_conv1d.py /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | # This workflow will: 2 | # - Create a new Github release 3 | # - Build wheels for supported architectures 4 | # - Deploy the wheels to the Github release 5 | # - Release the static code to PyPi 6 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 7 | 8 | name: Build wheels and deploy 9 | 10 | on: 11 | create: 12 | tags: 13 | - v* 14 | 15 | jobs: 16 | 17 | setup_release: 18 | name: Create Release 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: Get the tag version 22 | id: extract_branch 23 | run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} 24 | shell: bash 25 | 26 | - name: Create Release 27 | id: create_release 28 | uses: actions/create-release@v1 29 | env: 30 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 31 | with: 32 | tag_name: ${{ steps.extract_branch.outputs.branch }} 33 | release_name: ${{ steps.extract_branch.outputs.branch }} 34 | 35 | build_wheels: 36 | name: Build Wheel 37 | needs: setup_release 38 | runs-on: ${{ matrix.os }} 39 | 40 | strategy: 41 | fail-fast: false 42 | matrix: 43 | # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the 44 | # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. 45 | os: [ubuntu-20.04] 46 | python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] 47 | torch-version: ['2.1.2', '2.2.2', '2.3.1', '2.4.0', '2.5.1', '2.6.0.dev20241001'] 48 | cuda-version: ['11.8.0', '12.3.2'] 49 | # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. 50 | # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. 51 | # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) 52 | # when building without C++11 ABI and using it on nvcr images. 53 | cxx11_abi: ['FALSE', 'TRUE'] 54 | exclude: 55 | # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix 56 | # Pytorch < 2.2 does not support Python 3.12 57 | - torch-version: '2.1.2' 58 | python-version: '3.12' 59 | # Pytorch < 2.5 does not support Python 3.13 60 | - torch-version: '2.1.2' 61 | python-version: '3.13' 62 | - torch-version: '2.2.2' 63 | python-version: '3.13' 64 | - torch-version: '2.3.1' 65 | python-version: '3.13' 66 | - torch-version: '2.4.0' 67 | python-version: '3.13' 68 | 69 | steps: 70 | - name: Checkout 71 | uses: actions/checkout@v4 72 | 73 | - name: Set up Python 74 | uses: actions/setup-python@v5 75 | with: 76 | python-version: ${{ matrix.python-version }} 77 | 78 | - name: Set CUDA and PyTorch versions 79 | run: | 80 | echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV 81 | echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV 82 | echo "WHEEL_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV 83 | echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV 84 | 85 | - name: Free up disk space 86 | if: ${{ runner.os == 'Linux' }} 87 | # https://github.com/easimon/maximize-build-space/blob/master/action.yml 88 | # https://github.com/easimon/maximize-build-space/tree/test-report 89 | run: | 90 | sudo rm -rf /usr/share/dotnet 91 | sudo rm -rf /opt/ghc 92 | sudo rm -rf /opt/hostedtoolcache/CodeQL 93 | 94 | - name: Set up swap space 95 | if: runner.os == 'Linux' 96 | uses: pierotofy/set-swap-space@v1.0 97 | with: 98 | swap-size-gb: 10 99 | 100 | - name: Install CUDA ${{ matrix.cuda-version }} 101 | if: ${{ matrix.cuda-version != 'cpu' }} 102 | uses: Jimver/cuda-toolkit@v0.2.19 103 | id: cuda-toolkit 104 | with: 105 | cuda: ${{ matrix.cuda-version }} 106 | linux-local-args: '["--toolkit"]' 107 | # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 108 | # method: ${{ (matrix.cuda-version == '11.8.0' || matrix.cuda-version == '12.1.0') && 'network' || 'local' }} 109 | method: 'network' 110 | sub-packages: '["nvcc"]' 111 | 112 | - name: Install PyTorch ${{ matrix.torch-version }}+cu${{ matrix.cuda-version }} 113 | run: | 114 | pip install --upgrade pip 115 | # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools 116 | pip install setuptools==68.0.0 117 | # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error 118 | # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable 119 | pip install typing-extensions==4.12.2 120 | # We want to figure out the CUDA version to download pytorch 121 | # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 122 | # This code is ugly, maybe there's a better way to do this. 123 | export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ 124 | minv = {'2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118, '2.5': 118, '2.6': 118}[env['MATRIX_TORCH_VERSION']]; \ 125 | maxv = {'2.1': 121, '2.2': 121, '2.3': 121, '2.4': 124, '2.5': 124, '2.6': 124}[env['MATRIX_TORCH_VERSION']]; \ 126 | print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ 127 | ) 128 | if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then 129 | # pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} 130 | # Can't use --no-deps because we need cudnn etc. 131 | # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 132 | pip install jinja2 133 | pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl 134 | pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl 135 | else 136 | pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} 137 | fi 138 | nvcc --version 139 | python --version 140 | python -c "import torch; print('PyTorch:', torch.__version__)" 141 | python -c "import torch; print('CUDA:', torch.version.cuda)" 142 | python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" 143 | shell: 144 | bash 145 | 146 | - name: Build wheel 147 | run: | 148 | # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 149 | # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 150 | # However this still fails so I'm using a newer version of setuptools 151 | pip install setuptools==68.0.0 152 | pip install ninja packaging wheel 153 | export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH 154 | export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH 155 | # Limit MAX_JOBS otherwise the github runner goes OOM 156 | MAX_JOBS=2 CAUSAL_CONV1D_FORCE_BUILD="TRUE" CAUSAL_CONV1D_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist 157 | tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} 158 | wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") 159 | ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} 160 | echo "wheel_name=${wheel_name}" >> $GITHUB_ENV 161 | 162 | - name: Log Built Wheels 163 | run: | 164 | ls dist 165 | 166 | - name: Get the tag version 167 | id: extract_branch 168 | run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} 169 | 170 | - name: Get Release with tag 171 | id: get_current_release 172 | uses: joutvhu/get-release@v1 173 | with: 174 | tag_name: ${{ steps.extract_branch.outputs.branch }} 175 | env: 176 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 177 | 178 | - name: Upload Release Asset 179 | id: upload_release_asset 180 | uses: actions/upload-release-asset@v1 181 | env: 182 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 183 | with: 184 | upload_url: ${{ steps.get_current_release.outputs.upload_url }} 185 | asset_path: ./dist/${{env.wheel_name}} 186 | asset_name: ${{env.wheel_name}} 187 | asset_content_type: application/* 188 | 189 | publish_package: 190 | name: Publish package 191 | needs: [build_wheels] 192 | 193 | runs-on: ubuntu-latest 194 | 195 | steps: 196 | - uses: actions/checkout@v4 197 | 198 | - uses: actions/setup-python@v5 199 | with: 200 | python-version: '3.10' 201 | 202 | - name: Install dependencies 203 | run: | 204 | pip install ninja packaging setuptools wheel twine 205 | # We don't want to download anything CUDA-related here 206 | pip install torch --index-url https://download.pytorch.org/whl/cpu 207 | 208 | - name: Build core package 209 | env: 210 | CAUSAL_CONV1D_SKIP_CUDA_BUILD: "TRUE" 211 | run: | 212 | python setup.py sdist --dist-dir=dist 213 | 214 | - name: Deploy 215 | env: 216 | TWINE_USERNAME: "__token__" 217 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 218 | run: | 219 | python -m twine upload dist/* 220 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | *.egg-info/ 3 | build/ 4 | **.so 5 | *.hip 6 | *_hip.* -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | Tri Dao, tri@tridao.me 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Causal depthwise conv1d in CUDA with a PyTorch interface 2 | 3 | Features: 4 | - Support fp32, fp16, bf16. 5 | - Kernel size 2, 3, 4. 6 | 7 | ## How to use 8 | 9 | ```python 10 | from causal_conv1d import causal_conv1d_fn 11 | ``` 12 | 13 | ```python 14 | def causal_conv1d_fn(x, weight, bias=None, activation=None): 15 | """ 16 | x: (batch, dim, seqlen) 17 | weight: (dim, width) 18 | bias: (dim,) 19 | activation: either None or "silu" or "swish" 20 | 21 | out: (batch, dim, seqlen) 22 | """ 23 | ``` 24 | 25 | Equivalent to: 26 | ```python 27 | import torch.nn.functional as F 28 | 29 | F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)[..., :seqlen] 30 | ``` 31 | 32 | ## Additional Prerequisites for AMD cards 33 | 34 | ### Patching ROCm 35 | 36 | If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards. 37 | 38 | 1. Locate your ROCm installation directory. This is typically found at `/opt/rocm/`, but may vary depending on your installation. 39 | 40 | 2. Apply the Patch. Run with `sudo` in case you encounter permission issues. 41 | ```bash 42 | patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch 43 | ``` 44 | -------------------------------------------------------------------------------- /causal_conv1d/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.5.0.post8" 2 | 3 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update 4 | -------------------------------------------------------------------------------- /causal_conv1d/causal_conv1d_interface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Tri Dao. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from causal_conv1d.cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function 7 | 8 | 9 | class CausalConv1dFn(torch.autograd.Function): 10 | @staticmethod 11 | def forward( 12 | ctx, 13 | x, 14 | weight, 15 | bias=None, 16 | seq_idx=None, 17 | initial_states=None, 18 | return_final_states=False, 19 | final_states_out=None, 20 | activation=None, 21 | ): 22 | if activation not in [None, "silu", "swish"]: 23 | raise NotImplementedError("activation must be None, silu, or swish") 24 | if x.stride(2) != 1 and x.stride(1) != 1: 25 | x = x.contiguous() 26 | bias = bias.contiguous() if bias is not None else None 27 | if seq_idx is not None: 28 | assert ( 29 | initial_states is None 30 | ), "initial_states must be None if seq_idx is not None" 31 | assert ( 32 | not return_final_states 33 | ), "If seq_idx is not None, we don't return final_states_out" 34 | seq_idx = seq_idx.contiguous() if seq_idx is not None else None 35 | if initial_states is not None and ( 36 | initial_states.stride(2) != 1 and initial_states.stride(1) != 1 37 | ): 38 | initial_states = initial_states.contiguous() 39 | if return_final_states: 40 | assert ( 41 | x.stride(1) == 1 42 | ), "Only channel-last layout support returning final_states_out" 43 | if final_states_out is not None: 44 | assert ( 45 | final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 46 | ) 47 | else: 48 | batch, dim, seqlen = x.shape 49 | width = weight.shape[1] 50 | final_states_out = torch.empty( 51 | batch, width - 1, dim, device=x.device, dtype=x.dtype 52 | ).transpose(1, 2) 53 | else: 54 | final_states_out = None 55 | ctx.activation = activation in ["silu", "swish"] 56 | out = causal_conv1d_fwd_function( 57 | x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation 58 | ) 59 | ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) 60 | ctx.return_final_states = return_final_states 61 | ctx.return_dinitial_states = ( 62 | initial_states is not None and initial_states.requires_grad 63 | ) 64 | return out if not return_final_states else (out, final_states_out) 65 | 66 | @staticmethod 67 | def backward(ctx, dout, *args): 68 | x, weight, bias, seq_idx, initial_states = ctx.saved_tensors 69 | dfinal_states = args[0] if ctx.return_final_states else None 70 | if dout.stride(2) != 1 and dout.stride(1) != 1: 71 | dout = dout.contiguous() 72 | # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the 73 | # backward of conv1d with the backward of chunk). 74 | # Here we just pass in None and dx will be allocated in the C++ code. 75 | dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( 76 | x, 77 | weight, 78 | bias, 79 | dout, 80 | seq_idx, 81 | initial_states, 82 | dfinal_states, 83 | None, 84 | ctx.return_dinitial_states, 85 | ctx.activation, 86 | ) 87 | return ( 88 | dx, 89 | dweight, 90 | dbias if bias is not None else None, 91 | None, 92 | dinitial_states if initial_states is not None else None, 93 | None, 94 | None, 95 | None, 96 | ) 97 | 98 | 99 | def causal_conv1d_fn( 100 | x, 101 | weight, 102 | bias=None, 103 | seq_idx=None, 104 | initial_states=None, 105 | return_final_states=False, 106 | final_states_out=None, 107 | activation=None, 108 | ): 109 | """ 110 | x: (batch, dim, seqlen) 111 | weight: (dim, width) 112 | bias: (dim,) 113 | seq_idx: (batch, seqlen) 114 | initial_states: (batch, dim, width - 1) 115 | final_states_out: (batch, dim, width - 1), to be written to 116 | activation: either None or "silu" or "swish" 117 | 118 | out: (batch, dim, seqlen) 119 | """ 120 | return CausalConv1dFn.apply( 121 | x, 122 | weight, 123 | bias, 124 | seq_idx, 125 | initial_states, 126 | return_final_states, 127 | final_states_out, 128 | activation, 129 | ) 130 | 131 | 132 | def causal_conv1d_ref( 133 | x, 134 | weight, 135 | bias=None, 136 | initial_states=None, 137 | return_final_states=False, 138 | final_states_out=None, 139 | activation=None, 140 | ): 141 | """ 142 | x: (batch, dim, seqlen) 143 | weight: (dim, width) 144 | bias: (dim,) 145 | initial_states: (batch, dim, width - 1) 146 | final_states_out: (batch, dim, width - 1) 147 | 148 | out: (batch, dim, seqlen) 149 | """ 150 | if activation not in [None, "silu", "swish"]: 151 | raise NotImplementedError("activation must be None, silu, or swish") 152 | dtype_in = x.dtype 153 | x = x.to(weight.dtype) 154 | seqlen = x.shape[-1] 155 | dim, width = weight.shape 156 | if initial_states is None: 157 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) 158 | else: 159 | x = torch.cat([initial_states, x], dim=-1) 160 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) 161 | out = out[..., :seqlen] 162 | if return_final_states: 163 | final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( 164 | dtype_in 165 | ) # (batch, dim, width - 1) 166 | if final_states_out is not None: 167 | final_states_out.copy_(final_states) 168 | else: 169 | final_states_out = final_states 170 | out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) 171 | return out if not return_final_states else (out, final_states_out) 172 | 173 | 174 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): 175 | """ 176 | x: (batch, dim) or (batch, dim, seqlen) 177 | conv_state: (batch, dim, state_len), where state_len >= width - 1 178 | weight: (dim, width) 179 | bias: (dim,) 180 | cache_seqlens: (batch,), dtype int32. 181 | If not None, the conv_state is treated as a circular buffer. 182 | The conv_state will be updated by copying x to the conv_state starting at the index 183 | @cache_seqlens % state_len. 184 | conv_state_indices: (batch,), dtype int32 185 | If None, the conv_state is a larger tensor along the batch dim, 186 | and we are selecting the batch coords specified by conv_state_indices. 187 | Useful for a continuous batching scenario. 188 | 189 | out: (batch, dim) or (batch, dim, seqlen) 190 | """ 191 | if activation not in [None, "silu", "swish"]: 192 | raise NotImplementedError("activation must be None, silu, or swish") 193 | activation = activation in ["silu", "swish"] 194 | unsqueeze = x.dim() == 2 195 | if unsqueeze: 196 | x = x.unsqueeze(-1) 197 | out = causal_conv1d_update_function( 198 | x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices 199 | ) 200 | if unsqueeze: 201 | out = out.squeeze(-1) 202 | return out 203 | 204 | 205 | def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): 206 | """ 207 | x: (batch, dim) or (batch, dim, seqlen) 208 | conv_state: (batch, dim, state_len), where state_len >= width - 1 209 | weight: (dim, width) 210 | bias: (dim,) 211 | cache_seqlens: (batch,), dtype int32. 212 | If not None, the conv_state is treated as a circular buffer. 213 | The conv_state will be updated by copying x to the conv_state starting at the index 214 | @cache_seqlens % state_len before performing the convolution. 215 | 216 | out: (batch, dim) or (batch, dim, seqlen) 217 | """ 218 | if activation not in [None, "silu", "swish"]: 219 | raise NotImplementedError("activation must be None, silu, or swish") 220 | dtype_in = x.dtype 221 | unsqueeze = x.dim() == 2 222 | if unsqueeze: 223 | x = x.unsqueeze(-1) 224 | batch, dim, seqlen = x.shape 225 | width = weight.shape[1] 226 | state_len = conv_state.shape[-1] 227 | assert conv_state.shape == (batch, dim, state_len) 228 | assert weight.shape == (dim, width) 229 | if cache_seqlens is None: 230 | x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) 231 | conv_state.copy_(x_new[:, :, -state_len:]) 232 | else: 233 | width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) 234 | width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) 235 | x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) 236 | copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) 237 | copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) 238 | conv_state.scatter_(2, copy_idx, x) 239 | out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] 240 | if unsqueeze: 241 | out = out.squeeze(-1) 242 | return (out if activation is None else F.silu(out)).to(dtype=dtype_in) 243 | -------------------------------------------------------------------------------- /causal_conv1d/causal_conv1d_varlen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | import triton 5 | import triton.language as tl 6 | 7 | 8 | @triton.jit 9 | def _causal_conv1d_varlen_states( 10 | X, 11 | CU_SEQLENS, 12 | STATES, 13 | state_len, 14 | dim, 15 | stride_x_seqlen, stride_x_dim, 16 | stride_states_batch, stride_states_seqlen, stride_states_dim, 17 | BLOCK_M: tl.constexpr, 18 | BLOCK_N: tl.constexpr 19 | ): 20 | batch_idx = tl.program_id(2) 21 | STATES += batch_idx * stride_states_batch 22 | end_idx = tl.load(CU_SEQLENS + batch_idx + 1) 23 | start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) 24 | rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) 25 | cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) 26 | x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, 27 | mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), 28 | other=0) 29 | rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) 30 | tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, 31 | x, 32 | mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) 33 | 34 | 35 | def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: 36 | """ 37 | Forward pass only, does not support backward pass. 38 | Parameters: 39 | x: (total_tokens, dim) 40 | cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. 41 | state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. 42 | If some of those elements belong to a different sequence, the value of the states will be zero. 43 | Return: 44 | states: (batch, dim, state_len) 45 | """ 46 | _, dim = x.shape 47 | batch = cu_seqlens.shape[0] - 1 48 | cu_seqlens = cu_seqlens.contiguous() 49 | states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) 50 | BLOCK_M = min(triton.next_power_of_2(state_len), 16) 51 | BLOCK_N = min(triton.next_power_of_2(dim), 256) 52 | grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) 53 | with torch.cuda.device(x.device.index): 54 | _causal_conv1d_varlen_states[grid]( 55 | x, 56 | cu_seqlens, 57 | states, 58 | state_len, 59 | dim, 60 | x.stride(0), x.stride(1), 61 | states.stride(0), states.stride(2), states.stride(1), 62 | BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N 63 | ) 64 | return states 65 | 66 | 67 | def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: 68 | """ 69 | Forward pass only, does not support backward pass. 70 | Parameters: 71 | x: (total_tokens, dim) 72 | cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. 73 | state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. 74 | If some of those elements belong to a different sequence, the value of the states will be zero. 75 | Return: 76 | states: (batch, dim, state_len) 77 | """ 78 | _, dim = x.shape 79 | batch = cu_seqlens.shape[0] - 1 80 | cu_seqlens = cu_seqlens.contiguous() 81 | states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) 82 | for i in range(batch): 83 | end_idx = cu_seqlens[i + 1] 84 | start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) 85 | states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T 86 | return states 87 | -------------------------------------------------------------------------------- /causal_conv1d/cpp_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Tri Dao. 2 | 3 | import torch 4 | 5 | import causal_conv1d_cuda 6 | 7 | 8 | LIBRARY_NAME = "DaoAILab" 9 | 10 | 11 | @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_fwd_cpp", mutates_args={"out", "final_states_out"}) 12 | def _causal_conv1d_fwd_cpp( 13 | x: torch.Tensor, 14 | weight: torch.Tensor, 15 | bias: torch.Tensor | None, 16 | seq_idx: torch.Tensor | None, 17 | initial_states: torch.Tensor | None, 18 | out: torch.Tensor, 19 | final_states_out: torch.Tensor | None, 20 | silu_activation: bool, 21 | ) -> None: 22 | causal_conv1d_cuda.causal_conv1d_fwd( 23 | x, 24 | weight, 25 | bias, 26 | seq_idx, 27 | initial_states, 28 | out, 29 | final_states_out, 30 | silu_activation, 31 | ) 32 | 33 | 34 | @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_bwd_cpp", mutates_args={ 35 | "dfinal_states", 36 | "dx", 37 | "dweight", 38 | "dbias", 39 | "dinitial_states", 40 | }) 41 | def _causal_conv1d_bwd_cpp( 42 | x: torch.Tensor, 43 | weight: torch.Tensor, 44 | bias: torch.Tensor | None, 45 | dout: torch.Tensor, 46 | seq_idx: torch.Tensor | None, 47 | initial_states: torch.Tensor | None, 48 | dfinal_states: torch.Tensor | None, 49 | dx: torch.Tensor, 50 | dweight: torch.Tensor, 51 | dbias: torch.Tensor | None, 52 | dinitial_states: torch.Tensor, 53 | silu_activation: bool, 54 | ) -> None: 55 | causal_conv1d_cuda.causal_conv1d_bwd( 56 | x, 57 | weight, 58 | bias, 59 | dout, 60 | seq_idx, 61 | initial_states, 62 | dfinal_states, 63 | dx, 64 | dweight, 65 | dbias, 66 | dinitial_states, 67 | silu_activation, 68 | ) 69 | 70 | 71 | @torch.library.custom_op(f"{LIBRARY_NAME}::_causal_conv1d_update_cpp", mutates_args={"out", "conv_state"}) 72 | def _causal_conv1d_update_cpp( 73 | x: torch.Tensor, 74 | conv_state: torch.Tensor, 75 | weight: torch.Tensor, 76 | bias: torch.Tensor | None, 77 | out: torch.Tensor, 78 | silu_activation: bool, 79 | cache_seqlens: torch.Tensor | None, 80 | conv_state_indices: torch.Tensor | None, 81 | ) -> None: 82 | causal_conv1d_cuda.causal_conv1d_update( 83 | x, 84 | conv_state, 85 | weight, 86 | bias, 87 | out, 88 | silu_activation, 89 | cache_seqlens, 90 | conv_state_indices 91 | ) 92 | 93 | 94 | def causal_conv1d_fwd_function( 95 | x: torch.Tensor, 96 | weight: torch.Tensor, 97 | bias: torch.Tensor | None, 98 | seq_idx: torch.Tensor | None, 99 | initial_states: torch.Tensor | None, 100 | final_states_out: torch.Tensor | None, 101 | silu_activation: bool, 102 | ) -> torch.Tensor: 103 | out = torch.empty_like(x) 104 | _causal_conv1d_fwd_cpp( 105 | x=x, 106 | weight=weight, 107 | bias=bias, 108 | seq_idx=seq_idx, 109 | initial_states=initial_states, 110 | out=out, 111 | final_states_out=final_states_out, 112 | silu_activation=silu_activation, 113 | ) 114 | return out 115 | 116 | 117 | def causal_conv1d_bwd_function( 118 | x: torch.Tensor, 119 | weight: torch.Tensor, 120 | bias: torch.Tensor | None, 121 | dout: torch.Tensor, 122 | seq_idx: torch.Tensor | None, 123 | initial_states: torch.Tensor | None, 124 | dfinal_states: torch.Tensor | None, 125 | dx: torch.Tensor | None, 126 | return_dinitial_states: torch.Tensor, 127 | silu_activation: bool, 128 | ) -> tuple[torch.Tensor | None]: 129 | batch_size, dim = x.size()[:2] 130 | width = weight.size(-1) 131 | 132 | if dx is None: 133 | dx = torch.empty_like(x) 134 | dweight = torch.zeros_like(weight, dtype=torch.float32) 135 | dbias = torch.zeros_like(bias, dtype=torch.float32) 136 | dinitial_states = None 137 | if return_dinitial_states: 138 | dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) 139 | 140 | _causal_conv1d_bwd_cpp( 141 | x=x, 142 | weight=weight, 143 | bias=bias, 144 | dout=dout, 145 | seq_idx=seq_idx, 146 | initial_states=initial_states, 147 | dfinal_states=dfinal_states, 148 | dx=dx, 149 | dweight=dweight, 150 | dbias=dbias, 151 | dinitial_states=dinitial_states, 152 | silu_activation=silu_activation, 153 | ) 154 | 155 | dweight = dweight.type_as(weight) 156 | dbias = dbias.type_as(bias) 157 | return dx, dweight, dbias, dinitial_states 158 | 159 | 160 | def causal_conv1d_update_function( 161 | x: torch.Tensor, 162 | conv_state: torch.Tensor, 163 | weight: torch.Tensor, 164 | bias: torch.Tensor | None, 165 | silu_activation: bool, 166 | cache_seqlens: torch.Tensor | None, 167 | conv_state_indices: torch.Tensor | None, 168 | ) -> torch.Tensor: 169 | out = torch.empty_like(x) 170 | _causal_conv1d_update_cpp( 171 | x=x, 172 | conv_state=conv_state, 173 | weight=weight, 174 | bias=bias, 175 | out=out, 176 | silu_activation=silu_activation, 177 | cache_seqlens=cache_seqlens, 178 | conv_state_indices=conv_state_indices, 179 | ) 180 | return out 181 | -------------------------------------------------------------------------------- /csrc/causal_conv1d.cpp: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "causal_conv1d.h" 11 | 12 | #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") 13 | 14 | #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ 15 | if (ITYPE == at::ScalarType::Half) { \ 16 | using input_t = at::Half; \ 17 | __VA_ARGS__(); \ 18 | } else if (ITYPE == at::ScalarType::BFloat16) { \ 19 | using input_t = at::BFloat16; \ 20 | __VA_ARGS__(); \ 21 | } else if (ITYPE == at::ScalarType::Float) { \ 22 | using input_t = float; \ 23 | __VA_ARGS__(); \ 24 | } else { \ 25 | AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ 26 | } 27 | 28 | #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ 29 | if (WTYPE == at::ScalarType::Half) { \ 30 | using weight_t = at::Half; \ 31 | __VA_ARGS__(); \ 32 | } else if (WTYPE == at::ScalarType::BFloat16) { \ 33 | using weight_t = at::BFloat16; \ 34 | __VA_ARGS__(); \ 35 | } else if (WTYPE == at::ScalarType::Float) { \ 36 | using weight_t = float; \ 37 | __VA_ARGS__(); \ 38 | } else { \ 39 | AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ 40 | } 41 | 42 | template 43 | void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 44 | template 45 | void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 46 | 47 | template 48 | void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 49 | template 50 | void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 51 | 52 | template 53 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 54 | 55 | void set_conv_params_fwd(ConvParamsBase ¶ms, 56 | // sizes 57 | const size_t batch, 58 | const size_t dim, 59 | const size_t seqlen, 60 | const size_t width, 61 | // device pointers 62 | const at::Tensor x, 63 | const at::Tensor weight, 64 | const at::Tensor out, 65 | void* bias_ptr, 66 | bool silu_activation) { 67 | 68 | // Reset the parameters 69 | memset(¶ms, 0, sizeof(params)); 70 | 71 | params.batch = batch; 72 | params.dim = dim; 73 | params.seqlen = seqlen; 74 | params.width = width; 75 | 76 | params.silu_activation = silu_activation; 77 | 78 | // Set the pointers and strides. 79 | params.x_ptr = x.data_ptr(); 80 | params.weight_ptr = weight.data_ptr(); 81 | params.bias_ptr = bias_ptr; 82 | params.out_ptr = out.data_ptr(); 83 | // All stride are in elements, not bytes. 84 | params.x_batch_stride = x.stride(0); 85 | params.x_c_stride = x.stride(1); 86 | params.x_l_stride = x.stride(-1); 87 | params.weight_c_stride = weight.stride(0); 88 | params.weight_width_stride = weight.stride(1); 89 | params.out_batch_stride = out.stride(0); 90 | params.out_c_stride = out.stride(1); 91 | params.out_l_stride = out.stride(-1); 92 | } 93 | 94 | 95 | void set_conv_params_bwd(ConvParamsBwd ¶ms, 96 | // sizes 97 | const size_t batch, 98 | const size_t dim, 99 | const size_t seqlen, 100 | const size_t width, 101 | // device pointers 102 | const at::Tensor x, 103 | const at::Tensor weight, 104 | void* bias_ptr, 105 | const at::Tensor dout, 106 | const at::Tensor dx, 107 | const at::Tensor dweight, 108 | void* dbias_ptr, 109 | bool silu_activation) { 110 | // Pass in "dout" instead of "out", we're not gonna use "out" at all. 111 | set_conv_params_fwd(params, batch, dim, seqlen, width, 112 | x, weight, dout, bias_ptr, silu_activation); 113 | 114 | // Set the pointers and strides. 115 | params.dout_ptr = dout.data_ptr(); 116 | params.dx_ptr = dx.data_ptr(); 117 | params.dweight_ptr = dweight.data_ptr(); 118 | params.dbias_ptr = dbias_ptr; 119 | // All stride are in elements, not bytes. 120 | params.dout_batch_stride = dout.stride(0); 121 | params.dout_c_stride = dout.stride(1); 122 | params.dout_l_stride = dout.stride(2); 123 | params.dweight_c_stride = dweight.stride(0); 124 | params.dweight_width_stride = dweight.stride(1); 125 | params.dx_batch_stride = dx.stride(0); 126 | params.dx_c_stride = dx.stride(1); 127 | params.dx_l_stride = dx.stride(2); 128 | } 129 | 130 | void 131 | causal_conv1d_fwd(const at::Tensor &x, 132 | const at::Tensor &weight, 133 | const c10::optional &bias_, 134 | const c10::optional &seq_idx_, 135 | const c10::optional &initial_states_, 136 | at::Tensor &out, 137 | c10::optional &final_states_out_, 138 | bool silu_activation) { 139 | auto input_type = x.scalar_type(); 140 | auto weight_type = weight.scalar_type(); 141 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 142 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 143 | 144 | TORCH_CHECK(x.is_cuda()); 145 | TORCH_CHECK(weight.is_cuda()); 146 | 147 | const auto sizes = x.sizes(); 148 | const int batch_size = sizes[0]; 149 | const int dim = sizes[1]; 150 | const int seqlen = sizes[2]; 151 | const int width = weight.size(-1); 152 | 153 | CHECK_SHAPE(x, batch_size, dim, seqlen); 154 | CHECK_SHAPE(weight, dim, width); 155 | 156 | TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); 157 | const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; 158 | 159 | if (is_channel_last) { 160 | TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); 161 | TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8"); 162 | } 163 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 164 | 165 | if (bias_.has_value()) { 166 | auto bias = bias_.value(); 167 | TORCH_CHECK(bias.scalar_type() == weight_type); 168 | TORCH_CHECK(bias.is_cuda()); 169 | TORCH_CHECK(bias.stride(-1) == 1); 170 | CHECK_SHAPE(bias, dim); 171 | } 172 | 173 | if (seq_idx_.has_value()) { 174 | TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout"); 175 | auto seq_idx = seq_idx_.value(); 176 | TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); 177 | TORCH_CHECK(seq_idx.is_cuda()); 178 | TORCH_CHECK(seq_idx.is_contiguous()); 179 | CHECK_SHAPE(seq_idx, batch_size, seqlen); 180 | } 181 | 182 | ConvParamsBase params; 183 | set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, 184 | bias_.has_value() ? bias_.value().data_ptr() : nullptr, 185 | silu_activation); 186 | 187 | if (seq_idx_.has_value()) { 188 | params.seq_idx_ptr = seq_idx_.value().data_ptr(); 189 | } else { 190 | params.seq_idx_ptr = nullptr; 191 | } 192 | 193 | if (initial_states_.has_value()) { 194 | TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); 195 | auto initial_states = initial_states_.value(); 196 | TORCH_CHECK(initial_states.scalar_type() == input_type); 197 | TORCH_CHECK(initial_states.is_cuda()); 198 | CHECK_SHAPE(initial_states, batch_size, dim, width - 1); 199 | TORCH_CHECK(initial_states.stride(1) == 1); 200 | params.initial_states_ptr = initial_states.data_ptr(); 201 | params.initial_states_batch_stride = initial_states.stride(0); 202 | params.initial_states_c_stride = initial_states.stride(1); 203 | params.initial_states_l_stride = initial_states.stride(2); 204 | } else { 205 | params.initial_states_ptr = nullptr; 206 | } 207 | 208 | if (final_states_out_.has_value()) { 209 | TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout"); 210 | auto final_states = final_states_out_.value(); 211 | TORCH_CHECK(final_states.scalar_type() == input_type); 212 | TORCH_CHECK(final_states.is_cuda()); 213 | CHECK_SHAPE(final_states, batch_size, dim, width - 1); 214 | TORCH_CHECK(final_states.stride(1) == 1); 215 | params.final_states_ptr = final_states.data_ptr(); 216 | params.final_states_batch_stride = final_states.stride(0); 217 | params.final_states_c_stride = final_states.stride(1); 218 | params.final_states_l_stride = final_states.stride(2); 219 | } else { 220 | params.final_states_ptr = nullptr; 221 | } 222 | 223 | // Otherwise the kernel will be launched from cuda:0 device 224 | at::cuda::CUDAGuard device_guard{x.device()}; 225 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 226 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { 227 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] { 228 | if (!is_channel_last) { 229 | causal_conv1d_fwd_cuda(params, stream); 230 | } else { 231 | causal_conv1d_channellast_fwd_cuda(params, stream); 232 | } 233 | }); 234 | }); 235 | } 236 | 237 | void 238 | causal_conv1d_bwd(const at::Tensor &x, 239 | const at::Tensor &weight, 240 | const c10::optional &bias_, 241 | at::Tensor &dout, 242 | const c10::optional &seq_idx_, 243 | const c10::optional &initial_states_, 244 | const c10::optional &dfinal_states_, 245 | at::Tensor &dx, 246 | at::Tensor &dweight, 247 | c10::optional &dbias_, 248 | c10::optional &dinitial_states_, 249 | bool silu_activation) { 250 | auto input_type = x.scalar_type(); 251 | auto weight_type = weight.scalar_type(); 252 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 253 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 254 | 255 | TORCH_CHECK(x.is_cuda()); 256 | TORCH_CHECK(weight.is_cuda()); 257 | TORCH_CHECK(dout.is_cuda()); 258 | TORCH_CHECK(bias_.has_value() == dbias_.has_value()); 259 | 260 | const auto sizes = x.sizes(); 261 | const int batch_size = sizes[0]; 262 | const int dim = sizes[1]; 263 | const int seqlen = sizes[2]; 264 | const int width = weight.size(-1); 265 | 266 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 267 | 268 | CHECK_SHAPE(x, batch_size, dim, seqlen); 269 | CHECK_SHAPE(weight, dim, width); 270 | CHECK_SHAPE(dout, batch_size, dim, seqlen); 271 | 272 | TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); 273 | const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; 274 | if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); } 275 | if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); } 276 | 277 | if (is_channel_last) { 278 | TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); 279 | TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8"); 280 | TORCH_CHECK(dout.stride(2) % 8 == 0 and dout.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (dout.stride(0) and dout.stride(2)) to be multiples of 8"); 281 | } 282 | 283 | if (bias_.has_value()) { 284 | auto bias = bias_.value(); 285 | TORCH_CHECK(bias.scalar_type() == weight_type); 286 | TORCH_CHECK(bias.is_cuda()); 287 | TORCH_CHECK(bias.stride(-1) == 1); 288 | CHECK_SHAPE(bias, dim); 289 | } 290 | 291 | if (seq_idx_.has_value()) { 292 | TORCH_CHECK(is_channel_last, "seq_idx only supported for channel last layout"); 293 | auto seq_idx = seq_idx_.value(); 294 | TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); 295 | TORCH_CHECK(seq_idx.is_cuda()); 296 | TORCH_CHECK(seq_idx.is_contiguous()); 297 | CHECK_SHAPE(seq_idx, batch_size, seqlen); 298 | } 299 | 300 | TORCH_CHECK(dx.scalar_type() == input_type); 301 | TORCH_CHECK(dx.is_cuda()); 302 | CHECK_SHAPE(dx, batch_size, dim, seqlen); 303 | if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); } 304 | if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); } 305 | 306 | // Otherwise the kernel will be launched from cuda:0 device 307 | at::cuda::CUDAGuard device_guard{x.device()}; 308 | 309 | ConvParamsBwd params; 310 | set_conv_params_bwd(params, batch_size, dim, seqlen, width, 311 | x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr, 312 | dout, dx, dweight, bias_.has_value() ? dbias_.value().data_ptr() : nullptr, 313 | silu_activation); 314 | 315 | if (seq_idx_.has_value()) { 316 | params.seq_idx_ptr = seq_idx_.value().data_ptr(); 317 | } else { 318 | params.seq_idx_ptr = nullptr; 319 | } 320 | 321 | if (initial_states_.has_value()) { 322 | TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); 323 | auto initial_states = initial_states_.value(); 324 | TORCH_CHECK(initial_states.scalar_type() == input_type); 325 | TORCH_CHECK(initial_states.is_cuda()); 326 | CHECK_SHAPE(initial_states, batch_size, dim, width - 1); 327 | TORCH_CHECK(initial_states.stride(1) == 1); 328 | params.initial_states_ptr = initial_states.data_ptr(); 329 | params.initial_states_batch_stride = initial_states.stride(0); 330 | params.initial_states_c_stride = initial_states.stride(1); 331 | params.initial_states_l_stride = initial_states.stride(2); 332 | } else { 333 | params.initial_states_ptr = nullptr; 334 | } 335 | 336 | if (dfinal_states_.has_value()) { 337 | TORCH_CHECK(is_channel_last, "dfinal_states is only supported for channel last layout"); 338 | auto dfinal_states = dfinal_states_.value(); 339 | TORCH_CHECK(dfinal_states.scalar_type() == input_type); 340 | TORCH_CHECK(dfinal_states.is_cuda()); 341 | CHECK_SHAPE(dfinal_states, batch_size, dim, width - 1); 342 | params.dfinal_states_ptr = dfinal_states.data_ptr(); 343 | params.dfinal_states_batch_stride = dfinal_states.stride(0); 344 | params.dfinal_states_c_stride = dfinal_states.stride(1); 345 | params.dfinal_states_l_stride = dfinal_states.stride(2); 346 | } else { 347 | params.dfinal_states_ptr = nullptr; 348 | } 349 | 350 | if (dinitial_states_.has_value()) { 351 | at::Tensor dinitial_states = dinitial_states_.value(); 352 | TORCH_CHECK(dinitial_states.stride(1) == 1); 353 | params.dinitial_states_ptr = dinitial_states.data_ptr(); 354 | params.dinitial_states_batch_stride = dinitial_states.stride(0); 355 | params.dinitial_states_c_stride = dinitial_states.stride(1); 356 | params.dinitial_states_l_stride = dinitial_states.stride(2); 357 | } else { 358 | params.dinitial_states_ptr = nullptr; 359 | } 360 | 361 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 362 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] { 363 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] { 364 | if (!is_channel_last) { 365 | causal_conv1d_bwd_cuda(params, stream); 366 | } else { 367 | causal_conv1d_channellast_bwd_cuda(params, stream); 368 | } 369 | }); 370 | }); 371 | } 372 | 373 | void 374 | causal_conv1d_update(const at::Tensor &x, 375 | const at::Tensor &conv_state, 376 | const at::Tensor &weight, 377 | const c10::optional &bias_, 378 | at::Tensor &out, 379 | bool silu_activation, 380 | const c10::optional &cache_seqlens_, 381 | const c10::optional &conv_state_indices_ 382 | ) { 383 | auto input_type = x.scalar_type(); 384 | auto weight_type = weight.scalar_type(); 385 | TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); 386 | TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); 387 | TORCH_CHECK(conv_state.scalar_type() == input_type); 388 | 389 | TORCH_CHECK(x.is_cuda()); 390 | TORCH_CHECK(conv_state.is_cuda()); 391 | TORCH_CHECK(weight.is_cuda()); 392 | 393 | const auto sizes = x.sizes(); 394 | const int batch_size = sizes[0]; 395 | const int dim = sizes[1]; 396 | const int seqlen = sizes[2]; 397 | const int width = weight.size(-1); 398 | const int conv_state_len = conv_state.size(2); 399 | TORCH_CHECK(conv_state_len >= width - 1); 400 | 401 | CHECK_SHAPE(x, batch_size, dim, seqlen); 402 | CHECK_SHAPE(weight, dim, width); 403 | 404 | TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); 405 | 406 | if (bias_.has_value()) { 407 | auto bias = bias_.value(); 408 | TORCH_CHECK(bias.scalar_type() == weight_type); 409 | TORCH_CHECK(bias.is_cuda()); 410 | TORCH_CHECK(bias.stride(-1) == 1); 411 | CHECK_SHAPE(bias, dim); 412 | } 413 | 414 | ConvParamsBase params; 415 | set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, 416 | bias_.has_value() ? bias_.value().data_ptr() : nullptr, 417 | silu_activation); 418 | params.conv_state_ptr = conv_state.data_ptr(); 419 | params.conv_state_len = conv_state_len; 420 | // All stride are in elements, not bytes. 421 | params.conv_state_batch_stride = conv_state.stride(0); 422 | params.conv_state_c_stride = conv_state.stride(1); 423 | params.conv_state_l_stride = conv_state.stride(2); 424 | 425 | if (conv_state_indices_.has_value()) { 426 | auto conv_state_indices = conv_state_indices_.value(); 427 | TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) 428 | TORCH_CHECK(conv_state_indices.is_cuda()); 429 | TORCH_CHECK(conv_state_indices.stride(0) == 1) 430 | CHECK_SHAPE(conv_state_indices, batch_size); 431 | 432 | int conv_state_entries = conv_state.size(0); 433 | CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len); 434 | 435 | params.conv_state_indices_ptr = conv_state_indices.data_ptr(); 436 | } else { 437 | CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len); 438 | params.conv_state_indices_ptr = nullptr; 439 | } 440 | 441 | if (cache_seqlens_.has_value()) { 442 | auto cache_seqlens = cache_seqlens_.value(); 443 | TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32); 444 | TORCH_CHECK(cache_seqlens.is_cuda()); 445 | TORCH_CHECK(cache_seqlens.stride(-1) == 1); 446 | CHECK_SHAPE(cache_seqlens, batch_size); 447 | params.cache_seqlens = cache_seqlens.data_ptr(); 448 | } else { 449 | params.cache_seqlens = nullptr; 450 | } 451 | 452 | // Otherwise the kernel will be launched from cuda:0 device 453 | at::cuda::CUDAGuard device_guard{x.device()}; 454 | auto stream = at::cuda::getCurrentCUDAStream().stream(); 455 | DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { 456 | DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] { 457 | causal_conv1d_update_cuda(params, stream); 458 | }); 459 | }); 460 | } 461 | 462 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 463 | m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward"); 464 | m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward"); 465 | m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update"); 466 | } 467 | -------------------------------------------------------------------------------- /csrc/causal_conv1d.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | //////////////////////////////////////////////////////////////////////////////////////////////////// 8 | 9 | struct ConvParamsBase { 10 | using index_t = uint32_t; 11 | 12 | int batch, dim, seqlen, width; 13 | bool silu_activation; 14 | 15 | index_t x_batch_stride; 16 | index_t x_c_stride; 17 | index_t x_l_stride; 18 | index_t weight_c_stride; 19 | index_t weight_width_stride; 20 | index_t out_batch_stride; 21 | index_t out_c_stride; 22 | index_t out_l_stride; 23 | 24 | int conv_state_len; 25 | index_t conv_state_batch_stride; 26 | index_t conv_state_c_stride; 27 | index_t conv_state_l_stride; 28 | 29 | // Common data pointers. 30 | void *__restrict__ x_ptr; 31 | void *__restrict__ weight_ptr; 32 | void *__restrict__ bias_ptr; 33 | void *__restrict__ out_ptr; 34 | 35 | void *__restrict__ conv_state_ptr; 36 | int32_t *__restrict__ cache_seqlens; 37 | 38 | // Only used if the elements of the batch are gathered from a larger buffer, 39 | // which may happen for continuous batching. 40 | int32_t *__restrict__ conv_state_indices_ptr; 41 | 42 | void *__restrict__ seq_idx_ptr; 43 | 44 | // No __restrict__ since initial_states could be the same as final_states. 45 | void * initial_states_ptr; 46 | index_t initial_states_batch_stride; 47 | index_t initial_states_l_stride; 48 | index_t initial_states_c_stride; 49 | 50 | void * final_states_ptr; 51 | index_t final_states_batch_stride; 52 | index_t final_states_l_stride; 53 | index_t final_states_c_stride; 54 | }; 55 | 56 | struct ConvParamsBwd: public ConvParamsBase { 57 | index_t dx_batch_stride; 58 | index_t dx_c_stride; 59 | index_t dx_l_stride; 60 | index_t dweight_c_stride; 61 | index_t dweight_width_stride; 62 | index_t dout_batch_stride; 63 | index_t dout_c_stride; 64 | index_t dout_l_stride; 65 | 66 | // Common data pointers. 67 | void *__restrict__ dx_ptr; 68 | void *__restrict__ dweight_ptr; 69 | void *__restrict__ dbias_ptr; 70 | void *__restrict__ dout_ptr; 71 | 72 | void * dinitial_states_ptr; 73 | index_t dinitial_states_batch_stride; 74 | index_t dinitial_states_l_stride; 75 | index_t dinitial_states_c_stride; 76 | 77 | void * dfinal_states_ptr; 78 | index_t dfinal_states_batch_stride; 79 | index_t dfinal_states_l_stride; 80 | index_t dfinal_states_c_stride; 81 | }; 82 | -------------------------------------------------------------------------------- /csrc/causal_conv1d_bwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 8 | 9 | #ifndef USE_ROCM 10 | #include 11 | #include 12 | #include 13 | #else 14 | #include 15 | namespace cub = hipcub; 16 | #endif 17 | 18 | #include "causal_conv1d.h" 19 | #include "causal_conv1d_common.h" 20 | #include "static_switch.h" 21 | 22 | template 23 | struct Causal_conv1d_bwd_kernel_traits { 24 | using input_t = input_t_; 25 | using weight_t = weight_t_; 26 | static constexpr int kNThreads = kNThreads_; 27 | static constexpr int kWidth = kWidth_; 28 | static constexpr bool kSiluAct = kSiluAct_; 29 | static constexpr int kNBytes = sizeof(input_t); 30 | static_assert(kNBytes == 2 || kNBytes == 4); 31 | static constexpr int kNElts = kNBytes == 4 ? 4 : 8; 32 | static_assert(kWidth <= kNElts); 33 | // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits 34 | // (since then we'd have 8 values of float, and each round we can exchange 4 floats). 35 | static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t); 36 | static constexpr bool kIsVecLoad = kIsVecLoad_; 37 | using vec_t = typename BytesToType::Type; 38 | using BlockLoadT = cub::BlockLoad; 39 | using BlockLoadVecT = cub::BlockLoad; 40 | using BlockStoreT = cub::BlockStore; 41 | using BlockStoreVecT = cub::BlockStore; 42 | using BlockReduceFloatT = cub::BlockReduce; 43 | static constexpr int kSmemIOSize = kIsVecLoad 44 | ? 0 45 | : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); 46 | static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1); 47 | static constexpr int kSmemSize = custom_max({kSmemExchangeSize, 48 | int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize); 49 | }; 50 | 51 | template 52 | __global__ __launch_bounds__(Ktraits::kNThreads) 53 | void causal_conv1d_bwd_kernel(ConvParamsBwd params) { 54 | constexpr int kWidth = Ktraits::kWidth; 55 | constexpr int kNThreads = Ktraits::kNThreads; 56 | constexpr bool kSiluAct = Ktraits::kSiluAct; 57 | static constexpr int kNElts = Ktraits::kNElts; 58 | constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds; 59 | static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; 60 | using input_t = typename Ktraits::input_t; 61 | using vec_t = typename Ktraits::vec_t; 62 | using weight_t = typename Ktraits::weight_t; 63 | 64 | // Shared memory. 65 | extern __shared__ char smem_[]; 66 | auto& smem_load = reinterpret_cast(smem_); 67 | auto& smem_load_vec = reinterpret_cast(smem_); 68 | auto& smem_store = reinterpret_cast(smem_); 69 | auto& smem_store_vec = reinterpret_cast(smem_); 70 | vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); 71 | vec_t *smem_exchange_x = reinterpret_cast(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds; 72 | auto& smem_reduce_float = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); 73 | 74 | const int tidx = threadIdx.x; 75 | const int batch_id = blockIdx.x; 76 | const int dim_id = blockIdx.y; 77 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 78 | + dim_id * params.x_c_stride; 79 | weight_t *weight = reinterpret_cast(params.weight_ptr) + dim_id * params.weight_c_stride; 80 | input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride 81 | + dim_id * params.dout_c_stride; 82 | input_t *dx = reinterpret_cast(params.dx_ptr) + batch_id * params.dx_batch_stride 83 | + dim_id * params.dx_c_stride; 84 | float *dweight = reinterpret_cast(params.dweight_ptr) + dim_id * params.dweight_c_stride; 85 | float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[dim_id]); 86 | 87 | // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0. 88 | if (tidx == 0) { 89 | if constexpr (!kSiluAct) { 90 | input_t zeros[kNElts] = {0}; 91 | smem_exchange[0] = reinterpret_cast(zeros)[0]; 92 | } else { 93 | float zeros[kNElts] = {0}; 94 | #pragma unroll 95 | for (int r = 0; r < kNExchangeRounds; ++r) { 96 | smem_exchange[r * kNThreads] = reinterpret_cast(zeros)[r]; 97 | } 98 | } 99 | } 100 | 101 | float weight_vals[kWidth]; 102 | #pragma unroll 103 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; } 104 | 105 | float dweight_vals[kWidth] = {0}; 106 | float dbias_val = 0; 107 | 108 | constexpr int kChunkSize = kNThreads * kNElts; 109 | const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; 110 | x += (n_chunks - 1) * kChunkSize; 111 | dout += (n_chunks - 1) * kChunkSize; 112 | dx += (n_chunks - 1) * kChunkSize; 113 | for (int chunk = n_chunks - 1; chunk >= 0; --chunk) { 114 | input_t x_vals_load[2 * kNElts] = {0}; 115 | input_t dout_vals_load[2 * kNElts] = {0}; 116 | if constexpr(kIsVecLoad) { 117 | typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); 118 | typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(dout), *reinterpret_cast(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts); 119 | } else { 120 | __syncthreads(); 121 | typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); 122 | __syncthreads(); 123 | typename Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize); 124 | } 125 | float dout_vals[2 * kNElts], x_vals[2 * kNElts]; 126 | if constexpr (!kSiluAct) { 127 | __syncthreads(); 128 | // Thread 0 don't write yet, so that thread kNThreads - 1 can read 129 | // the first elements of the next chunk. 130 | if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast(dout_vals_load)[0]; } 131 | __syncthreads(); 132 | reinterpret_cast(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0]; 133 | __syncthreads(); 134 | // Now thread 0 can write the first elements of the current chunk. 135 | if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast(dout_vals_load)[0]; } 136 | #pragma unroll 137 | for (int i = 0; i < 2 * kNElts; ++i) { 138 | dout_vals[i] = float(dout_vals_load[i]); 139 | x_vals[i] = float(x_vals_load[i]); 140 | } 141 | } else { 142 | if (tidx == 0 && chunk > 0) { 143 | if constexpr(kIsVecLoad) { 144 | reinterpret_cast(x_vals_load)[0] = reinterpret_cast(x)[-1]; 145 | } else { 146 | #pragma unroll 147 | for (int i = 0; i < kNElts; ++i) { 148 | if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; } 149 | } 150 | } 151 | } 152 | __syncthreads(); 153 | smem_exchange_x[tidx] = reinterpret_cast(x_vals_load)[1]; 154 | __syncthreads(); 155 | if (tidx > 0) { reinterpret_cast(x_vals_load)[0] = smem_exchange_x[tidx - 1]; } 156 | #pragma unroll 157 | for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } 158 | // Recompute the output 159 | #pragma unroll 160 | for (int i = 0; i < kNElts; ++i) { 161 | float out_val = bias_val; 162 | #pragma unroll 163 | for (int w = 0; w < kWidth; ++w) { 164 | out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; 165 | } 166 | float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val)); 167 | dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val 168 | * (1.0f + out_val * (1.0f - out_sigmoid_val)); 169 | } 170 | // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange 171 | // if input_t is 16 bits (since then we'd have 8 values of float) 172 | __syncthreads(); 173 | // Thread 0 don't write yet, so that thread kNThreads - 1 can read 174 | // the first elements of the next chunk. 175 | if (tidx > 0) { 176 | #pragma unroll 177 | for (int r = 0; r < kNExchangeRounds; ++r) { 178 | smem_exchange[r * kNThreads + tidx] = reinterpret_cast(dout_vals)[r]; 179 | } 180 | } 181 | __syncthreads(); 182 | #pragma unroll 183 | for (int r = 0; r < kNExchangeRounds; ++r) { 184 | reinterpret_cast(dout_vals)[kNExchangeRounds + r] 185 | = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)]; 186 | } 187 | __syncthreads(); 188 | // Now thread 0 can write the first elements of the current chunk. 189 | if (tidx == 0) { 190 | #pragma unroll 191 | for (int r = 0; r < kNExchangeRounds; ++r) { 192 | smem_exchange[r * kNThreads + tidx] = reinterpret_cast(dout_vals)[r]; 193 | } 194 | } 195 | } 196 | dout -= kChunkSize; 197 | x -= kChunkSize; 198 | 199 | #pragma unroll 200 | for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; } 201 | 202 | float dx_vals[kNElts] = {0}; 203 | #pragma unroll 204 | for (int i = 0; i < kNElts; ++i) { 205 | #pragma unroll 206 | for (int w = 0; w < kWidth; ++w) { 207 | dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1]; 208 | } 209 | } 210 | 211 | input_t dx_vals_store[kNElts]; 212 | #pragma unroll 213 | for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; } 214 | if constexpr(kIsVecLoad) { 215 | typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(dx), reinterpret_cast(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); 216 | } else { 217 | typename Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize); 218 | } 219 | dx -= kChunkSize; 220 | 221 | #pragma unroll 222 | for (int w = 0; w < kWidth; ++w) { 223 | #pragma unroll 224 | for (int i = 0; i < kNElts; ++i) { 225 | dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1]; 226 | } 227 | } 228 | } 229 | 230 | #pragma unroll 231 | for (int w = 0; w < kWidth; ++w) { 232 | __syncthreads(); 233 | dweight_vals[w] = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]); 234 | if (tidx == 0) { 235 | atomicAdd(&reinterpret_cast(dweight)[w * params.dweight_width_stride], dweight_vals[w]); 236 | } 237 | } 238 | if (params.bias_ptr != nullptr) { 239 | __syncthreads(); 240 | dbias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val); 241 | if (tidx == 0) { 242 | atomicAdd(&reinterpret_cast(params.dbias_ptr)[dim_id], dbias_val); 243 | } 244 | } 245 | } 246 | 247 | template 248 | void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) { 249 | static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; 250 | BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { 251 | BOOL_SWITCH(params.silu_activation, kSiluAct, [&] { 252 | using Ktraits = Causal_conv1d_bwd_kernel_traits; 253 | constexpr int kSmemSize = Ktraits::kSmemSize; 254 | dim3 grid(params.batch, params.dim); 255 | auto kernel = &causal_conv1d_bwd_kernel; 256 | 257 | if (kSmemSize >= 48 * 1024) { 258 | #ifndef USE_ROCM 259 | C10_CUDA_CHECK(cudaFuncSetAttribute( 260 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); 261 | #else 262 | // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. 263 | C10_CUDA_CHECK(cudaFuncSetAttribute( 264 | (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); 265 | std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; 266 | #endif 267 | } 268 | 269 | 270 | kernel<<>>(params); 271 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 272 | }); 273 | }); 274 | } 275 | 276 | template 277 | void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) { 278 | if (params.width == 2) { 279 | causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream); 280 | } else if (params.width == 3) { 281 | causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream); 282 | } else if (params.width == 4) { 283 | causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream); 284 | } 285 | } 286 | 287 | template 288 | struct Causal_conv1d_channellast_bwd_kernel_traits { 289 | // The cache line is 128 bytes, and we try to read 16 bytes per thread. 290 | // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. 291 | // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 292 | // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. 293 | using input_t = input_t_; 294 | using weight_t = weight_t_; 295 | static constexpr bool kSiluAct = kSiluAct_; 296 | static constexpr int kNThreads = kNThreads_; 297 | static_assert(kNThreads % 32 == 0); 298 | static constexpr int kNWarps = kNThreads / 32; 299 | static constexpr int kWidth = kWidth_; 300 | static constexpr int kChunkSizeL = kChunkSizeL_; 301 | static constexpr int kNBytes = sizeof(input_t); 302 | static_assert(kNBytes == 2 || kNBytes == 4); 303 | static constexpr int kNElts = kNBytes == 4 ? 4 : 8; 304 | static constexpr int kNEltsPerRow = 128 / kNBytes; 305 | static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now 306 | static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); 307 | static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now 308 | static_assert(kNColsPerWarp * kNThreadsPerRow == 32); 309 | static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; 310 | static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; 311 | static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); 312 | static constexpr bool kIsVecLoad = kIsVecLoad_; 313 | using vec_t = typename BytesToType::Type; 314 | // using BlockLoadT = cub::BlockLoad; 315 | // using BlockStoreT = cub::BlockStore; 316 | // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), 317 | // sizeof(typename BlockStoreT::TempStorage)}); 318 | // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; 319 | }; 320 | 321 | template 322 | __global__ __launch_bounds__(Ktraits::kNThreads) 323 | void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) { 324 | constexpr int kWidth = Ktraits::kWidth; 325 | constexpr int kNThreads = Ktraits::kNThreads; 326 | constexpr bool kSiluAct = Ktraits::kSiluAct; 327 | constexpr int kNElts = Ktraits::kNElts; 328 | constexpr int kNWarp = Ktraits::kNWarps; 329 | constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; 330 | constexpr int kLPerLoad = Ktraits::kNColsPerLoad; 331 | constexpr int kChunkSizeL = Ktraits::kChunkSizeL; 332 | constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; 333 | using input_t = typename Ktraits::input_t; 334 | using vec_t = typename Ktraits::vec_t; 335 | using weight_t = typename Ktraits::weight_t; 336 | 337 | // Shared memory. 338 | __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts]; 339 | __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts]; 340 | 341 | const int batch_id = blockIdx.x; 342 | const int chunk_l_id = blockIdx.y; 343 | const int chunk_c_id = blockIdx.z; 344 | const int tid = threadIdx.x; 345 | const int l_idx = tid / kNThreadsPerC; 346 | const int c_idx = tid % kNThreadsPerC; 347 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 348 | + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; 349 | weight_t *weight = reinterpret_cast(params.weight_ptr) 350 | + chunk_c_id * kChunkSizeC * params.weight_c_stride; 351 | input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride 352 | + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; 353 | input_t *dx = reinterpret_cast(params.dx_ptr) + batch_id * params.dx_batch_stride 354 | + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; 355 | float *dweight = reinterpret_cast(params.dweight_ptr) 356 | + chunk_c_id * kChunkSizeC * params.dweight_c_stride; 357 | int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) 358 | + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; 359 | input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr 360 | : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; 361 | input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr 362 | : reinterpret_cast(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; 363 | input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr 364 | : reinterpret_cast(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC; 365 | 366 | #pragma unroll 367 | for (int l = 0; l < Ktraits::kNLoads; ++l) { 368 | input_t dout_vals_load[kNElts] = {0}; 369 | input_t x_vals_load[kNElts] = {0}; 370 | if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen 371 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 372 | reinterpret_cast(dout_vals_load)[0] = *reinterpret_cast(dout + l * kLPerLoad * params.dout_l_stride); 373 | reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); 374 | } 375 | reinterpret_cast(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(dout_vals_load)[0]; 376 | reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; 377 | } 378 | // Load the elements from the previous chunk or next chunk that are needed for convolution. 379 | if (l_idx < kWidth - 1) { 380 | input_t dout_vals_load[kNElts] = {0}; 381 | input_t x_vals_load[kNElts] = {0}; 382 | if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen 383 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 384 | reinterpret_cast(dout_vals_load)[0] = *reinterpret_cast(dout + kChunkSizeL * params.dout_l_stride); 385 | } 386 | if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 387 | && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen 388 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 389 | reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); 390 | } else if (initial_states != nullptr 391 | && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 392 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 393 | reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); 394 | } 395 | reinterpret_cast(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast(dout_vals_load)[0]; 396 | reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; 397 | } 398 | // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs 399 | if constexpr (kSiluAct) { 400 | if (l_idx < kWidth - 1) { 401 | input_t x_vals_load[kNElts] = {0}; 402 | if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen 403 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 404 | reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + kChunkSizeL * params.x_l_stride); 405 | } 406 | reinterpret_cast(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; 407 | } 408 | } 409 | 410 | __syncthreads(); 411 | 412 | constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); 413 | static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); 414 | constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; 415 | static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); 416 | // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity 417 | static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); 418 | static_assert((kLPerThread & (kLPerThread - 1)) == 0); 419 | static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); 420 | static_assert(kNThreadsPerRow <= 32); 421 | 422 | const int row_idx = tid / kNThreadsPerRow; 423 | const int col_idx = tid % kNThreadsPerRow; 424 | 425 | float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); 426 | float weight_vals[kWidth] = {0}; 427 | if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { 428 | #pragma unroll 429 | for (int w = 0; w < kWidth; ++w) { 430 | weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; 431 | } 432 | } 433 | float dout_vals[kLPerThread + kWidth - 1]; 434 | float x_vals[kWidth - 1 + kLPerThread + kWidth - 1]; 435 | #pragma unroll 436 | for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { 437 | dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]); 438 | x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); 439 | } 440 | 441 | int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1]; 442 | if constexpr (kHasSeqIdx) { 443 | #pragma unroll 444 | for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) { 445 | const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1); 446 | seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; 447 | } 448 | } 449 | 450 | if constexpr (kSiluAct) { // Recompute the output 451 | #pragma unroll 452 | for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) { 453 | x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); 454 | } 455 | #pragma unroll 456 | for (int i = 0; i < kLPerThread + kWidth - 1; ++i) { 457 | float out_val = bias_val; 458 | const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; 459 | #pragma unroll 460 | for (int w = 0; w < kWidth; ++w) { 461 | if constexpr (!kHasSeqIdx) { 462 | out_val += weight_vals[w] * x_vals[i + w]; 463 | } else { 464 | out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; 465 | } 466 | } 467 | float out_val_sigmoid = 1.f / (1.f + expf(-out_val)); 468 | dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid)); 469 | } 470 | } 471 | 472 | float dweight_vals[kWidth] = {0}; 473 | SumOp sum_op; 474 | #pragma unroll 475 | for (int w = 0; w < kWidth; ++w) { 476 | #pragma unroll 477 | for (int i = 0; i < kLPerThread; ++i) { 478 | if constexpr (!kHasSeqIdx) { 479 | dweight_vals[w] += x_vals[i + w] * dout_vals[i]; 480 | } else { 481 | dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f; 482 | } 483 | } 484 | dweight_vals[w] = Allreduce::run(dweight_vals[w], sum_op); 485 | if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) { 486 | atomicAdd(&reinterpret_cast(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]); 487 | } 488 | } 489 | 490 | if (params.bias_ptr != nullptr) { 491 | float dbias_val = 0.f; 492 | for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; } 493 | dbias_val = Allreduce::run(dbias_val, sum_op); 494 | if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) { 495 | atomicAdd(&reinterpret_cast(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val); 496 | } 497 | } 498 | 499 | float dx_vals[kLPerThread] = {0}; 500 | #pragma unroll 501 | for (int i = 0; i < kLPerThread; ++i) { 502 | const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; 503 | #pragma unroll 504 | for (int w = 0; w < kWidth; ++w) { 505 | if constexpr (!kHasSeqIdx) { 506 | dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w]; 507 | } else { 508 | dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f; 509 | } 510 | } 511 | // if (dfinal_states != nullptr) { 512 | if constexpr (kHasDfinalStates) { 513 | if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1 514 | && chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen 515 | && chunk_c_id * kChunkSizeC + row_idx < params.dim) { 516 | dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]); 517 | } 518 | } 519 | } 520 | 521 | float dxinit_vals[kWidth - 1] = {0}; 522 | static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states 523 | if (dinitial_states != nullptr && col_idx == 0) { 524 | #pragma unroll 525 | for (int i = 0; i < kWidth - 1; ++i) { 526 | #pragma unroll 527 | for (int w = 0; w < kWidth; ++w) { 528 | dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f; 529 | } 530 | // chunk_l_id must be 0 because dinitial_states != nullptr 531 | // if (dfinal_states != nullptr) { 532 | if constexpr (kHasDfinalStates) { 533 | if (i >= params.seqlen) { 534 | dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]); 535 | } 536 | } 537 | } 538 | } 539 | 540 | __syncthreads(); 541 | #pragma unroll 542 | for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; } 543 | if (dinitial_states != nullptr && col_idx == 0) { 544 | #pragma unroll 545 | for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; } 546 | } 547 | __syncthreads(); 548 | 549 | #pragma unroll 550 | for (int l = 0; l < Ktraits::kNLoads; ++l) { 551 | input_t dx_vals_store[kNElts]; 552 | reinterpret_cast(dx_vals_store)[0] = reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx]; 553 | if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen 554 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 555 | *reinterpret_cast(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast(dx_vals_store)[0]; 556 | } 557 | } 558 | if (dinitial_states != nullptr 559 | && l_idx < kWidth - 1 560 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 561 | input_t dxinit_vals_store[kNElts]; 562 | reinterpret_cast(dxinit_vals_store)[0] = reinterpret_cast(x_smem[l_idx])[c_idx]; 563 | *reinterpret_cast(dinitial_states) = reinterpret_cast(dxinit_vals_store)[0]; 564 | } 565 | 566 | } 567 | 568 | template 569 | void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) { 570 | BOOL_SWITCH(params.silu_activation, kSiluAct, [&] { 571 | BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { 572 | BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] { 573 | BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] { 574 | // kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger 575 | static constexpr int kChunk = kChunkSizeL64 ? 64 : 128; 576 | using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits; 577 | // constexpr int kSmemSize = Ktraits::kSmemSize; 578 | constexpr int kChunkSizeL = Ktraits::kChunkSizeL; 579 | constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; 580 | const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; 581 | const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; 582 | dim3 grid(params.batch, n_chunks_L, n_chunks_C); 583 | dim3 block(Ktraits::kNThreads); 584 | auto kernel = &causal_conv1d_channellast_bwd_kernel; 585 | // if (kSmemSize >= 48 * 1024) { 586 | // C10_CUDA_CHECK(cudaFuncSetAttribute( 587 | // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); 588 | // } 589 | // kernel<<>>(params); 590 | kernel<<>>(params); 591 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 592 | }); 593 | }); 594 | }); 595 | }); 596 | } 597 | 598 | template 599 | void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) { 600 | if (params.width == 2) { 601 | causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream); 602 | } else if (params.width == 3) { 603 | causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream); 604 | } else if (params.width == 4) { 605 | causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream); 606 | } 607 | } 608 | 609 | template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 610 | template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 611 | template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 612 | template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 613 | template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 614 | template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 615 | template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 616 | template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 617 | template void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 618 | 619 | template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 620 | template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 621 | template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 622 | template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 623 | template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 624 | template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 625 | template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 626 | template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); 627 | template void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/causal_conv1d_common.h: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #pragma once 6 | 7 | #ifndef USE_ROCM 8 | #include 9 | 10 | template 11 | __device__ inline T shuffle_xor(T val, int offset) { 12 | return __shfl_xor_sync(uint32_t(-1), val, offset); 13 | } 14 | 15 | constexpr size_t custom_max(std::initializer_list ilist) 16 | { 17 | return std::max(ilist); 18 | } 19 | 20 | template 21 | constexpr T constexpr_min(T a, T b) { 22 | return std::min(a, b); 23 | } 24 | 25 | #else 26 | #include 27 | 28 | template 29 | __device__ inline T shuffle_xor(T val, int offset) { 30 | return __shfl_xor(val, offset); 31 | } 32 | constexpr size_t custom_max(std::initializer_list ilist) 33 | { 34 | return *std::max_element(ilist.begin(), ilist.end()); 35 | } 36 | 37 | template 38 | constexpr T constexpr_min(T a, T b) { 39 | return a < b ? a : b; 40 | } 41 | #endif 42 | #include 43 | 44 | //////////////////////////////////////////////////////////////////////////////////////////////////// 45 | 46 | template struct BytesToType {}; 47 | 48 | template<> struct BytesToType<16> { 49 | using Type = uint4; 50 | static_assert(sizeof(Type) == 16); 51 | }; 52 | 53 | template<> struct BytesToType<8> { 54 | using Type = uint64_t; 55 | static_assert(sizeof(Type) == 8); 56 | }; 57 | 58 | template<> struct BytesToType<4> { 59 | using Type = uint32_t; 60 | static_assert(sizeof(Type) == 4); 61 | }; 62 | 63 | template<> struct BytesToType<2> { 64 | using Type = uint16_t; 65 | static_assert(sizeof(Type) == 2); 66 | }; 67 | 68 | template<> struct BytesToType<1> { 69 | using Type = uint8_t; 70 | static_assert(sizeof(Type) == 1); 71 | }; 72 | 73 | //////////////////////////////////////////////////////////////////////////////////////////////////// 74 | 75 | template 76 | struct SumOp { 77 | __device__ inline T operator()(T const & x, T const & y) { return x + y; } 78 | }; 79 | 80 | template 81 | struct Allreduce { 82 | static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); 83 | template 84 | static __device__ inline T run(T x, Operator &op) { 85 | constexpr int OFFSET = THREADS / 2; 86 | x = op(x, shuffle_xor(x, OFFSET)); 87 | return Allreduce::run(x, op); 88 | } 89 | }; 90 | 91 | template<> 92 | struct Allreduce<2> { 93 | template 94 | static __device__ inline T run(T x, Operator &op) { 95 | x = op(x, shuffle_xor(x, 1)); 96 | return x; 97 | } 98 | }; 99 | -------------------------------------------------------------------------------- /csrc/causal_conv1d_fwd.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2024, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 8 | 9 | #ifndef USE_ROCM 10 | #include 11 | #include 12 | #else 13 | #include 14 | namespace cub = hipcub; 15 | #endif 16 | 17 | #include "causal_conv1d.h" 18 | #include "causal_conv1d_common.h" 19 | #include "static_switch.h" 20 | 21 | template 22 | struct Causal_conv1d_fwd_kernel_traits { 23 | using input_t = input_t_; 24 | using weight_t = weight_t_; 25 | static constexpr int kNThreads = kNThreads_; 26 | static constexpr int kWidth = kWidth_; 27 | static constexpr int kNBytes = sizeof(input_t); 28 | static_assert(kNBytes == 2 || kNBytes == 4); 29 | static constexpr int kNElts = kNBytes == 4 ? 4 : 8; 30 | static_assert(kWidth <= kNElts); 31 | static constexpr bool kIsVecLoad = kIsVecLoad_; 32 | using vec_t = typename BytesToType::Type; 33 | using BlockLoadT = cub::BlockLoad; 34 | using BlockLoadVecT = cub::BlockLoad; 35 | using BlockStoreT = cub::BlockStore; 36 | using BlockStoreVecT = cub::BlockStore; 37 | static constexpr int kSmemIOSize = kIsVecLoad 38 | ? 0 39 | : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); 40 | static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; 41 | static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; 42 | }; 43 | 44 | template 45 | __global__ __launch_bounds__(Ktraits::kNThreads) 46 | void causal_conv1d_fwd_kernel(ConvParamsBase params) { 47 | constexpr int kWidth = Ktraits::kWidth; 48 | constexpr int kNThreads = Ktraits::kNThreads; 49 | constexpr int kNElts = Ktraits::kNElts; 50 | static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; 51 | using input_t = typename Ktraits::input_t; 52 | using vec_t = typename Ktraits::vec_t; 53 | using weight_t = typename Ktraits::weight_t; 54 | 55 | // Shared memory. 56 | extern __shared__ char smem_[]; 57 | auto& smem_load = reinterpret_cast(smem_); 58 | auto& smem_load_vec = reinterpret_cast(smem_); 59 | auto& smem_store = reinterpret_cast(smem_); 60 | auto& smem_store_vec = reinterpret_cast(smem_); 61 | vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); 62 | 63 | const int tidx = threadIdx.x; 64 | const int batch_id = blockIdx.x; 65 | const int channel_id = blockIdx.y; 66 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 67 | + channel_id * params.x_c_stride; 68 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; 69 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 70 | + channel_id * params.out_c_stride; 71 | float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); 72 | 73 | // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. 74 | if (tidx == 0) { 75 | input_t zeros[kNElts] = {0}; 76 | smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; 77 | } 78 | 79 | float weight_vals[kWidth]; 80 | #pragma unroll 81 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } 82 | 83 | constexpr int kChunkSize = kNThreads * kNElts; 84 | const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; 85 | for (int chunk = 0; chunk < n_chunks; ++chunk) { 86 | input_t x_vals_load[2 * kNElts] = {0}; 87 | if constexpr(kIsVecLoad) { 88 | typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); 89 | } else { 90 | __syncthreads(); 91 | typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); 92 | } 93 | x += kChunkSize; 94 | __syncthreads(); 95 | // Thread kNThreads - 1 don't write yet, so that thread 0 can read 96 | // the last elements of the previous chunk. 97 | if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } 98 | __syncthreads(); 99 | reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; 100 | __syncthreads(); 101 | // Now thread kNThreads - 1 can write the last elements of the current chunk. 102 | if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } 103 | 104 | float x_vals[2 * kNElts]; 105 | #pragma unroll 106 | for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } 107 | 108 | float out_vals[kNElts]; 109 | #pragma unroll 110 | for (int i = 0; i < kNElts; ++i) { 111 | out_vals[i] = bias_val; 112 | #pragma unroll 113 | for (int w = 0; w < kWidth; ++w) { 114 | out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; 115 | } 116 | } 117 | 118 | if (params.silu_activation) { 119 | #pragma unroll 120 | for (int i = 0; i < kNElts; ++i) { 121 | out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); 122 | } 123 | } 124 | 125 | input_t out_vals_store[kNElts]; 126 | #pragma unroll 127 | for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } 128 | if constexpr(kIsVecLoad) { 129 | typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); 130 | } else { 131 | typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); 132 | } 133 | out += kChunkSize; 134 | } 135 | } 136 | 137 | template 138 | void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { 139 | static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; 140 | BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { 141 | using Ktraits = Causal_conv1d_fwd_kernel_traits; 142 | constexpr int kSmemSize = Ktraits::kSmemSize; 143 | dim3 grid(params.batch, params.dim); 144 | 145 | auto kernel = &causal_conv1d_fwd_kernel; 146 | 147 | if (kSmemSize >= 48 * 1024) { 148 | #ifndef USE_ROCM 149 | C10_CUDA_CHECK(cudaFuncSetAttribute( 150 | kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); 151 | #else 152 | // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. 153 | C10_CUDA_CHECK(cudaFuncSetAttribute( 154 | (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); 155 | std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; 156 | #endif 157 | } 158 | kernel<<>>(params); 159 | 160 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 161 | }); 162 | } 163 | 164 | template 165 | void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { 166 | if (params.width == 2) { 167 | causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); 168 | } else if (params.width == 3) { 169 | causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); 170 | } else if (params.width == 4) { 171 | causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); 172 | } 173 | } 174 | 175 | template 176 | struct Causal_conv1d_channellast_fwd_kernel_traits { 177 | // The cache line is 128 bytes, and we try to read 16 bytes per thread. 178 | // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. 179 | // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 180 | // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. 181 | using input_t = input_t_; 182 | using weight_t = weight_t_; 183 | static constexpr int kNThreads = kNThreads_; 184 | static_assert(kNThreads % 32 == 0); 185 | static constexpr int kNWarps = kNThreads / 32; 186 | static constexpr int kWidth = kWidth_; 187 | static constexpr int kChunkSizeL = kChunkSizeL_; 188 | static constexpr int kNBytes = sizeof(input_t); 189 | static_assert(kNBytes == 2 || kNBytes == 4); 190 | static constexpr int kNElts = kNBytes == 4 ? 4 : 8; 191 | static constexpr int kNEltsPerRow = 128 / kNBytes; 192 | static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now 193 | static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); 194 | static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now 195 | static_assert(kNColsPerWarp * kNThreadsPerRow == 32); 196 | static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; 197 | static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; 198 | static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); 199 | static constexpr bool kIsVecLoad = kIsVecLoad_; 200 | using vec_t = typename BytesToType::Type; 201 | // using BlockLoadT = cub::BlockLoad; 202 | // using BlockStoreT = cub::BlockStore; 203 | // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), 204 | // sizeof(typename BlockStoreT::TempStorage)}); 205 | // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; 206 | }; 207 | 208 | template 209 | __global__ __launch_bounds__(Ktraits::kNThreads) 210 | void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { 211 | constexpr int kWidth = Ktraits::kWidth; 212 | constexpr int kNThreads = Ktraits::kNThreads; 213 | constexpr int kNElts = Ktraits::kNElts; 214 | constexpr int kNWarp = Ktraits::kNWarps; 215 | constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; 216 | constexpr int kLPerLoad = Ktraits::kNColsPerLoad; 217 | constexpr int kChunkSizeL = Ktraits::kChunkSizeL; 218 | constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; 219 | using input_t = typename Ktraits::input_t; 220 | using vec_t = typename Ktraits::vec_t; 221 | using weight_t = typename Ktraits::weight_t; 222 | 223 | // Shared memory. 224 | __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; 225 | 226 | const int batch_id = blockIdx.x; 227 | const int chunk_l_id = blockIdx.y; 228 | const int chunk_c_id = blockIdx.z; 229 | const int tid = threadIdx.x; 230 | const int l_idx = tid / kNThreadsPerC; 231 | const int c_idx = tid % kNThreadsPerC; 232 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 233 | + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; 234 | weight_t *weight = reinterpret_cast(params.weight_ptr) 235 | + chunk_c_id * kChunkSizeC * params.weight_c_stride; 236 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 237 | + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; 238 | int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) 239 | + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; 240 | input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr 241 | : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; 242 | // The last L-chunk will also have enough info to write to final states, since it also contain a few x values 243 | // from the previous L-chunk. 244 | input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr 245 | : reinterpret_cast(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; 246 | 247 | #pragma unroll 248 | for (int l = 0; l < Ktraits::kNLoads; ++l) { 249 | input_t x_vals_load[kNElts] = {0}; 250 | if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen 251 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 252 | reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); 253 | } 254 | reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; 255 | } 256 | // Load the elements from the previous chunk that are needed for convolution. 257 | if (l_idx < kWidth - 1) { 258 | input_t x_vals_load[kNElts] = {0}; 259 | if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 260 | && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen 261 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 262 | reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); 263 | } else if (initial_states != nullptr 264 | && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 265 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 266 | reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); 267 | } 268 | reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; 269 | } 270 | 271 | __syncthreads(); 272 | 273 | if (final_states != nullptr 274 | && l_idx < kWidth - 1 275 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 276 | // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1) 277 | // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx] 278 | *reinterpret_cast(final_states) = reinterpret_cast(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; 279 | } 280 | 281 | constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); 282 | static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); 283 | constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; 284 | static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); 285 | // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity 286 | static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); 287 | static_assert((kLPerThread & (kLPerThread - 1)) == 0); 288 | static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); 289 | static_assert(kNThreadsPerRow <= 32); 290 | 291 | const int row_idx = tid / kNThreadsPerRow; 292 | const int col_idx = tid % kNThreadsPerRow; 293 | 294 | float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); 295 | float weight_vals[kWidth] = {0}; 296 | if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { 297 | #pragma unroll 298 | for (int w = 0; w < kWidth; ++w) { 299 | weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; 300 | } 301 | } 302 | float x_vals[kWidth - 1 + kLPerThread]; 303 | #pragma unroll 304 | for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { 305 | x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); 306 | } 307 | int seq_idx_thread[kWidth - 1 + kLPerThread]; 308 | if constexpr (kHasSeqIdx) { 309 | #pragma unroll 310 | for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { 311 | seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; 312 | } 313 | } 314 | 315 | float out_vals[kLPerThread]; 316 | #pragma unroll 317 | for (int i = 0; i < kLPerThread; ++i) { 318 | out_vals[i] = bias_val; 319 | const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; 320 | #pragma unroll 321 | for (int w = 0; w < kWidth; ++w) { 322 | if constexpr (!kHasSeqIdx) { 323 | out_vals[i] += weight_vals[w] * x_vals[i + w]; 324 | } else { 325 | out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; 326 | } 327 | } 328 | if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } 329 | } 330 | 331 | __syncthreads(); 332 | #pragma unroll 333 | for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } 334 | __syncthreads(); 335 | 336 | #pragma unroll 337 | for (int l = 0; l < Ktraits::kNLoads; ++l) { 338 | input_t out_vals_store[kNElts]; 339 | reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; 340 | if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen 341 | && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { 342 | *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0]; 343 | } 344 | } 345 | 346 | } 347 | 348 | template 349 | void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { 350 | BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { 351 | using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits; 352 | // constexpr int kSmemSize = Ktraits::kSmemSize; 353 | constexpr int kChunkSizeL = Ktraits::kChunkSizeL; 354 | constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; 355 | const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; 356 | const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; 357 | dim3 grid(params.batch, n_chunks_L, n_chunks_C); 358 | dim3 block(Ktraits::kNThreads); 359 | auto kernel = &causal_conv1d_channellast_fwd_kernel; 360 | // if (kSmemSize >= 48 * 1024) { 361 | // C10_CUDA_CHECK(cudaFuncSetAttribute( 362 | // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); 363 | // } 364 | // kernel<<>>(params); 365 | kernel<<>>(params); 366 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 367 | }); 368 | } 369 | 370 | template 371 | void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { 372 | if (params.width == 2) { 373 | causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); 374 | } else if (params.width == 3) { 375 | causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); 376 | } else if (params.width == 4) { 377 | causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); 378 | } 379 | } 380 | 381 | template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 382 | template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 383 | template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 384 | template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 385 | template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 386 | template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 387 | template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 388 | template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 389 | template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 390 | 391 | template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 392 | template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 393 | template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 394 | template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 395 | template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 396 | template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 397 | template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 398 | template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 399 | template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -------------------------------------------------------------------------------- /csrc/causal_conv1d_update.cu: -------------------------------------------------------------------------------- 1 | /****************************************************************************** 2 | * Copyright (c) 2023, Tri Dao. 3 | ******************************************************************************/ 4 | 5 | #include 6 | #include 7 | #include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK 8 | 9 | #include "causal_conv1d.h" 10 | #include "causal_conv1d_common.h" 11 | #include "static_switch.h" 12 | 13 | template 14 | struct Causal_conv1d_update_kernel_traits { 15 | using input_t = input_t_; 16 | using weight_t = weight_t_; 17 | static constexpr int kNThreads = kNThreads_; 18 | static constexpr int kWidth = kWidth_; 19 | static constexpr int kNBytes = sizeof(input_t); 20 | static_assert(kNBytes == 2 || kNBytes == 4); 21 | }; 22 | 23 | template 24 | __global__ __launch_bounds__(Ktraits::kNThreads) 25 | void causal_conv1d_update_kernel(ConvParamsBase params) { 26 | constexpr int kWidth = Ktraits::kWidth; 27 | constexpr int kNThreads = Ktraits::kNThreads; 28 | using input_t = typename Ktraits::input_t; 29 | using weight_t = typename Ktraits::weight_t; 30 | 31 | const int tidx = threadIdx.x; 32 | const int batch_id = blockIdx.x; 33 | const int channel_id = blockIdx.y * kNThreads + tidx; 34 | if (channel_id >= params.dim) return; 35 | 36 | input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride 37 | + channel_id * params.x_c_stride; 38 | 39 | // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor 40 | // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. 41 | const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr 42 | ? batch_id 43 | : params.conv_state_indices_ptr[batch_id]; 44 | input_t *conv_state = reinterpret_cast(params.conv_state_ptr) 45 | + conv_state_batch_coord * params.conv_state_batch_stride 46 | + channel_id * params.conv_state_c_stride; 47 | weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; 48 | input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride 49 | + channel_id * params.out_c_stride; 50 | float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); 51 | 52 | int state_len = params.conv_state_len; 53 | int advance_len = params.seqlen; 54 | int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; 55 | int update_idx = cache_seqlen - (kWidth - 1); 56 | update_idx = update_idx < 0 ? update_idx + state_len : update_idx; 57 | 58 | float weight_vals[kWidth] = {0}; 59 | #pragma unroll 60 | for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } 61 | 62 | float x_vals[kWidth] = {0}; 63 | if constexpr (!kIsCircularBuffer) { 64 | #pragma unroll 2 65 | for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { 66 | conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; 67 | } 68 | #pragma unroll 69 | for (int i = 0; i < kWidth - 1; ++i) { 70 | input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; 71 | if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { 72 | conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; 73 | } 74 | x_vals[i] = float(state_val); 75 | } 76 | } else { 77 | #pragma unroll 78 | for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { 79 | input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; 80 | x_vals[i] = float(state_val); 81 | } 82 | } 83 | #pragma unroll 2 84 | for (int i = 0; i < params.seqlen; ++i) { 85 | input_t x_val = x[i * params.x_l_stride]; 86 | if constexpr (!kIsCircularBuffer) { 87 | if (i < advance_len && state_len - advance_len + i >= 0) { 88 | conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; 89 | } 90 | } else { 91 | conv_state[update_idx * params.conv_state_l_stride] = x_val; 92 | ++update_idx; 93 | update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; 94 | } 95 | x_vals[kWidth - 1] = float(x_val); 96 | float out_val = bias_val; 97 | #pragma unroll 98 | for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } 99 | if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } 100 | out[i * params.out_l_stride] = input_t(out_val); 101 | // Shift the input buffer by 1 102 | #pragma unroll 103 | for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } 104 | } 105 | } 106 | 107 | template 108 | void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { 109 | using Ktraits = Causal_conv1d_update_kernel_traits; 110 | dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); 111 | auto kernel = params.cache_seqlens == nullptr 112 | ? &causal_conv1d_update_kernel 113 | : &causal_conv1d_update_kernel; 114 | kernel<<>>(params); 115 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 116 | } 117 | 118 | template 119 | void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { 120 | if (params.width == 2) { 121 | causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); 122 | } else if (params.width == 3) { 123 | causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); 124 | } else if (params.width == 4) { 125 | causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); 126 | } 127 | } 128 | 129 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 130 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 131 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 132 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 133 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 134 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 135 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 136 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 137 | template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); 138 | -------------------------------------------------------------------------------- /csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h 2 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h 3 | 4 | #pragma once 5 | 6 | /// @param COND - a boolean expression to switch by 7 | /// @param CONST_NAME - a name given for the constexpr bool variable. 8 | /// @param ... - code to execute for true and false 9 | /// 10 | /// Usage: 11 | /// ``` 12 | /// BOOL_SWITCH(flag, BoolConst, [&] { 13 | /// some_function(...); 14 | /// }); 15 | /// ``` 16 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \ 17 | [&] { \ 18 | if (COND) { \ 19 | static constexpr bool CONST_NAME = true; \ 20 | return __VA_ARGS__(); \ 21 | } else { \ 22 | static constexpr bool CONST_NAME = false; \ 23 | return __VA_ARGS__(); \ 24 | } \ 25 | }() 26 | -------------------------------------------------------------------------------- /rocm_patch/rocm6_0.patch: -------------------------------------------------------------------------------- 1 | --- /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h 2023-12-12 20:11:48.000000000 +0000 2 | +++ rocm_update_files/amd_hip_bf16.h 2024-05-20 17:40:26.983349079 +0000 3 | @@ -137,7 +137,7 @@ 4 | * \ingroup HIP_INTRINSIC_BFLOAT16_CONV 5 | * \brief Converts float to bfloat16 6 | */ 7 | -__HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) { 8 | +__HOST_DEVICE__ static inline __hip_bfloat16 __float2bfloat16(float f) { 9 | __hip_bfloat16 ret; 10 | union { 11 | float fp32; 12 | @@ -181,7 +181,7 @@ 13 | * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 14 | * \brief Converts and moves bfloat162 to float2 15 | */ 16 | -__HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) { 17 | +__HOST_DEVICE__ static inline float2 __bfloat1622float2(const __hip_bfloat162 a) { 18 | return float2{__bfloat162float(a.x), __bfloat162float(a.y)}; 19 | } 20 | 21 | @@ -209,7 +209,7 @@ 22 | * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 23 | * \brief Convert double to __hip_bfloat16 24 | */ 25 | -__HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) { 26 | +__HOST_DEVICE__ static inline __hip_bfloat16 __double2bfloat16(const double a) { 27 | return __float2bfloat16((float)a); 28 | } 29 | 30 | @@ -217,7 +217,7 @@ 31 | * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 32 | * \brief Convert float2 to __hip_bfloat162 33 | */ 34 | -__HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) { 35 | +__HOST_DEVICE__ static inline __hip_bfloat162 __float22bfloat162_rn(const float2 a) { 36 | return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)}; 37 | } 38 | 39 | @@ -247,7 +247,7 @@ 40 | * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 41 | * \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result 42 | */ 43 | -__HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); } 44 | +__HOST_DEVICE__ static inline float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); } 45 | 46 | /** 47 | * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 48 | @@ -275,7 +275,7 @@ 49 | * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 50 | * \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result 51 | */ 52 | -__HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); } 53 | +__HOST_DEVICE__ static inline float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); } 54 | 55 | /** 56 | * \ingroup HIP_INTRINSIC_BFLOAT162_CONV 57 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, Tri Dao. 2 | 3 | import sys 4 | import warnings 5 | import os 6 | import re 7 | import shutil 8 | import ast 9 | from pathlib import Path 10 | from packaging.version import parse, Version 11 | import platform 12 | 13 | from setuptools import setup, find_packages 14 | import subprocess 15 | 16 | import urllib.request 17 | import urllib.error 18 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 19 | 20 | import torch 21 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME, HIP_HOME 22 | 23 | 24 | with open("README.md", "r", encoding="utf-8") as fh: 25 | long_description = fh.read() 26 | 27 | 28 | # ninja build does not work unless include_dirs are abs path 29 | this_dir = os.path.dirname(os.path.abspath(__file__)) 30 | 31 | PACKAGE_NAME = "causal_conv1d" 32 | 33 | BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}" 34 | 35 | # FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels 36 | # SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation 37 | FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE" 38 | SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE" 39 | # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI 40 | FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE" 41 | 42 | 43 | def get_platform(): 44 | """ 45 | Returns the platform name as used in wheel filenames. 46 | """ 47 | if sys.platform.startswith("linux"): 48 | return "linux_x86_64" 49 | elif sys.platform == "darwin": 50 | mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) 51 | return f"macosx_{mac_version}_x86_64" 52 | elif sys.platform == "win32": 53 | return "win_amd64" 54 | else: 55 | raise ValueError("Unsupported platform: {}".format(sys.platform)) 56 | 57 | 58 | def get_cuda_bare_metal_version(cuda_dir): 59 | raw_output = subprocess.check_output( 60 | [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True 61 | ) 62 | output = raw_output.split() 63 | release_idx = output.index("release") + 1 64 | bare_metal_version = parse(output[release_idx].split(",")[0]) 65 | 66 | return raw_output, bare_metal_version 67 | 68 | 69 | def get_hip_version(rocm_dir): 70 | 71 | hipcc_bin = "hipcc" if rocm_dir is None else os.path.join(rocm_dir, "bin", "hipcc") 72 | try: 73 | raw_output = subprocess.check_output( 74 | [hipcc_bin, "--version"], universal_newlines=True 75 | ) 76 | except Exception as e: 77 | print( 78 | f"hip installation not found: {e} ROCM_PATH={os.environ.get('ROCM_PATH')}" 79 | ) 80 | return None, None 81 | 82 | for line in raw_output.split("\n"): 83 | if "HIP version" in line: 84 | rocm_version = parse(line.split()[-1].replace("-", "+")) # local version is not parsed correctly 85 | return line, rocm_version 86 | 87 | return None, None 88 | 89 | 90 | def get_torch_hip_version(): 91 | if torch.version.hip: 92 | return parse(torch.version.hip.split()[-1].replace("-", "+")) 93 | else: 94 | return None 95 | 96 | 97 | def check_if_hip_home_none(global_option: str) -> None: 98 | 99 | if HIP_HOME is not None: 100 | return 101 | # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary 102 | # in that case. 103 | warnings.warn( 104 | f"{global_option} was requested, but hipcc was not found. Are you sure your environment has hipcc available?" 105 | ) 106 | 107 | 108 | def check_if_cuda_home_none(global_option: str) -> None: 109 | if CUDA_HOME is not None: 110 | return 111 | # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary 112 | # in that case. 113 | warnings.warn( 114 | f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " 115 | "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " 116 | "only images whose names contain 'devel' will provide nvcc." 117 | ) 118 | 119 | 120 | def append_nvcc_threads(nvcc_extra_args): 121 | return nvcc_extra_args + ["--threads", "4"] 122 | 123 | 124 | cmdclass = {} 125 | ext_modules = [] 126 | 127 | 128 | HIP_BUILD = bool(torch.version.hip) 129 | 130 | if not SKIP_CUDA_BUILD: 131 | 132 | print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) 133 | TORCH_MAJOR = int(torch.__version__.split(".")[0]) 134 | TORCH_MINOR = int(torch.__version__.split(".")[1]) 135 | 136 | 137 | cc_flag = [] 138 | 139 | if HIP_BUILD: 140 | check_if_hip_home_none(PACKAGE_NAME) 141 | 142 | rocm_home = os.getenv("ROCM_PATH") 143 | _, hip_version = get_hip_version(rocm_home) 144 | 145 | 146 | if HIP_HOME is not None: 147 | if hip_version < Version("6.0"): 148 | raise RuntimeError( 149 | f"{PACKAGE_NAME} is only supported on ROCm 6.0 and above. " 150 | "Note: make sure HIP has a supported version by running hipcc --version." 151 | ) 152 | if hip_version == Version("6.0"): 153 | warnings.warn( 154 | f"{PACKAGE_NAME} requires a patch to be applied when running on ROCm 6.0. " 155 | "Refer to the README.md for detailed instructions.", 156 | UserWarning 157 | ) 158 | 159 | cc_flag.append("-DBUILD_PYTHON_PACKAGE") 160 | 161 | else: 162 | check_if_cuda_home_none(PACKAGE_NAME) 163 | # Check, if CUDA11 is installed for compute capability 8.0 164 | 165 | if CUDA_HOME is not None: 166 | _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 167 | if bare_metal_version < Version("11.6"): 168 | raise RuntimeError( 169 | f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. " 170 | "Note: make sure nvcc has a supported version by running nvcc -V." 171 | ) 172 | 173 | cc_flag.append("-gencode") 174 | cc_flag.append("arch=compute_53,code=sm_53") 175 | cc_flag.append("-gencode") 176 | cc_flag.append("arch=compute_62,code=sm_62") 177 | cc_flag.append("-gencode") 178 | cc_flag.append("arch=compute_70,code=sm_70") 179 | cc_flag.append("-gencode") 180 | cc_flag.append("arch=compute_72,code=sm_72") 181 | cc_flag.append("-gencode") 182 | cc_flag.append("arch=compute_80,code=sm_80") 183 | cc_flag.append("-gencode") 184 | cc_flag.append("arch=compute_87,code=sm_87") 185 | if bare_metal_version >= Version("11.8"): 186 | cc_flag.append("-gencode") 187 | cc_flag.append("arch=compute_90,code=sm_90") 188 | if bare_metal_version >= Version("12.8"): 189 | cc_flag.append("-gencode") 190 | cc_flag.append("arch=compute_100,code=sm_100") 191 | 192 | # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as 193 | # torch._C._GLIBCXX_USE_CXX11_ABI 194 | # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 195 | if FORCE_CXX11_ABI: 196 | torch._C._GLIBCXX_USE_CXX11_ABI = True 197 | 198 | 199 | if HIP_BUILD: 200 | extra_compile_args = { 201 | "cxx": ["-O3", "-std=c++17"], 202 | "nvcc": [ 203 | "-O3", 204 | "-std=c++17", 205 | f"--offload-arch={os.getenv('HIP_ARCHITECTURES', 'native')}", 206 | "-U__CUDA_NO_HALF_OPERATORS__", 207 | "-U__CUDA_NO_HALF_CONVERSIONS__", 208 | "-fgpu-flush-denormals-to-zero", 209 | ] 210 | + cc_flag, 211 | } 212 | else: 213 | extra_compile_args = { 214 | "cxx": ["-O3"], 215 | "nvcc": append_nvcc_threads( 216 | [ 217 | "-O3", 218 | "-U__CUDA_NO_HALF_OPERATORS__", 219 | "-U__CUDA_NO_HALF_CONVERSIONS__", 220 | "-U__CUDA_NO_BFLOAT16_OPERATORS__", 221 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 222 | "-U__CUDA_NO_BFLOAT162_OPERATORS__", 223 | "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", 224 | "--expt-relaxed-constexpr", 225 | "--expt-extended-lambda", 226 | "--use_fast_math", 227 | "--ptxas-options=-v", 228 | "-lineinfo", 229 | ] 230 | + cc_flag 231 | ), 232 | } 233 | 234 | ext_modules.append( 235 | CUDAExtension( 236 | name="causal_conv1d_cuda", 237 | sources=[ 238 | "csrc/causal_conv1d.cpp", 239 | "csrc/causal_conv1d_fwd.cu", 240 | "csrc/causal_conv1d_bwd.cu", 241 | "csrc/causal_conv1d_update.cu", 242 | ], 243 | extra_compile_args=extra_compile_args, 244 | include_dirs=[Path(this_dir) / "csrc" / "causal_conv1d"], 245 | ) 246 | ) 247 | 248 | 249 | def get_package_version(): 250 | with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f: 251 | version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) 252 | public_version = ast.literal_eval(version_match.group(1)) 253 | local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION") 254 | if local_version: 255 | return f"{public_version}+{local_version}" 256 | else: 257 | return str(public_version) 258 | 259 | 260 | def get_wheel_url(): 261 | 262 | # Determine the version numbers that will be used to determine the correct wheel 263 | torch_version_raw = parse(torch.__version__) 264 | 265 | if HIP_BUILD: 266 | # We're using the HIP version used to build torch, not the one currently installed 267 | torch_hip_version = get_torch_hip_version() 268 | hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}" 269 | else: 270 | # We're using the CUDA version used to build torch, not the one currently installed 271 | # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) 272 | torch_cuda_version = parse(torch.version.cuda) 273 | # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 274 | # to save CI time. Minor versions should be compatible. 275 | torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") 276 | cuda_version = f"{torch_cuda_version.major}" 277 | 278 | gpu_compute_version = hip_version if HIP_BUILD else cuda_version 279 | cuda_or_hip = "hip" if HIP_BUILD else "cu" 280 | 281 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 282 | platform_name = get_platform() 283 | causal_conv1d_version = get_package_version() 284 | 285 | torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" 286 | cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() 287 | 288 | # Determine wheel URL based on CUDA version, torch version, python version and OS 289 | wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+{cuda_or_hip}{gpu_compute_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" 290 | 291 | wheel_url = BASE_WHEEL_URL.format( 292 | tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename 293 | ) 294 | return wheel_url, wheel_filename 295 | 296 | 297 | class CachedWheelsCommand(_bdist_wheel): 298 | """ 299 | The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot 300 | find an existing wheel (which is currently the case for all installs). We use 301 | the environment parameters to detect whether there is already a pre-built version of a compatible 302 | wheel available and short-circuits the standard full build pipeline. 303 | """ 304 | 305 | def run(self): 306 | if FORCE_BUILD: 307 | return super().run() 308 | 309 | wheel_url, wheel_filename = get_wheel_url() 310 | print("Guessing wheel URL: ", wheel_url) 311 | try: 312 | urllib.request.urlretrieve(wheel_url, wheel_filename) 313 | 314 | # Make the archive 315 | # Lifted from the root wheel processing command 316 | # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 317 | if not os.path.exists(self.dist_dir): 318 | os.makedirs(self.dist_dir) 319 | 320 | impl_tag, abi_tag, plat_tag = self.get_tag() 321 | archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" 322 | 323 | wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") 324 | print("Raw wheel path", wheel_path) 325 | shutil.move(wheel_filename, wheel_path) 326 | except urllib.error.HTTPError: 327 | print("Precompiled wheel not found. Building from source...") 328 | # If the wheel could not be downloaded, build from source 329 | super().run() 330 | 331 | 332 | setup( 333 | name=PACKAGE_NAME, 334 | version=get_package_version(), 335 | packages=find_packages( 336 | exclude=( 337 | "build", 338 | "csrc", 339 | "include", 340 | "tests", 341 | "dist", 342 | "docs", 343 | "benchmarks", 344 | "causal_conv1d.egg-info", 345 | ) 346 | ), 347 | author="Tri Dao", 348 | author_email="tri@tridao.me", 349 | description="Causal depthwise conv1d in CUDA, with a PyTorch interface", 350 | long_description=long_description, 351 | long_description_content_type="text/markdown", 352 | url="https://github.com/Dao-AILab/causal-conv1d", 353 | classifiers=[ 354 | "Programming Language :: Python :: 3", 355 | "License :: OSI Approved :: BSD License", 356 | "Operating System :: Unix", 357 | ], 358 | ext_modules=ext_modules, 359 | cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} 360 | if ext_modules 361 | else { 362 | "bdist_wheel": CachedWheelsCommand, 363 | }, 364 | python_requires=">=3.9", 365 | install_requires=[ 366 | "torch", 367 | "packaging", 368 | "ninja", 369 | ], 370 | ) 371 | -------------------------------------------------------------------------------- /tests/test_causal_conv1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2024, Tri Dao. 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import pytest 9 | 10 | from einops import rearrange 11 | 12 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_ref 13 | from causal_conv1d.causal_conv1d_interface import causal_conv1d_update, causal_conv1d_update_ref 14 | from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states, causal_conv1d_varlen_states_ref 15 | 16 | 17 | @pytest.mark.parametrize("return_final_states", [False, True]) 18 | # @pytest.mark.parametrize("return_final_states", [True]) 19 | @pytest.mark.parametrize("has_initial_states", [False, True]) 20 | # @pytest.mark.parametrize("has_initial_states", [False]) 21 | @pytest.mark.parametrize("channel_last", [False, True]) 22 | # @pytest.mark.parametrize('channel_last', [True]) 23 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 24 | # @pytest.mark.parametrize('itype', [torch.float16]) 25 | @pytest.mark.parametrize("silu_activation", [False, True]) 26 | # @pytest.mark.parametrize('silu_activation', [True]) 27 | @pytest.mark.parametrize("has_bias", [False, True]) 28 | # @pytest.mark.parametrize('has_bias', [True]) 29 | @pytest.mark.parametrize("width", [2, 3, 4]) 30 | # @pytest.mark.parametrize('width', [3]) 31 | @pytest.mark.parametrize( 32 | "seqlen", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 33 | ) 34 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 35 | # @pytest.mark.parametrize('seqlen', [128]) 36 | @pytest.mark.parametrize('dim', [64, 4096 + 32]) 37 | # @pytest.mark.parametrize('dim', [64]) 38 | def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states): 39 | if not channel_last and (has_initial_states or return_final_states): 40 | pytest.skip("Only channel_last support initial_states or return_final_states") 41 | device = "cuda" 42 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 43 | if itype == torch.bfloat16: 44 | rtol, atol = 1e-2, 5e-2 45 | rtolw, atolw = (1e-3, 1e-3) 46 | # set seed 47 | torch.random.manual_seed(0) 48 | batch = 2 49 | # batch = 1 50 | if not channel_last: 51 | x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 52 | else: 53 | x = rearrange( 54 | torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 55 | ).requires_grad_() 56 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 57 | if has_bias: 58 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 59 | else: 60 | bias = None 61 | if has_initial_states: 62 | initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_() 63 | else: 64 | initial_states = None 65 | x_ref = x.detach().clone().requires_grad_() 66 | weight_ref = weight.detach().clone().requires_grad_() 67 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None 68 | initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None 69 | activation = None if not silu_activation else "silu" 70 | out = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states, 71 | activation=activation) 72 | out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation) 73 | if return_final_states: 74 | out, final_states = out 75 | out_ref, final_states_ref = out_ref 76 | print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}") 77 | print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}") 78 | assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) 79 | 80 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 81 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 82 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 83 | 84 | if return_final_states: 85 | out += F.sigmoid(final_states).sum(dim=-1, keepdim=True) 86 | out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True) 87 | 88 | g = torch.randn_like(out) 89 | out.backward(g) 90 | out_ref.backward(g) 91 | 92 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") 93 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") 94 | if has_bias: 95 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") 96 | if has_initial_states: 97 | print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}") 98 | 99 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 100 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) 101 | if has_bias: 102 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) 103 | if has_initial_states: 104 | assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 105 | 106 | 107 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 108 | # @pytest.mark.parametrize('itype', [torch.float16]) 109 | @pytest.mark.parametrize("silu_activation", [False, True]) 110 | # @pytest.mark.parametrize('silu_activation', [True]) 111 | @pytest.mark.parametrize("has_bias", [False, True]) 112 | # @pytest.mark.parametrize('has_bias', [True]) 113 | @pytest.mark.parametrize("has_cache_seqlens", [False, True]) 114 | # @pytest.mark.parametrize('has_cache_seqlens', [True]) 115 | @pytest.mark.parametrize("seqlen", [1, 4, 5]) 116 | # @pytest.mark.parametrize('seqlen', [4]) 117 | @pytest.mark.parametrize("width", [2, 3, 4]) 118 | # @pytest.mark.parametrize('width', [4]) 119 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 120 | # @pytest.mark.parametrize("dim", [2048]) 121 | def test_causal_conv1d_update(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype): 122 | device = "cuda" 123 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 124 | if itype == torch.bfloat16: 125 | rtol, atol = 1e-2, 5e-2 126 | rtolw, atolw = (1e-3, 1e-3) 127 | # set seed 128 | torch.random.manual_seed(0) 129 | batch = 64 130 | # batch = 1 131 | # dim = 64 132 | x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2) 133 | state_len = torch.randint(width - 1, width + 10, (1,)).item() 134 | conv_state = torch.randn(batch, state_len, dim, device=device, dtype=itype).transpose(-1, -2) 135 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 136 | if has_bias: 137 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 138 | else: 139 | bias = None 140 | conv_state_ref = conv_state.detach().clone() 141 | activation = None if not silu_activation else "silu" 142 | cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device) 143 | if has_cache_seqlens else None) 144 | out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, cache_seqlens=cache_seqlens) 145 | out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens) 146 | 147 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 148 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 149 | assert torch.equal(conv_state, conv_state_ref) 150 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 151 | 152 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 153 | # @pytest.mark.parametrize('itype', [torch.float16]) 154 | @pytest.mark.parametrize("silu_activation", [False, True]) 155 | # @pytest.mark.parametrize('silu_activation', [True]) 156 | @pytest.mark.parametrize("has_bias", [False, True]) 157 | # @pytest.mark.parametrize('has_bias', [True]) 158 | @pytest.mark.parametrize("has_cache_seqlens", [False, True]) 159 | # @pytest.mark.parametrize('has_cache_seqlens', [True]) 160 | @pytest.mark.parametrize("seqlen", [1, 4, 5]) 161 | # @pytest.mark.parametrize('seqlen', [4]) 162 | @pytest.mark.parametrize("width", [2, 3, 4]) 163 | # @pytest.mark.parametrize('width', [4]) 164 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 165 | # @pytest.mark.parametrize("dim", [2048]) 166 | def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype): 167 | device = "cuda" 168 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 169 | if itype == torch.bfloat16: 170 | rtol, atol = 1e-2, 5e-2 171 | rtolw, atolw = (1e-3, 1e-3) 172 | # set seed 173 | torch.random.manual_seed(0) 174 | batch = 64 175 | # batch = 1 176 | # dim = 64 177 | x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2) 178 | state_len = torch.randint(width - 1, width + 10, (1,)).item() 179 | 180 | total_entries = 10 * batch 181 | conv_state = torch.randn(total_entries, state_len, dim, device=device, dtype=itype).transpose(-1, -2) 182 | conv_state_indices = torch.randperm(total_entries)[:batch].to(dtype=torch.int32, device=device) 183 | 184 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 185 | if has_bias: 186 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 187 | else: 188 | bias = None 189 | conv_state_ref = conv_state[conv_state_indices, :].detach().clone() 190 | activation = None if not silu_activation else "silu" 191 | cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device) 192 | if has_cache_seqlens else None) 193 | out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, 194 | cache_seqlens=cache_seqlens, conv_state_indices=conv_state_indices) 195 | out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens) 196 | 197 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 198 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 199 | assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) 200 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 201 | 202 | 203 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 204 | # @pytest.mark.parametrize('itype', [torch.float16]) 205 | @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) 206 | # @pytest.mark.parametrize("dim", [2048]) 207 | def test_causal_conv1d_get_states(dim, itype): 208 | device = "cuda" 209 | # set seed 210 | torch.random.manual_seed(0) 211 | seqlens = torch.randint(1, 32, (100,), device=device) 212 | total_seqlen = seqlens.sum().item() 213 | x = torch.randn(total_seqlen, dim, device=device, dtype=itype) 214 | cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0)) 215 | state_len = 20 216 | out = causal_conv1d_varlen_states(x, cu_seqlens, state_len) 217 | out_ref = causal_conv1d_varlen_states_ref(x, cu_seqlens, state_len) 218 | assert torch.equal(out, out_ref) 219 | 220 | 221 | # @pytest.mark.parametrize("channel_last", [False, True]) 222 | @pytest.mark.parametrize('channel_last', [True]) 223 | # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 224 | @pytest.mark.parametrize('itype', [torch.bfloat16]) 225 | # @pytest.mark.parametrize("silu_activation", [False, True]) 226 | @pytest.mark.parametrize('silu_activation', [True]) 227 | # @pytest.mark.parametrize("has_bias", [False, True]) 228 | @pytest.mark.parametrize('has_bias', [True]) 229 | # @pytest.mark.parametrize("width", [2, 3, 4]) 230 | @pytest.mark.parametrize('width', [4]) 231 | @pytest.mark.parametrize( 232 | # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 233 | "seqlen", [2048] 234 | ) 235 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 236 | # @pytest.mark.parametrize('seqlen', [128]) 237 | def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last): 238 | device = "cuda" 239 | # set seed 240 | torch.random.manual_seed(0) 241 | batch = 2 242 | # batch = 1 243 | dim = 4096 + 32 # Try dim not divisible by 64 244 | # dim = 64 245 | if not channel_last: 246 | x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_() 247 | else: 248 | x = rearrange( 249 | torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 250 | ).requires_grad_() 251 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 252 | if has_bias: 253 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 254 | else: 255 | bias = None 256 | activation = None if not silu_activation else "silu" 257 | out0 = causal_conv1d_fn(x, weight, bias, activation=activation) 258 | g = torch.randn_like(out0) 259 | dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g) 260 | dw_atol = 1e-4 261 | db_atol = 1e-4 262 | 263 | for i in range(10000): 264 | out = causal_conv1d_fn(x, weight, bias, activation=activation) 265 | dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g) 266 | dw_equal = torch.allclose(dw, dw0, atol=dw_atol) 267 | # if not dw_equal: 268 | # breakpoint() 269 | if has_bias: 270 | db_equal = torch.allclose(db, db0, atol=db_atol) 271 | # if not db_equal: 272 | # breakpoint() 273 | assert torch.equal(out, out0) 274 | assert torch.equal(dx, dx0) 275 | assert dw_equal 276 | if has_bias: 277 | assert dw_equal 278 | 279 | 280 | @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) 281 | # @pytest.mark.parametrize('itype', [torch.float16]) 282 | @pytest.mark.parametrize("silu_activation", [False, True]) 283 | # @pytest.mark.parametrize('silu_activation', [False]) 284 | @pytest.mark.parametrize("has_bias", [False, True]) 285 | # @pytest.mark.parametrize('has_bias', [False]) 286 | @pytest.mark.parametrize("width", [2, 3, 4]) 287 | # @pytest.mark.parametrize('width', [2]) 288 | @pytest.mark.parametrize( 289 | "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096] 290 | ) 291 | # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) 292 | # @pytest.mark.parametrize('seqlen', [2048]) 293 | @pytest.mark.parametrize('dim', [64, 4096 + 32]) 294 | # @pytest.mark.parametrize('dim', [64]) 295 | def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, itype): 296 | device = "cuda" 297 | rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) 298 | if itype == torch.bfloat16: 299 | rtol, atol = 1e-2, 5e-2 300 | rtolw, atolw = (1e-3, 1e-3) 301 | # set seed 302 | torch.random.manual_seed(seqlen + dim + width) 303 | batch = 3 304 | seqlens = [] 305 | for b in range(batch): 306 | nsplits = torch.randint(1, 5, (1,)).item() 307 | eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values 308 | seqlens.append(torch.diff(torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist()) 309 | assert sum(seqlens[-1]) == seqlen 310 | assert all(s > 0 for s in seqlens[-1]) 311 | # Only support channel_last 312 | x = rearrange( 313 | torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s" 314 | ).requires_grad_() 315 | weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True) 316 | if has_bias: 317 | bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True) 318 | else: 319 | bias = None 320 | seq_idx = torch.stack([torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(sl)], dim=0) 321 | for sl in seqlens], dim=0) 322 | x_ref = x.detach().clone().requires_grad_() 323 | weight_ref = weight.detach().clone().requires_grad_() 324 | bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None 325 | activation = None if not silu_activation else "silu" 326 | out = causal_conv1d_fn(x, weight, bias, seq_idx=seq_idx, activation=activation) 327 | out_ref = [] 328 | for b in range(batch): 329 | out_ref_b = [] 330 | for x_s in torch.split(x_ref[[b]], seqlens[b], dim=2): 331 | out_ref_b.append(causal_conv1d_ref(x_s, weight_ref, bias_ref, activation=activation)) 332 | out_ref.append(torch.cat(out_ref_b, dim=2)) 333 | out_ref = torch.cat(out_ref, dim=0) 334 | 335 | print(f"Output max diff: {(out - out_ref).abs().max().item()}") 336 | print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") 337 | assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) 338 | 339 | g = torch.randn_like(out) 340 | out_ref.backward(g) 341 | out.backward(g) 342 | 343 | print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}") 344 | print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}") 345 | if has_bias: 346 | print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}") 347 | 348 | assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol) 349 | assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw) 350 | if has_bias: 351 | assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw) 352 | --------------------------------------------------------------------------------