├── .flake8 ├── .github ├── scripts │ ├── install_cuda.sh │ ├── install_cudnn.sh │ └── install_torch.sh └── workflows │ ├── run_tests_cpu.yml │ ├── run_tests_cuda.yml │ └── style_check.yml ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── MANIFEST.in ├── README.md ├── cmake ├── Modules │ ├── FetchContent.cmake │ ├── FetchContent │ │ └── CMakeLists.cmake.in │ └── README.md ├── googletest.cmake ├── pybind11.cmake ├── select_compute_arch.cmake ├── torch.cmake └── transform.cmake ├── fast_rnnt ├── CMakeLists.txt ├── csrc │ ├── CMakeLists.txt │ ├── device_guard.h │ ├── mutual_information.h │ ├── mutual_information_cpu.cu │ └── mutual_information_cuda.cu └── python │ ├── CMakeLists.txt │ ├── csrc │ ├── CMakeLists.txt │ ├── fast_rnnt.cu │ ├── fast_rnnt.h │ ├── mutual_information.cu │ └── mutual_information.h │ ├── fast_rnnt │ ├── __init__.py │ ├── mutual_information.py │ └── rnnt_loss.py │ └── tests │ ├── CMakeLists.txt │ ├── mutual_information_test.py │ └── rnnt_loss_test.py ├── package.sh ├── requirements.txt └── setup.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | show-source=true 3 | statistics=true 4 | max-line-length=80 5 | 6 | exclude = 7 | .git, 8 | .github, 9 | setup.py, 10 | build, 11 | -------------------------------------------------------------------------------- /.github/scripts/install_cuda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | echo "cuda version: $cuda" 18 | 19 | case "$cuda" in 20 | 10.0) 21 | url=https://developer.nvidia.com/compute/cuda/10.0/Prod/local_installers/cuda_10.0.130_410.48_linux 22 | ;; 23 | 10.1) 24 | # WARNING: there are bugs in 25 | # https://developer.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.105_418.39_linux.run 26 | # with GCC 7. Please use the following version 27 | url=http://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda_10.1.243_418.87.00_linux.run 28 | ;; 29 | 10.2) 30 | url=http://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda_10.2.89_440.33.01_linux.run 31 | ;; 32 | 11.0) 33 | url=http://developer.download.nvidia.com/compute/cuda/11.0.2/local_installers/cuda_11.0.2_450.51.05_linux.run 34 | ;; 35 | 11.1) 36 | # url=https://developer.download.nvidia.com/compute/cuda/11.1.0/local_installers/cuda_11.1.0_455.23.05_linux.run 37 | url=https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run 38 | ;; 39 | 11.3) 40 | # url=https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda_11.3.0_465.19.01_linux.run 41 | url=https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.19.01_linux.run 42 | ;; 43 | 11.5) 44 | url=https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda_11.5.2_495.29.05_linux.run 45 | ;; 46 | 11.6) 47 | url=https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run 48 | ;; 49 | 11.7) 50 | url=https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda_11.7.1_515.65.01_linux.run 51 | ;; 52 | *) 53 | echo "Unknown cuda version: $cuda" 54 | exit 1 55 | ;; 56 | esac 57 | 58 | function retry() { 59 | $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) 60 | } 61 | 62 | retry curl -LSs -O $url 63 | filename=$(basename $url) 64 | echo "filename: $filename" 65 | chmod +x ./$filename 66 | sudo ./$filename --toolkit --silent 67 | rm -fv ./$filename 68 | 69 | export CUDA_HOME=/usr/local/cuda 70 | export PATH=$CUDA_HOME/bin:$PATH 71 | export LD_LIBRARY_PATH=$CUDA_HOME/lib:$LD_LIBRARY_PATH 72 | export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH 73 | -------------------------------------------------------------------------------- /.github/scripts/install_cudnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | case $cuda in 18 | 10.0) 19 | filename=cudnn-10.0-linux-x64-v7.6.5.32.tgz 20 | ;; 21 | 10.1) 22 | filename=cudnn-10.1-linux-x64-v8.0.2.39.tgz 23 | ;; 24 | 10.2) 25 | filename=cudnn-10.2-linux-x64-v8.0.2.39.tgz 26 | ;; 27 | 11.0) 28 | filename=cudnn-11.0-linux-x64-v8.0.5.39.tgz 29 | ;; 30 | 11.1) 31 | filename=cudnn-11.1-linux-x64-v8.0.4.30.tgz 32 | ;; 33 | 11.3) 34 | filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz 35 | ;; 36 | 11.5) 37 | filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz 38 | ;; 39 | 11.6) 40 | filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz 41 | ;; 42 | 11.7) 43 | filename=cudnn-11.3-linux-x64-v8.2.0.53.tgz 44 | ;; 45 | *) 46 | echo "Unsupported cuda version: $cuda" 47 | exit 1 48 | ;; 49 | esac 50 | 51 | command -v git-lfs >/dev/null 2>&1 || { echo >&2 "\nPlease install 'git-lfs' first."; exit 2; } 52 | 53 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/cudnn 54 | cd cudnn 55 | git lfs pull --include="$filename" 56 | 57 | sudo tar xf ./$filename --strip-components=1 -C /usr/local/cuda 58 | 59 | # save disk space 60 | git lfs prune && cd .. && rm -rf cudnn 61 | 62 | sudo sed -i '59i#define CUDNN_MAJOR 8' /usr/local/cuda/include/cudnn.h 63 | -------------------------------------------------------------------------------- /.github/scripts/install_torch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | echo "torch version: $torch" 18 | echo "cuda version: $cuda" 19 | 20 | case ${torch} in 21 | 1.5.*) 22 | case ${cuda} in 23 | 10.1) 24 | package="torch==${torch}+cu101" 25 | url=https://download.pytorch.org/whl/torch_stable.html 26 | ;; 27 | 10.2) 28 | package="torch==${torch}" 29 | # Leave url empty to use PyPI. 30 | # torch_stable provides cu92 but we want cu102 31 | url= 32 | ;; 33 | esac 34 | ;; 35 | 1.6.0) 36 | case ${cuda} in 37 | 10.1) 38 | package="torch==1.6.0+cu101" 39 | url=https://download.pytorch.org/whl/torch_stable.html 40 | ;; 41 | 10.2) 42 | package="torch==1.6.0" 43 | # Leave it empty to use PyPI. 44 | # torch_stable provides cu92 but we want cu102 45 | url= 46 | ;; 47 | esac 48 | ;; 49 | 1.7.*) 50 | case ${cuda} in 51 | 10.1) 52 | package="torch==${torch}+cu101" 53 | url=https://download.pytorch.org/whl/torch_stable.html 54 | ;; 55 | 10.2) 56 | package="torch==${torch}" 57 | # Leave it empty to use PyPI. 58 | # torch_stable provides cu92 but we want cu102 59 | url= 60 | ;; 61 | 11.0) 62 | package="torch==${torch}+cu110" 63 | url=https://download.pytorch.org/whl/torch_stable.html 64 | ;; 65 | esac 66 | ;; 67 | 1.8.*) 68 | case ${cuda} in 69 | 10.1) 70 | package="torch==${torch}+cu101" 71 | url=https://download.pytorch.org/whl/torch_stable.html 72 | ;; 73 | 10.2) 74 | package="torch==${torch}" 75 | # Leave it empty to use PyPI. 76 | url= 77 | ;; 78 | 11.1) 79 | package="torch==${torch}+cu111" 80 | url=https://download.pytorch.org/whl/torch_stable.html 81 | ;; 82 | esac 83 | ;; 84 | 1.9.*) 85 | case ${cuda} in 86 | 10.2) 87 | package="torch==${torch}" 88 | # Leave it empty to use PyPI. 89 | url= 90 | ;; 91 | 11.1) 92 | package="torch==${torch}+cu111" 93 | url=https://download.pytorch.org/whl/torch_stable.html 94 | ;; 95 | esac 96 | ;; 97 | 1.10.*) 98 | case ${cuda} in 99 | 10.2) 100 | package="torch==${torch}" 101 | # Leave it empty to use PyPI. 102 | url= 103 | ;; 104 | 11.1) 105 | package="torch==${torch}+cu111" 106 | url=https://download.pytorch.org/whl/torch_stable.html 107 | ;; 108 | 11.3) 109 | package="torch==${torch}+cu113" 110 | url=https://download.pytorch.org/whl/torch_stable.html 111 | ;; 112 | esac 113 | ;; 114 | 1.11.*) 115 | case ${cuda} in 116 | 10.2) 117 | package="torch==${torch}" 118 | # Leave it empty to use PyPI. 119 | url= 120 | ;; 121 | 11.3) 122 | package="torch==${torch}+cu113" 123 | url=https://download.pytorch.org/whl/torch_stable.html 124 | ;; 125 | 11.5) 126 | package="torch==${torch}+cu115" 127 | url=https://download.pytorch.org/whl/torch_stable.html 128 | ;; 129 | esac 130 | ;; 131 | 1.12.*) 132 | case ${cuda} in 133 | 10.2) 134 | package="torch==${torch}" 135 | # Leave it empty to use PyPI. 136 | url= 137 | ;; 138 | 11.3) 139 | package="torch==${torch}+cu113" 140 | url=https://download.pytorch.org/whl/torch_stable.html 141 | ;; 142 | 11.6) 143 | package="torch==${torch}+cu116" 144 | url=https://download.pytorch.org/whl/torch_stable.html 145 | ;; 146 | esac 147 | ;; 148 | 1.13.*) 149 | case ${cuda} in 150 | 11.6) 151 | package="torch==${torch}+cu116" 152 | url=https://download.pytorch.org/whl/torch_stable.html 153 | ;; 154 | 11.7) 155 | package="torch==${torch}" 156 | # Leave it empty to use PyPI. 157 | url= 158 | ;; 159 | esac 160 | ;; 161 | 2.0.*) 162 | case ${cuda} in 163 | 11.7) 164 | package="torch==${torch}+cu117" 165 | url=https://download.pytorch.org/whl/torch_stable.html 166 | ;; 167 | 11.8) 168 | package="torch==${torch}+cu118" 169 | url=https://download.pytorch.org/whl/torch_stable.html 170 | ;; 171 | esac 172 | ;; 173 | *) 174 | echo "Unsupported PyTorch version: ${torch}" 175 | exit 1 176 | ;; 177 | esac 178 | 179 | function retry() { 180 | $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) 181 | } 182 | 183 | if [ x"${url}" == "x" ]; then 184 | retry python3 -m pip install -q $package 185 | else 186 | retry python3 -m pip install -q $package -f $url 187 | fi 188 | 189 | rm -rfv ~/.cache/pip 190 | -------------------------------------------------------------------------------- /.github/workflows/run_tests_cpu.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Xiaomi Corp. (Wei Kang) 2 | 3 | # See ../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # refer to https://github.com/actions/starter-workflows/pull/47/files 18 | 19 | name: run-tests-cpu 20 | 21 | on: 22 | push: 23 | branches: 24 | - master 25 | paths: 26 | - '.github/workflows/run_tests_cpu.yml' 27 | - 'CMakeLists.txt' 28 | - 'cmake/**' 29 | - 'fast_rnnt/csrc/**' 30 | - 'fast_rnnt/python/**' 31 | pull_request: 32 | branches: 33 | - master 34 | paths: 35 | - '.github/workflows/run_tests_cpu.yml' 36 | - 'CMakeLists.txt' 37 | - 'cmake/**' 38 | - 'fast_rnnt/csrc/**' 39 | - 'fast_rnnt/python/**' 40 | 41 | concurrency: 42 | group: run-tests-cpu-${{ github.ref }} 43 | cancel-in-progress: true 44 | 45 | jobs: 46 | run-tests-cpu: 47 | runs-on: ${{ matrix.os }} 48 | strategy: 49 | fail-fast: false 50 | matrix: 51 | os: [ubuntu-latest, macos-latest] 52 | torch: ["1.12.1"] 53 | torchaudio: ["0.12.1"] 54 | python-version: ["3.9"] 55 | build_type: ["Release", "Debug"] 56 | 57 | steps: 58 | # refer to https://github.com/actions/checkout 59 | - uses: actions/checkout@v2 60 | 61 | - name: Display GCC version 62 | run: | 63 | gcc --version 64 | 65 | - name: Display clang version 66 | if: startsWith(matrix.os, 'macos') 67 | run: | 68 | clang --version 69 | 70 | - name: Setup Python ${{ matrix.python-version }} 71 | uses: actions/setup-python@v2 72 | with: 73 | python-version: ${{ matrix.python-version }} 74 | 75 | - name: Display Python version 76 | run: python -c "import sys; print(sys.version)" 77 | 78 | - name: Install PyTorch ${{ matrix.torch }} 79 | if: startsWith(matrix.os, 'ubuntu') 80 | shell: bash 81 | run: | 82 | python3 -m pip install -qq --upgrade pip 83 | python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html 84 | python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }} -f https://download.pytorch.org/whl/cpu/torch_stable.html 85 | python3 -c "import torch; print('torch version:', torch.__version__)" 86 | 87 | python3 -m torch.utils.collect_env 88 | 89 | - name: Install PyTorch ${{ matrix.torch }} 90 | if: startsWith(matrix.os, 'macos') 91 | shell: bash 92 | run: | 93 | python3 -m pip install -qq --upgrade pip 94 | python3 -m pip install -qq torch==${{ matrix.torch }} 95 | python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }} 96 | python3 -c "import torch; print('torch version:', torch.__version__)" 97 | 98 | python3 -m torch.utils.collect_env 99 | 100 | - name: Configure CMake 101 | shell: bash 102 | env: 103 | torch: ${{ matrix.torch }} 104 | run: | 105 | mkdir build 106 | cd build 107 | cmake -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DFT_WITH_CUDA=OFF .. 108 | 109 | - name: ${{ matrix.build_type }} Build 110 | shell: bash 111 | run: | 112 | cd build 113 | make -j2 VERBOSE=1 114 | 115 | - name: Display Build Information 116 | shell: bash 117 | run: | 118 | export PYTHONPATH=$PWD/fast_rnnt/python:$PWD/build/lib:$PYTHONPATH 119 | 120 | - name: Run Tests 121 | shell: bash 122 | run: | 123 | cd build 124 | ctest --output-on-failure 125 | -------------------------------------------------------------------------------- /.github/workflows/run_tests_cuda.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Xiaomi Corp. (Wei Kang) 2 | 3 | # See ../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | name: run-tests-cuda 18 | 19 | on: 20 | push: 21 | branches: 22 | - master 23 | paths: 24 | - '.github/workflows/run_tests_cuda.yml' 25 | - 'CMakeLists.txt' 26 | - 'cmake/**' 27 | - 'fast_rnnt/csrc/**' 28 | - 'fast_rnnt/python/**' 29 | pull_request: 30 | branches: 31 | - master 32 | paths: 33 | - '.github/workflows/run_tests_cuda.yml' 34 | - 'CMakeLists.txt' 35 | - 'cmake/**' 36 | - 'fast_rnnt/csrc/**' 37 | - 'fast_rnnt/python/**' 38 | 39 | concurrency: 40 | group: run-tests-${{ github.ref }} 41 | cancel-in-progress: true 42 | 43 | jobs: 44 | run-tests: 45 | runs-on: ${{ matrix.os }} 46 | strategy: 47 | fail-fast: false 48 | matrix: 49 | os: [ubuntu-latest] 50 | cuda: ["11.6"] 51 | torch: ["1.12.1"] 52 | python-version: ["3.9"] 53 | build_type: ["Release", "Debug"] 54 | 55 | steps: 56 | # refer to https://github.com/actions/checkout 57 | - uses: actions/checkout@v2 58 | 59 | - name: Install CUDA Toolkit ${{ matrix.cuda }} 60 | env: 61 | cuda: ${{ matrix.cuda }} 62 | run: | 63 | source ./.github/scripts/install_cuda.sh 64 | echo "CUDA_HOME=${CUDA_HOME}" >> $GITHUB_ENV 65 | echo "${CUDA_HOME}/bin" >> $GITHUB_PATH 66 | echo "LD_LIBRARY_PATH=${CUDA_HOME}/lib:${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}" >> $GITHUB_ENV 67 | shell: bash 68 | 69 | - name: Display NVCC version 70 | run: | 71 | which nvcc 72 | nvcc --version 73 | 74 | - name: Display GCC version 75 | run: | 76 | gcc --version 77 | 78 | - name: Setup Python ${{ matrix.python-version }} 79 | uses: actions/setup-python@v2 80 | with: 81 | python-version: ${{ matrix.python-version }} 82 | 83 | - name: Display Python version 84 | run: python -c "import sys; print(sys.version)" 85 | 86 | - name: Install PyTorch ${{ matrix.torch }} 87 | env: 88 | cuda: ${{ matrix.cuda }} 89 | torch: ${{ matrix.torch }} 90 | shell: bash 91 | run: | 92 | python3 -m pip install -qq --upgrade pip 93 | 94 | ./.github/scripts/install_torch.sh 95 | python3 -c "import torch; print('torch version:', torch.__version__)" 96 | 97 | - name: Install git lfs 98 | run: | 99 | sudo apt-get install -y git-lfs 100 | 101 | - name: Download cudnn 8.0 102 | env: 103 | cuda: ${{ matrix.cuda }} 104 | run: | 105 | ./.github/scripts/install_cudnn.sh 106 | 107 | - name: Configure CMake 108 | shell: bash 109 | env: 110 | torch: ${{ matrix.torch }} 111 | run: | 112 | mkdir build 113 | cd build 114 | cmake -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} .. 115 | 116 | - name: ${{ matrix.build_type }} Build 117 | shell: bash 118 | run: | 119 | echo "number of cores: $(nproc)" 120 | cd build 121 | # we cannot use -j here because of limited RAM 122 | # of the VM provided by GitHub actions 123 | make VERBOSE=1 -j2 124 | 125 | - name: Display Build Information 126 | shell: bash 127 | run: | 128 | export PYTHONPATH=$PWD/fast_rnnt/python:$PWD/build/lib:$PYTHONPATH 129 | 130 | - name: Run Tests 131 | shell: bash 132 | run: | 133 | cd build 134 | ctest --output-on-failure 135 | -------------------------------------------------------------------------------- /.github/workflows/style_check.yml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Xiaomi Corp. (Fangjun Kuang 2 | # Wei Kang) 3 | # See ../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | name: style_check 18 | 19 | on: 20 | push: 21 | branches: 22 | - master 23 | pull_request: 24 | branches: 25 | - master 26 | 27 | concurrency: 28 | group: style_check-${{ github.ref }} 29 | cancel-in-progress: true 30 | 31 | jobs: 32 | style_check: 33 | runs-on: ${{ matrix.os }} 34 | strategy: 35 | matrix: 36 | os: [ubuntu-latest] 37 | python-version: [3.8] 38 | fail-fast: false 39 | 40 | steps: 41 | - uses: actions/checkout@v2 42 | with: 43 | fetch-depth: 0 44 | 45 | - name: Setup Python ${{ matrix.python-version }} 46 | uses: actions/setup-python@v1 47 | with: 48 | python-version: ${{ matrix.python-version }} 49 | 50 | - name: Install Python dependencies 51 | run: | 52 | python3 -m pip install --upgrade pip black==22.3.0 flake8==5.0.4 click==8.1.0 53 | # Click issue fixed in https://github.com/psf/black/pull/2966 54 | 55 | - name: Run flake8 56 | shell: bash 57 | working-directory: ${{github.workspace}} 58 | run: | 59 | # stop the build if there are Python syntax errors or undefined names 60 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 61 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 62 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 \ 63 | --statistics --extend-ignore=E203,E266,E501,F401,E402,F403,F841,W503 64 | 65 | - name: Run black 66 | shell: bash 67 | working-directory: ${{github.workspace}} 68 | run: | 69 | black -l 80 --check --diff . 70 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | venv* 4 | deploy* 5 | **/__pycache__ 6 | **/build* 7 | Testing* 8 | dist/* 9 | *egg-info*/* 10 | 11 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | if("x${CMAKE_SOURCE_DIR}" STREQUAL "x${CMAKE_BINARY_DIR}") 2 | message(FATAL_ERROR "\ 3 | In-source build is not a good practice. 4 | Please use: 5 | mkdir build 6 | cd build 7 | cmake .. 8 | to build this project" 9 | ) 10 | endif() 11 | 12 | cmake_minimum_required(VERSION 3.8 FATAL_ERROR) 13 | 14 | set(CMAKE_DISABLE_FIND_PACKAGE_MKL TRUE) 15 | set(languages CXX) 16 | set(_FT_WITH_CUDA ON) 17 | 18 | # the following settings are modified from cub/CMakeLists.txt 19 | set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") 20 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 21 | set(CMAKE_CXX_EXTENSIONS OFF) 22 | 23 | message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}") 24 | 25 | 26 | find_program(FT_HAS_NVCC nvcc) 27 | if(NOT FT_HAS_NVCC AND "$ENV{CUDACXX}" STREQUAL "") 28 | message(STATUS "No NVCC detected. Disable CUDA support") 29 | set(_FT_WITH_CUDA OFF) 30 | endif() 31 | 32 | if(APPLE OR (DEFINED FT_WITH_CUDA AND NOT FT_WITH_CUDA)) 33 | if(_FT_WITH_CUDA) 34 | message(STATUS "Disable CUDA support") 35 | set(_FT_WITH_CUDA OFF) 36 | endif() 37 | endif() 38 | 39 | if(_FT_WITH_CUDA) 40 | set(languages ${languages} CUDA) 41 | if(NOT DEFINED FT_WITH_CUDA) 42 | set(FT_WITH_CUDA ON) 43 | endif() 44 | endif() 45 | 46 | message(STATUS "Enabled languages: ${languages}") 47 | 48 | project(fast_rnnt ${languages}) 49 | 50 | set(FT_VERSION "1.2") 51 | 52 | set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel) 53 | set(DEFAULT_BUILD_TYPE "Release") 54 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "${ALLOWABLE_BUILD_TYPES}") 55 | if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) 56 | # CMAKE_CONFIGURATION_TYPES: with config type values from other generators (IDE). 57 | message(STATUS "No CMAKE_BUILD_TYPE given, default to ${DEFAULT_BUILD_TYPE}") 58 | set(CMAKE_BUILD_TYPE "${DEFAULT_BUILD_TYPE}") 59 | elseif(NOT CMAKE_BUILD_TYPE IN_LIST ALLOWABLE_BUILD_TYPES) 60 | message(FATAL_ERROR "Invalid build type: ${CMAKE_BUILD_TYPE}, \ 61 | choose one from ${ALLOWABLE_BUILD_TYPES}") 62 | endif() 63 | 64 | option(FT_BUILD_TESTS "Whether to build tests or not" ON) 65 | option(BUILD_SHARED_LIBS "Whether to build shared libs" ON) 66 | 67 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 68 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib") 69 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin") 70 | 71 | set(CMAKE_SKIP_BUILD_RPATH FALSE) 72 | set(BUILD_RPATH_USE_ORIGIN TRUE) 73 | set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE) 74 | set(CMAKE_INSTALL_RPATH "$ORIGIN") 75 | set(CMAKE_BUILD_RPATH "$ORIGIN") 76 | 77 | if(FT_WITH_CUDA) 78 | add_definitions(-DFT_WITH_CUDA) 79 | # Force CUDA C++ standard to be the same as the C++ standard used. 80 | # 81 | # Now, CMake is unaligned with reality on standard versions: https://gitlab.kitware.com/cmake/cmake/issues/18597 82 | # which means that using standard CMake methods, it's impossible to actually sync the CXX and CUDA versions for pre-11 83 | # versions of C++; CUDA accepts 98 but translates that to 03, while CXX doesn't accept 03 (and doesn't translate that to 03). 84 | # In case this gives You, dear user, any trouble, please escalate the above CMake bug, so we can support reality properly. 85 | if(DEFINED CMAKE_CUDA_STANDARD) 86 | message(WARNING "You've set CMAKE_CUDA_STANDARD; please note that this variable is ignored, and CMAKE_CXX_STANDARD" 87 | " is used as the C++ standard version for both C++ and CUDA.") 88 | endif() 89 | 90 | 91 | unset(CMAKE_CUDA_STANDARD CACHE) 92 | set(CMAKE_CUDA_STANDARD ${CMAKE_CXX_STANDARD}) 93 | 94 | include(cmake/select_compute_arch.cmake) 95 | cuda_select_nvcc_arch_flags(FT_COMPUTE_ARCH_FLAGS) 96 | message(STATUS "FT_COMPUTE_ARCH_FLAGS: ${FT_COMPUTE_ARCH_FLAGS}") 97 | 98 | # set(OT_COMPUTE_ARCHS 30 32 35 50 52 53 60 61 62 70 72) 99 | # message(WARNING "arch 62/72 are not supported for now") 100 | 101 | # see https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/ 102 | # https://www.myzhar.com/blog/tutorials/tutorial-nvidia-gpu-cuda-compute-capability/ 103 | set(FT_COMPUTE_ARCH_CANDIDATES 35 50 60 61 70 75) 104 | if(CUDA_VERSION VERSION_GREATER "11.0") 105 | list(APPEND FT_COMPUTE_ARCH_CANDIDATES 80 86) 106 | endif() 107 | message(STATUS "FT_COMPUTE_ARCH_CANDIDATES ${FT_COMPUTE_ARCH_CANDIDATES}") 108 | 109 | set(FT_COMPUTE_ARCHS) 110 | 111 | foreach(COMPUTE_ARCH IN LISTS FT_COMPUTE_ARCH_CANDIDATES) 112 | if("${FT_COMPUTE_ARCH_FLAGS}" MATCHES ${COMPUTE_ARCH}) 113 | message(STATUS "Adding arch ${COMPUTE_ARCH}") 114 | list(APPEND FT_COMPUTE_ARCHS ${COMPUTE_ARCH}) 115 | else() 116 | message(STATUS "Skipping arch ${COMPUTE_ARCH}") 117 | endif() 118 | endforeach() 119 | 120 | if(NOT FT_COMPUTE_ARCHS) 121 | set(FT_COMPUTE_ARCHS ${FT_COMPUTE_ARCH_CANDIDATES}) 122 | endif() 123 | 124 | message(STATUS "FT_COMPUTE_ARCHS: ${FT_COMPUTE_ARCHS}") 125 | 126 | foreach(COMPUTE_ARCH IN LISTS FT_COMPUTE_ARCHS) 127 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda -gencode arch=compute_${COMPUTE_ARCH},code=sm_${COMPUTE_ARCH}") 128 | set(CMAKE_CUDA_ARCHITECTURES "${COMPUTE_ARCH}-real;${COMPUTE_ARCH}-virtual;${CMAKE_CUDA_ARCHITECTURES}") 129 | endforeach() 130 | endif() 131 | 132 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) 133 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) 134 | 135 | include(pybind11) 136 | include(torch) 137 | 138 | if(FT_BUILD_TESTS) 139 | enable_testing() 140 | include(googletest) 141 | endif() 142 | 143 | add_subdirectory(fast_rnnt) 144 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Legal Notices 3 | 4 | NOTE (this is not from the Apache License): The copyright model is that 5 | authors (or their employers, if noted in individual files) own their 6 | individual contributions. The authors' contributions can be discerned 7 | from the git history. 8 | 9 | ------------------------------------------------------------------------- 10 | 11 | Apache License 12 | Version 2.0, January 2004 13 | http://www.apache.org/licenses/ 14 | 15 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 16 | 17 | 1. Definitions. 18 | 19 | "License" shall mean the terms and conditions for use, reproduction, 20 | and distribution as defined by Sections 1 through 9 of this document. 21 | 22 | "Licensor" shall mean the copyright owner or entity authorized by 23 | the copyright owner that is granting the License. 24 | 25 | "Legal Entity" shall mean the union of the acting entity and all 26 | other entities that control, are controlled by, or are under common 27 | control with that entity. For the purposes of this definition, 28 | "control" means (i) the power, direct or indirect, to cause the 29 | direction or management of such entity, whether by contract or 30 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 31 | outstanding shares, or (iii) beneficial ownership of such entity. 32 | 33 | "You" (or "Your") shall mean an individual or Legal Entity 34 | exercising permissions granted by this License. 35 | 36 | "Source" form shall mean the preferred form for making modifications, 37 | including but not limited to software source code, documentation 38 | source, and configuration files. 39 | 40 | "Object" form shall mean any form resulting from mechanical 41 | transformation or translation of a Source form, including but 42 | not limited to compiled object code, generated documentation, 43 | and conversions to other media types. 44 | 45 | "Work" shall mean the work of authorship, whether in Source or 46 | Object form, made available under the License, as indicated by a 47 | copyright notice that is included in or attached to the work 48 | (an example is provided in the Appendix below). 49 | 50 | "Derivative Works" shall mean any work, whether in Source or Object 51 | form, that is based on (or derived from) the Work and for which the 52 | editorial revisions, annotations, elaborations, or other modifications 53 | represent, as a whole, an original work of authorship. For the purposes 54 | of this License, Derivative Works shall not include works that remain 55 | separable from, or merely link (or bind by name) to the interfaces of, 56 | the Work and Derivative Works thereof. 57 | 58 | "Contribution" shall mean any work of authorship, including 59 | the original version of the Work and any modifications or additions 60 | to that Work or Derivative Works thereof, that is intentionally 61 | submitted to Licensor for inclusion in the Work by the copyright owner 62 | or by an individual or Legal Entity authorized to submit on behalf of 63 | the copyright owner. For the purposes of this definition, "submitted" 64 | means any form of electronic, verbal, or written communication sent 65 | to the Licensor or its representatives, including but not limited to 66 | communication on electronic mailing lists, source code control systems, 67 | and issue tracking systems that are managed by, or on behalf of, the 68 | Licensor for the purpose of discussing and improving the Work, but 69 | excluding communication that is conspicuously marked or otherwise 70 | designated in writing by the copyright owner as "Not a Contribution." 71 | 72 | "Contributor" shall mean Licensor and any individual or Legal Entity 73 | on behalf of whom a Contribution has been received by Licensor and 74 | subsequently incorporated within the Work. 75 | 76 | 2. Grant of Copyright License. Subject to the terms and conditions of 77 | this License, each Contributor hereby grants to You a perpetual, 78 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 79 | copyright license to reproduce, prepare Derivative Works of, 80 | publicly display, publicly perform, sublicense, and distribute the 81 | Work and such Derivative Works in Source or Object form. 82 | 83 | 3. Grant of Patent License. Subject to the terms and conditions of 84 | this License, each Contributor hereby grants to You a perpetual, 85 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 86 | (except as stated in this section) patent license to make, have made, 87 | use, offer to sell, sell, import, and otherwise transfer the Work, 88 | where such license applies only to those patent claims licensable 89 | by such Contributor that are necessarily infringed by their 90 | Contribution(s) alone or by combination of their Contribution(s) 91 | with the Work to which such Contribution(s) was submitted. If You 92 | institute patent litigation against any entity (including a 93 | cross-claim or counterclaim in a lawsuit) alleging that the Work 94 | or a Contribution incorporated within the Work constitutes direct 95 | or contributory patent infringement, then any patent licenses 96 | granted to You under this License for that Work shall terminate 97 | as of the date such litigation is filed. 98 | 99 | 4. Redistribution. You may reproduce and distribute copies of the 100 | Work or Derivative Works thereof in any medium, with or without 101 | modifications, and in Source or Object form, provided that You 102 | meet the following conditions: 103 | 104 | (a) You must give any other recipients of the Work or 105 | Derivative Works a copy of this License; and 106 | 107 | (b) You must cause any modified files to carry prominent notices 108 | stating that You changed the files; and 109 | 110 | (c) You must retain, in the Source form of any Derivative Works 111 | that You distribute, all copyright, patent, trademark, and 112 | attribution notices from the Source form of the Work, 113 | excluding those notices that do not pertain to any part of 114 | the Derivative Works; and 115 | 116 | (d) If the Work includes a "NOTICE" text file as part of its 117 | distribution, then any Derivative Works that You distribute must 118 | include a readable copy of the attribution notices contained 119 | within such NOTICE file, excluding those notices that do not 120 | pertain to any part of the Derivative Works, in at least one 121 | of the following places: within a NOTICE text file distributed 122 | as part of the Derivative Works; within the Source form or 123 | documentation, if provided along with the Derivative Works; or, 124 | within a display generated by the Derivative Works, if and 125 | wherever such third-party notices normally appear. The contents 126 | of the NOTICE file are for informational purposes only and 127 | do not modify the License. You may add Your own attribution 128 | notices within Derivative Works that You distribute, alongside 129 | or as an addendum to the NOTICE text from the Work, provided 130 | that such additional attribution notices cannot be construed 131 | as modifying the License. 132 | 133 | You may add Your own copyright statement to Your modifications and 134 | may provide additional or different license terms and conditions 135 | for use, reproduction, or distribution of Your modifications, or 136 | for any such Derivative Works as a whole, provided Your use, 137 | reproduction, and distribution of the Work otherwise complies with 138 | the conditions stated in this License. 139 | 140 | 5. Submission of Contributions. Unless You explicitly state otherwise, 141 | any Contribution intentionally submitted for inclusion in the Work 142 | by You to the Licensor shall be under the terms and conditions of 143 | this License, without any additional terms or conditions. 144 | Notwithstanding the above, nothing herein shall supersede or modify 145 | the terms of any separate license agreement you may have executed 146 | with Licensor regarding such Contributions. 147 | 148 | 6. Trademarks. This License does not grant permission to use the trade 149 | names, trademarks, service marks, or product names of the Licensor, 150 | except as required for reasonable and customary use in describing the 151 | origin of the Work and reproducing the content of the NOTICE file. 152 | 153 | 7. Disclaimer of Warranty. Unless required by applicable law or 154 | agreed to in writing, Licensor provides the Work (and each 155 | Contributor provides its Contributions) on an "AS IS" BASIS, 156 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 157 | implied, including, without limitation, any warranties or conditions 158 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 159 | PARTICULAR PURPOSE. You are solely responsible for determining the 160 | appropriateness of using or redistributing the Work and assume any 161 | risks associated with Your exercise of permissions under this License. 162 | 163 | 8. Limitation of Liability. In no event and under no legal theory, 164 | whether in tort (including negligence), contract, or otherwise, 165 | unless required by applicable law (such as deliberate and grossly 166 | negligent acts) or agreed to in writing, shall any Contributor be 167 | liable to You for damages, including any direct, indirect, special, 168 | incidental, or consequential damages of any character arising as a 169 | result of this License or out of the use or inability to use the 170 | Work (including but not limited to damages for loss of goodwill, 171 | work stoppage, computer failure or malfunction, or any and all 172 | other commercial damages or losses), even if such Contributor 173 | has been advised of the possibility of such damages. 174 | 175 | 9. Accepting Warranty or Additional Liability. While redistributing 176 | the Work or Derivative Works thereof, You may choose to offer, 177 | and charge a fee for, acceptance of support, warranty, indemnity, 178 | or other liability obligations and/or rights consistent with this 179 | License. However, in accepting such obligations, You may act only 180 | on Your own behalf and on Your sole responsibility, not on behalf 181 | of any other Contributor, and only if You agree to indemnify, 182 | defend, and hold each Contributor harmless for any liability 183 | incurred by, or claims asserted against, such Contributor by reason 184 | of your accepting any such warranty or additional liability. 185 | 186 | END OF TERMS AND CONDITIONS 187 | 188 | APPENDIX: How to apply the Apache License to your work. 189 | 190 | To apply the Apache License to your work, attach the following 191 | boilerplate notice, with the fields enclosed by brackets "[]" 192 | replaced with your own identifying information. (Don't include 193 | the brackets!) The text should be enclosed in the appropriate 194 | comment syntax for the file format. We also recommend that a 195 | file or class name and description of purpose be included on the 196 | same "printed page" as the copyright notice for easier 197 | identification within third-party archives. 198 | 199 | Copyright [yyyy] [name of copyright owner] 200 | 201 | Licensed under the Apache License, Version 2.0 (the "License"); 202 | you may not use this file except in compliance with the License. 203 | You may obtain a copy of the License at 204 | 205 | http://www.apache.org/licenses/LICENSE-2.0 206 | 207 | Unless required by applicable law or agreed to in writing, software 208 | distributed under the License is distributed on an "AS IS" BASIS, 209 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 210 | See the License for the specific language governing permissions and 211 | limitations under the License. 212 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include README.md 3 | include LICENSE* 4 | include CMakeLists.txt 5 | recursive-include fast_rnnt *.* 6 | recursive-include cmake *.* 7 | global-exclude *.pyc 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | This project implements a method for faster and more memory-efficient RNN-T loss computation, called `pruned rnnt`. 3 | 4 | Note: There is also a fast RNN-T loss implementation in [k2](https://github.com/k2-fsa/k2) project, which shares the same code here. We make `fast_rnnt` a stand-alone project in case someone wants only this rnnt loss. 5 | 6 | ## How does the pruned-rnnt work ? 7 | 8 | We first obtain pruning bounds for the RNN-T recursion using a simple joiner network that is just an addition of the encoder and decoder, then we use those pruning bounds to evaluate the full, non-linear joiner network. 9 | 10 | The picture below display the gradients (obtained by `rnnt_loss_simple` with `return_grad=true`) of lattice nodes, at each time frame, only a small set of nodes have a non-zero gradient, which justifies the pruned RNN-T loss, i.e., putting a limit on the number of symbols per frame. 11 | 12 | 13 | 14 | > This picture is taken from [here](https://github.com/k2-fsa/icefall/pull/251) 15 | 16 | ## Installation 17 | 18 | You can install it via `pip`: 19 | 20 | ``` 21 | pip install fast_rnnt 22 | ``` 23 | 24 | You can also install from source: 25 | 26 | ``` 27 | git clone https://github.com/danpovey/fast_rnnt.git 28 | cd fast_rnnt 29 | python setup.py install 30 | ``` 31 | 32 | To check that `fast_rnnt` was installed successfully, please run 33 | 34 | ``` 35 | python3 -c "import fast_rnnt; print(fast_rnnt.__version__)" 36 | ``` 37 | 38 | which should print the version of the installed `fast_rnnt`, e.g., `1.0`. 39 | 40 | 41 | ### How to display installation log ? 42 | 43 | Use 44 | 45 | ``` 46 | pip install --verbose fast_rnnt 47 | ``` 48 | 49 | ### How to reduce installation time ? 50 | 51 | Use 52 | 53 | ``` 54 | export FT_MAKE_ARGS="-j" 55 | pip install --verbose fast_rnnt 56 | ``` 57 | 58 | It will pass `-j` to `make`. 59 | 60 | ### Which version of PyTorch is supported ? 61 | 62 | It has been tested on PyTorch >= 1.5.0. 63 | 64 | Note: The cuda version of the Pytorch should be the same as the cuda version in your environment, 65 | or it will cause a compilation error. 66 | 67 | 68 | ### How to install a CPU version of `fast_rnnt` ? 69 | 70 | Use 71 | 72 | ``` 73 | export FT_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=Release -DFT_WITH_CUDA=OFF" 74 | export FT_MAKE_ARGS="-j" 75 | pip install --verbose fast_rnnt 76 | ``` 77 | 78 | It will pass `-DCMAKE_BUILD_TYPE=Release -DFT_WITH_CUDA=OFF` to `cmake`. 79 | 80 | ### Where to get help if I have problems with the installation ? 81 | 82 | Please file an issue at 83 | and describe your problem there. 84 | 85 | 86 | ## Usage 87 | 88 | ### For rnnt_loss_simple 89 | 90 | This is a simple case of the RNN-T loss, where the joiner network is just 91 | addition. 92 | 93 | Note: termination_symbol plays the role of blank in other RNN-T loss implementations, we call it termination_symbol as it terminates symbols of current frame. 94 | 95 | ```python 96 | am = torch.randn((B, T, C), dtype=torch.float32) 97 | lm = torch.randn((B, S + 1, C), dtype=torch.float32) 98 | symbols = torch.randint(0, C, (B, S)) 99 | termination_symbol = 0 100 | 101 | boundary = torch.zeros((B, 4), dtype=torch.int64) 102 | boundary[:, 2] = target_lengths 103 | boundary[:, 3] = num_frames 104 | 105 | loss = fast_rnnt.rnnt_loss_simple( 106 | lm=lm, 107 | am=am, 108 | symbols=symbols, 109 | termination_symbol=termination_symbol, 110 | boundary=boundary, 111 | reduction="sum", 112 | ) 113 | ``` 114 | 115 | ### For rnnt_loss_smoothed 116 | 117 | The same as `rnnt_loss_simple`, except that it supports `am_only` & `lm_only` smoothing 118 | that allows you to make the loss-function one of the form: 119 | 120 | lm_only_scale * lm_probs + 121 | am_only_scale * am_probs + 122 | (1-lm_only_scale-am_only_scale) * combined_probs 123 | 124 | where `lm_probs` and `am_probs` are the probabilities given the lm and acoustic model independently. 125 | 126 | ```python 127 | am = torch.randn((B, T, C), dtype=torch.float32) 128 | lm = torch.randn((B, S + 1, C), dtype=torch.float32) 129 | symbols = torch.randint(0, C, (B, S)) 130 | termination_symbol = 0 131 | 132 | boundary = torch.zeros((B, 4), dtype=torch.int64) 133 | boundary[:, 2] = target_lengths 134 | boundary[:, 3] = num_frames 135 | 136 | loss = fast_rnnt.rnnt_loss_smoothed( 137 | lm=lm, 138 | am=am, 139 | symbols=symbols, 140 | termination_symbol=termination_symbol, 141 | lm_only_scale=0.25, 142 | am_only_scale=0.0 143 | boundary=boundary, 144 | reduction="sum", 145 | ) 146 | ``` 147 | 148 | ### For rnnt_loss_pruned 149 | 150 | `rnnt_loss_pruned` can not be used alone, it needs the gradients returned by `rnnt_loss_simple/rnnt_loss_smoothed` to get pruning bounds. 151 | 152 | ```python 153 | am = torch.randn((B, T, C), dtype=torch.float32) 154 | lm = torch.randn((B, S + 1, C), dtype=torch.float32) 155 | symbols = torch.randint(0, C, (B, S)) 156 | termination_symbol = 0 157 | 158 | boundary = torch.zeros((B, 4), dtype=torch.int64) 159 | boundary[:, 2] = target_lengths 160 | boundary[:, 3] = num_frames 161 | 162 | # rnnt_loss_simple can be also rnnt_loss_smoothed 163 | simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( 164 | lm=lm, 165 | am=am, 166 | symbols=symbols, 167 | termination_symbol=termination_symbol, 168 | boundary=boundary, 169 | reduction="sum", 170 | return_grad=True, 171 | ) 172 | s_range = 5 # can be other values 173 | ranges = fast_rnnt.get_rnnt_prune_ranges( 174 | px_grad=px_grad, 175 | py_grad=py_grad, 176 | boundary=boundary, 177 | s_range=s_range, 178 | ) 179 | 180 | am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges) 181 | 182 | logits = model.joiner(am_pruned, lm_pruned) 183 | pruned_loss = fast_rnnt.rnnt_loss_pruned( 184 | logits=logits, 185 | symbols=symbols, 186 | ranges=ranges, 187 | termination_symbol=termination_symbol, 188 | boundary=boundary, 189 | reduction="sum", 190 | ) 191 | ``` 192 | 193 | You can also find recipes [here](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless) that uses `rnnt_loss_pruned` to train a model. 194 | 195 | 196 | ### For rnnt_loss 197 | 198 | The `unprund rnnt_loss` is the same as `torchaudio rnnt_loss`, it produces same output as torchaudio for the same input. 199 | 200 | ```python 201 | logits = torch.randn((B, S, T, C), dtype=torch.float32) 202 | symbols = torch.randint(0, C, (B, S)) 203 | termination_symbol = 0 204 | 205 | boundary = torch.zeros((B, 4), dtype=torch.int64) 206 | boundary[:, 2] = target_lengths 207 | boundary[:, 3] = num_frames 208 | 209 | loss = fast_rnnt.rnnt_loss( 210 | logits=logits, 211 | symbols=symbols, 212 | termination_symbol=termination_symbol, 213 | boundary=boundary, 214 | reduction="sum", 215 | ) 216 | ``` 217 | 218 | 219 | ## Benchmarking 220 | 221 | The [repo](https://github.com/csukuangfj/transducer-loss-benchmarking) compares the speed and memory usage of several transducer losses, the summary in the following table is taken from there, you can check the repository for more details. 222 | 223 | Note: As we declared above, `fast_rnnt` is also implemented in [k2](https://github.com/k2-fsa/k2) project, so `k2` and `fast_rnnt` are equivalent in the benchmarking. 224 | 225 | |Name |Average step time (us) | Peak memory usage (MB)| 226 | |--------------------|-----------------------|-----------------------| 227 | |torchaudio |601447 |12959.2 | 228 | |fast_rnnt(unpruned) |274407 |15106.5 | 229 | |fast_rnnt(pruned) |38112 |2647.8 | 230 | |optimized_transducer|567684 |10903.1 | 231 | |warprnnt_numba |229340 |13061.8 | 232 | |warp-transducer |210772 |13061.8 | 233 | -------------------------------------------------------------------------------- /cmake/Modules/FetchContent/CMakeLists.cmake.in: -------------------------------------------------------------------------------- 1 | # Distributed under the OSI-approved BSD 3-Clause License. See accompanying 2 | # file Copyright.txt or https://cmake.org/licensing for details. 3 | 4 | cmake_minimum_required(VERSION ${CMAKE_VERSION}) 5 | 6 | # We name the project and the target for the ExternalProject_Add() call 7 | # to something that will highlight to the user what we are working on if 8 | # something goes wrong and an error message is produced. 9 | 10 | project(${contentName}-populate NONE) 11 | 12 | include(ExternalProject) 13 | ExternalProject_Add(${contentName}-populate 14 | ${ARG_EXTRA} 15 | SOURCE_DIR "${ARG_SOURCE_DIR}" 16 | BINARY_DIR "${ARG_BINARY_DIR}" 17 | CONFIGURE_COMMAND "" 18 | BUILD_COMMAND "" 19 | INSTALL_COMMAND "" 20 | TEST_COMMAND "" 21 | ) 22 | -------------------------------------------------------------------------------- /cmake/Modules/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## FetchContent 3 | 4 | `FetchContent.cmake` and `FetchContent/CMakeLists.cmake.in` 5 | are copied from `cmake/3.11.0/share/cmake-3.11/Modules`. 6 | -------------------------------------------------------------------------------- /cmake/googletest.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) 2 | # See ../LICENSE for clarification regarding multiple authors 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | function(download_googltest) 17 | if(CMAKE_VERSION VERSION_LESS 3.11) 18 | # FetchContent is available since 3.11, 19 | # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules 20 | # so that it can be used in lower CMake versions. 21 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) 22 | endif() 23 | 24 | include(FetchContent) 25 | 26 | set(googletest_URL "https://github.com/google/googletest/archive/release-1.10.0.tar.gz") 27 | set(googletest_HASH "SHA256=9dc9157a9a1551ec7a7e43daea9a694a0bb5fb8bec81235d8a1e6ef64c716dcb") 28 | 29 | set(BUILD_GMOCK ON CACHE BOOL "" FORCE) 30 | set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) 31 | set(gtest_disable_pthreads ON CACHE BOOL "" FORCE) 32 | set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) 33 | 34 | FetchContent_Declare(googletest 35 | URL ${googletest_URL} 36 | URL_HASH ${googletest_HASH} 37 | ) 38 | 39 | FetchContent_GetProperties(googletest) 40 | if(NOT googletest_POPULATED) 41 | message(STATUS "Downloading googletest") 42 | FetchContent_Populate(googletest) 43 | endif() 44 | message(STATUS "googletest is downloaded to ${googletest_SOURCE_DIR}") 45 | message(STATUS "googletest's binary dir is ${googletest_BINARY_DIR}") 46 | 47 | if(APPLE) 48 | set(CMAKE_MACOSX_RPATH ON) # to solve the following warning on macOS 49 | endif() 50 | #[==[ 51 | -- Generating done 52 | Policy CMP0042 is not set: MACOSX_RPATH is enabled by default. Run "cmake 53 | --help-policy CMP0042" for policy details. Use the cmake_policy command to 54 | set the policy and suppress this warning. 55 | 56 | MACOSX_RPATH is not specified for the following targets: 57 | 58 | gmock 59 | gmock_main 60 | gtest 61 | gtest_main 62 | 63 | This warning is for project developers. Use -Wno-dev to suppress it. 64 | ]==] 65 | 66 | add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) 67 | 68 | target_include_directories(gtest 69 | INTERFACE 70 | ${googletest_SOURCE_DIR}/googletest/include 71 | ${googletest_SOURCE_DIR}/googlemock/include 72 | ) 73 | endfunction() 74 | 75 | download_googltest() 76 | -------------------------------------------------------------------------------- /cmake/pybind11.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) 2 | # See ../LICENSE for clarification regarding multiple authors 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | function(download_pybind11) 17 | if(CMAKE_VERSION VERSION_LESS 3.11) 18 | list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) 19 | endif() 20 | 21 | include(FetchContent) 22 | 23 | set(pybind11_URL "https://github.com/pybind/pybind11/archive/v2.6.0.tar.gz") 24 | set(pybind11_HASH "SHA256=90b705137b69ee3b5fc655eaca66d0dc9862ea1759226f7ccd3098425ae69571") 25 | 26 | set(double_quotes "\"") 27 | set(dollar "\$") 28 | set(semicolon "\;") 29 | if(NOT WIN32) 30 | FetchContent_Declare(pybind11 31 | URL ${pybind11_URL} 32 | URL_HASH ${pybind11_HASH} 33 | ) 34 | else() 35 | FetchContent_Declare(pybind11 36 | URL ${pybind11_URL} 37 | URL_HASH ${pybind11_HASH} 38 | ) 39 | endif() 40 | 41 | FetchContent_GetProperties(pybind11) 42 | if(NOT pybind11_POPULATED) 43 | message(STATUS "Downloading pybind11") 44 | FetchContent_Populate(pybind11) 45 | endif() 46 | message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}") 47 | add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR} EXCLUDE_FROM_ALL) 48 | endfunction() 49 | 50 | download_pybind11() 51 | -------------------------------------------------------------------------------- /cmake/select_compute_arch.cmake: -------------------------------------------------------------------------------- 1 | # 2 | # This file is copied from 3 | # https://github.com/pytorch/pytorch/blob/master/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake 4 | # 5 | # 6 | # Synopsis: 7 | # CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures]) 8 | # -- Selects GPU arch flags for nvcc based on target_CUDA_architectures 9 | # target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...) 10 | # - "Auto" detects local machine GPU compute arch at runtime. 11 | # - "Common" and "All" cover common and entire subsets of architectures 12 | # ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX 13 | # NAME: Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal Volta Turing Ampere 14 | # NUM: Any number. Only those pairs are currently accepted by NVCC though: 15 | # 3.5 3.7 5.0 5.2 5.3 6.0 6.2 7.0 7.2 7.5 8.0 16 | # Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable} 17 | # Additionally, sets ${out_variable}_readable to the resulting numeric list 18 | # Example: 19 | # CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell) 20 | # LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS}) 21 | # 22 | # More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA 23 | # 24 | 25 | if(CMAKE_CUDA_COMPILER_LOADED OR DEFINED CMAKE_CUDA_COMPILER_ID) # CUDA as a language 26 | if(CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA" 27 | AND CMAKE_CUDA_COMPILER_VERSION MATCHES "^([0-9]+\\.[0-9]+)") 28 | set(CUDA_VERSION "${CMAKE_MATCH_1}") 29 | endif() 30 | endif() 31 | 32 | # See: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-feature-list 33 | 34 | # This list will be used for CUDA_ARCH_NAME = All option 35 | set(CUDA_KNOWN_GPU_ARCHITECTURES "Kepler" "Maxwell") 36 | 37 | # This list will be used for CUDA_ARCH_NAME = Common option (enabled by default) 38 | set(CUDA_COMMON_GPU_ARCHITECTURES "3.5" "5.0") 39 | 40 | if(CUDA_VERSION VERSION_LESS "7.0") 41 | set(CUDA_LIMIT_GPU_ARCHITECTURE "5.2") 42 | endif() 43 | 44 | # This list is used to filter CUDA archs when autodetecting 45 | set(CUDA_ALL_GPU_ARCHITECTURES "3.5" "5.0") 46 | 47 | if(CUDA_VERSION VERSION_GREATER "6.5") 48 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Kepler+Tegra" "Kepler+Tesla" "Maxwell+Tegra") 49 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2") 50 | 51 | if(CUDA_VERSION VERSION_LESS "8.0") 52 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX") 53 | set(CUDA_LIMIT_GPU_ARCHITECTURE "6.0") 54 | endif() 55 | endif() 56 | 57 | if(CUDA_VERSION VERSION_GREATER "7.5") 58 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Pascal") 59 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1") 60 | list(APPEND CUDA_ALL_GPU_ARCHITECTURES "6.0" "6.1" "6.2") 61 | 62 | if(CUDA_VERSION VERSION_LESS "9.0") 63 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.2+PTX") 64 | set(CUDA_LIMIT_GPU_ARCHITECTURE "7.0") 65 | endif() 66 | endif () 67 | 68 | if(CUDA_VERSION VERSION_GREATER "8.5") 69 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Volta") 70 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.0") 71 | list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.0" "7.2") 72 | 73 | if(CUDA_VERSION VERSION_LESS "10.0") 74 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.2+PTX") 75 | set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0") 76 | endif() 77 | endif() 78 | 79 | if(CUDA_VERSION VERSION_GREATER "9.5") 80 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Turing") 81 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.5") 82 | list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.5") 83 | 84 | if(CUDA_VERSION VERSION_LESS "11.0") 85 | set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0") 86 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.5+PTX") 87 | endif() 88 | endif() 89 | 90 | if(CUDA_VERSION VERSION_GREATER "10.5") 91 | list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere") 92 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0") 93 | list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0") 94 | 95 | if(CUDA_VERSION VERSION_LESS "11.1") 96 | set(CUDA_LIMIT_GPU_ARCHITECTURE "8.6") 97 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0+PTX") 98 | endif() 99 | endif() 100 | 101 | if(CUDA_VERSION VERSION_GREATER "11.0") 102 | list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6" "8.6+PTX") 103 | list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6") 104 | 105 | if(CUDA_VERSION VERSION_LESS "12.0") 106 | set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0") 107 | endif() 108 | endif() 109 | 110 | ################################################################################################ 111 | # A function for automatic detection of GPUs installed (if autodetection is enabled) 112 | # Usage: 113 | # CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE) 114 | # 115 | function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE) 116 | if(NOT CUDA_GPU_DETECT_OUTPUT) 117 | if(CMAKE_CUDA_COMPILER_LOADED OR DEFINED CMAKE_CUDA_COMPILER_ID) # CUDA as a language 118 | set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cu") 119 | else() 120 | set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cpp") 121 | endif() 122 | 123 | file(WRITE ${file} "" 124 | "#include \n" 125 | "#include \n" 126 | "int main()\n" 127 | "{\n" 128 | " int count = 0;\n" 129 | " if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n" 130 | " if (count == 0) return -1;\n" 131 | " for (int device = 0; device < count; ++device)\n" 132 | " {\n" 133 | " cudaDeviceProp prop;\n" 134 | " if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n" 135 | " std::printf(\"%d.%d \", prop.major, prop.minor);\n" 136 | " }\n" 137 | " return 0;\n" 138 | "}\n") 139 | 140 | if(CMAKE_CUDA_COMPILER_LOADED OR DEFINED CMAKE_CUDA_COMPILER_ID) # CUDA as a language 141 | try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file} 142 | RUN_OUTPUT_VARIABLE compute_capabilities) 143 | else() 144 | try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file} 145 | CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}" 146 | LINK_LIBRARIES ${CUDA_LIBRARIES} 147 | RUN_OUTPUT_VARIABLE compute_capabilities) 148 | endif() 149 | 150 | # Filter unrelated content out of the output. 151 | string(REGEX MATCHALL "[0-9]+\\.[0-9]+" compute_capabilities "${compute_capabilities}") 152 | 153 | if(run_result EQUAL 0) 154 | string(REPLACE "2.1" "2.1(2.0)" compute_capabilities "${compute_capabilities}") 155 | set(CUDA_GPU_DETECT_OUTPUT ${compute_capabilities} 156 | CACHE INTERNAL "Returned GPU architectures from detect_gpus tool" FORCE) 157 | endif() 158 | endif() 159 | 160 | if(NOT CUDA_GPU_DETECT_OUTPUT) 161 | message(STATUS "Automatic GPU detection failed. Building for common architectures.") 162 | set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE) 163 | else() 164 | # Filter based on CUDA version supported archs 165 | set(CUDA_GPU_DETECT_OUTPUT_FILTERED "") 166 | separate_arguments(CUDA_GPU_DETECT_OUTPUT) 167 | foreach(ITEM IN ITEMS ${CUDA_GPU_DETECT_OUTPUT}) 168 | if(CUDA_LIMIT_GPU_ARCHITECTURE AND (ITEM VERSION_GREATER CUDA_LIMIT_GPU_ARCHITECTURE OR 169 | ITEM VERSION_EQUAL CUDA_LIMIT_GPU_ARCHITECTURE)) 170 | list(GET CUDA_COMMON_GPU_ARCHITECTURES -1 NEWITEM) 171 | string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${NEWITEM}") 172 | else() 173 | string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${ITEM}") 174 | endif() 175 | endforeach() 176 | 177 | set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT_FILTERED} PARENT_SCOPE) 178 | endif() 179 | endfunction() 180 | 181 | 182 | ################################################################################################ 183 | # Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list 184 | # Usage: 185 | # SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs]) 186 | function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) 187 | set(CUDA_ARCH_LIST "${ARGN}") 188 | 189 | if("X${CUDA_ARCH_LIST}" STREQUAL "X" ) 190 | set(CUDA_ARCH_LIST "Auto") 191 | endif() 192 | 193 | set(cuda_arch_bin) 194 | set(cuda_arch_ptx) 195 | 196 | if("${CUDA_ARCH_LIST}" STREQUAL "All") 197 | set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES}) 198 | elseif("${CUDA_ARCH_LIST}" STREQUAL "Common") 199 | set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES}) 200 | elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto") 201 | CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST) 202 | message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}") 203 | endif() 204 | 205 | # Now process the list and look for names 206 | string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") 207 | list(REMOVE_DUPLICATES CUDA_ARCH_LIST) 208 | foreach(arch_name ${CUDA_ARCH_LIST}) 209 | set(arch_bin) 210 | set(arch_ptx) 211 | set(add_ptx FALSE) 212 | # Check to see if we are compiling PTX 213 | if(arch_name MATCHES "(.*)\\+PTX$") 214 | set(add_ptx TRUE) 215 | set(arch_name ${CMAKE_MATCH_1}) 216 | endif() 217 | if(arch_name MATCHES "^([0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$") 218 | set(arch_bin ${CMAKE_MATCH_1}) 219 | set(arch_ptx ${arch_bin}) 220 | else() 221 | # Look for it in our list of known architectures 222 | if(${arch_name} STREQUAL "Kepler+Tesla") 223 | set(arch_bin 3.7) 224 | elseif(${arch_name} STREQUAL "Kepler") 225 | set(arch_bin 3.5) 226 | set(arch_ptx 3.5) 227 | elseif(${arch_name} STREQUAL "Maxwell+Tegra") 228 | set(arch_bin 5.3) 229 | elseif(${arch_name} STREQUAL "Maxwell") 230 | set(arch_bin 5.0 5.2) 231 | set(arch_ptx 5.2) 232 | elseif(${arch_name} STREQUAL "Pascal") 233 | set(arch_bin 6.0 6.1) 234 | set(arch_ptx 6.1) 235 | elseif(${arch_name} STREQUAL "Volta") 236 | set(arch_bin 7.0 7.0) 237 | set(arch_ptx 7.0) 238 | elseif(${arch_name} STREQUAL "Turing") 239 | set(arch_bin 7.5) 240 | set(arch_ptx 7.5) 241 | elseif(${arch_name} STREQUAL "Ampere") 242 | set(arch_bin 8.0) 243 | set(arch_ptx 8.0) 244 | else() 245 | message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS") 246 | endif() 247 | endif() 248 | if(NOT arch_bin) 249 | message(SEND_ERROR "arch_bin wasn't set for some reason") 250 | endif() 251 | list(APPEND cuda_arch_bin ${arch_bin}) 252 | if(add_ptx) 253 | if (NOT arch_ptx) 254 | set(arch_ptx ${arch_bin}) 255 | endif() 256 | list(APPEND cuda_arch_ptx ${arch_ptx}) 257 | endif() 258 | endforeach() 259 | 260 | # remove dots and convert to lists 261 | string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}") 262 | string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}") 263 | string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}") 264 | string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}") 265 | 266 | if(cuda_arch_bin) 267 | list(REMOVE_DUPLICATES cuda_arch_bin) 268 | endif() 269 | if(cuda_arch_ptx) 270 | list(REMOVE_DUPLICATES cuda_arch_ptx) 271 | endif() 272 | 273 | set(nvcc_flags "") 274 | set(nvcc_archs_readable "") 275 | 276 | # Tell NVCC to add binaries for the specified GPUs 277 | foreach(arch ${cuda_arch_bin}) 278 | if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)") 279 | # User explicitly specified ARCH for the concrete CODE 280 | list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1}) 281 | list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1}) 282 | else() 283 | # User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE 284 | list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch}) 285 | list(APPEND nvcc_archs_readable sm_${arch}) 286 | endif() 287 | endforeach() 288 | 289 | # Tell NVCC to add PTX intermediate code for the specified architectures 290 | foreach(arch ${cuda_arch_ptx}) 291 | list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch}) 292 | list(APPEND nvcc_archs_readable compute_${arch}) 293 | endforeach() 294 | 295 | string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}") 296 | set(${out_variable} ${nvcc_flags} PARENT_SCOPE) 297 | set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE) 298 | endfunction() 299 | -------------------------------------------------------------------------------- /cmake/torch.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) 2 | # PYTHON_EXECUTABLE is set by pybind11.cmake 3 | message(STATUS "Python executable: ${PYTHON_EXECUTABLE}") 4 | execute_process( 5 | COMMAND "${PYTHON_EXECUTABLE}" -c "import os; import torch; print(os.path.dirname(torch.__file__))" 6 | OUTPUT_STRIP_TRAILING_WHITESPACE 7 | OUTPUT_VARIABLE TORCH_DIR 8 | ) 9 | 10 | list(APPEND CMAKE_PREFIX_PATH "${TORCH_DIR}") 11 | find_package(Torch REQUIRED) 12 | 13 | # set the global CMAKE_CXX_FLAGS so that 14 | # optimized_transducer uses the same abi flag as PyTorch 15 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 16 | if(OT_WITH_CUDA) 17 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${TORCH_CXX_FLAGS}") 18 | endif() 19 | 20 | 21 | execute_process( 22 | COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[0])" 23 | OUTPUT_STRIP_TRAILING_WHITESPACE 24 | OUTPUT_VARIABLE OT_TORCH_VERSION_MAJOR 25 | ) 26 | 27 | execute_process( 28 | COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[1])" 29 | OUTPUT_STRIP_TRAILING_WHITESPACE 30 | OUTPUT_VARIABLE OT_TORCH_VERSION_MINOR 31 | ) 32 | 33 | execute_process( 34 | COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__)" 35 | OUTPUT_STRIP_TRAILING_WHITESPACE 36 | OUTPUT_VARIABLE TORCH_VERSION 37 | ) 38 | 39 | message(STATUS "PyTorch version: ${TORCH_VERSION}") 40 | 41 | if(OT_WITH_CUDA) 42 | execute_process( 43 | COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.version.cuda)" 44 | OUTPUT_STRIP_TRAILING_WHITESPACE 45 | OUTPUT_VARIABLE TORCH_CUDA_VERSION 46 | ) 47 | 48 | message(STATUS "PyTorch cuda version: ${TORCH_CUDA_VERSION}") 49 | 50 | if(NOT CUDA_VERSION VERSION_EQUAL TORCH_CUDA_VERSION) 51 | message(FATAL_ERROR 52 | "PyTorch ${TORCH_VERSION} is compiled with CUDA ${TORCH_CUDA_VERSION}.\n" 53 | "But you are using CUDA ${CUDA_VERSION} to compile optimized_transducer.\n" 54 | "Please try to use the same CUDA version for PyTorch and optimized_transducer.\n" 55 | "**You can remove this check if you are sure this will not cause " 56 | "problems**\n" 57 | ) 58 | endif() 59 | 60 | # Solve the following error for NVCC: 61 | # unknown option `-Wall` 62 | # 63 | # It contains only some -Wno-* flags, so it is OK 64 | # to set them to empty 65 | set_property(TARGET torch_cuda 66 | PROPERTY 67 | INTERFACE_COMPILE_OPTIONS "" 68 | ) 69 | set_property(TARGET torch_cpu 70 | PROPERTY 71 | INTERFACE_COMPILE_OPTIONS "" 72 | ) 73 | endif() 74 | -------------------------------------------------------------------------------- /cmake/transform.cmake: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) 2 | # See ../LICENSE for clarification regarding multiple authors 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # This function is used to copy foo.cu to foo.cc 17 | # Usage: 18 | # 19 | # transform(OUTPUT_VARIABLE output_variable_name SRCS foo.cu bar.cu) 20 | # 21 | function(transform) 22 | set(optional_args "") # there are no optional arguments 23 | set(one_value_arg OUTPUT_VARIABLE) 24 | set(multi_value_args SRCS) 25 | 26 | cmake_parse_arguments(MY "${optional_args}" "${one_value_arg}" "${multi_value_args}" ${ARGN}) 27 | foreach(src IN LISTS MY_SRCS) 28 | get_filename_component(src_name ${src} NAME_WE) 29 | get_filename_component(src_dir ${src} DIRECTORY) 30 | set(dst ${CMAKE_CURRENT_BINARY_DIR}/${src_dir}/${src_name}.cc) 31 | 32 | list(APPEND ans ${dst}) 33 | message(STATUS "Renaming ${CMAKE_CURRENT_SOURCE_DIR}/${src} to ${dst}") 34 | configure_file(${src} ${dst}) 35 | endforeach() 36 | set(${MY_OUTPUT_VARIABLE} ${ans} PARENT_SCOPE) 37 | endfunction() 38 | -------------------------------------------------------------------------------- /fast_rnnt/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(csrc) 2 | add_subdirectory(python) 3 | -------------------------------------------------------------------------------- /fast_rnnt/csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | include_directories(${CMAKE_SOURCE_DIR}) 2 | 3 | # it is located in fast_rnnt/cmake/transform.cmake 4 | include(transform) 5 | 6 | set(srcs 7 | mutual_information_cpu.cu 8 | ) 9 | 10 | if(NOT FT_WITH_CUDA) 11 | transform(OUTPUT_VARIABLE srcs SRCS ${srcs}) 12 | else() 13 | list(APPEND srcs mutual_information_cuda.cu) 14 | endif() 15 | 16 | add_library(mutual_information_core ${srcs}) 17 | target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES}) 18 | # for 19 | target_include_directories(mutual_information_core PUBLIC ${PYTHON_INCLUDE_DIRS}) 20 | -------------------------------------------------------------------------------- /fast_rnnt/csrc/device_guard.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang, Wei Kang) 3 | * 4 | * See LICENSE for clarification regarding multiple authors 5 | * 6 | * Licensed under the Apache License, Version 2.0 (the "License"); 7 | * you may not use this file except in compliance with the License. 8 | * You may obtain a copy of the License at 9 | * 10 | * http://www.apache.org/licenses/LICENSE-2.0 11 | * 12 | * Unless required by applicable law or agreed to in writing, software 13 | * distributed under the License is distributed on an "AS IS" BASIS, 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | * See the License for the specific language governing permissions and 16 | * limitations under the License. 17 | */ 18 | 19 | #ifndef FAST_RNNT_CSRC_DEVICE_GUARD_H_ 20 | #define FAST_RNNT_CSRC_DEVICE_GUARD_H_ 21 | 22 | #include "torch/script.h" 23 | 24 | // This file is modified from 25 | // https://github.com/k2-fsa/k2/blob/master/k2/csrc/device_guard.h 26 | namespace fast_rnnt { 27 | 28 | // DeviceGuard is an RAII class. Its sole purpose is to restore 29 | // the previous default cuda device if a CUDA context changes the 30 | // current default cuda device. 31 | class DeviceGuard { 32 | public: 33 | explicit DeviceGuard(torch::Device device) { 34 | if (device.type() == torch::kCUDA) { 35 | old_device_ = GetDevice(); 36 | new_device_ = device.index(); 37 | if (old_device_ != new_device_) 38 | SetDevice(new_device_); 39 | } 40 | // else do nothing 41 | } 42 | 43 | explicit DeviceGuard(int32_t new_device) : new_device_(new_device) { 44 | if (new_device != -1) { 45 | old_device_ = GetDevice(); 46 | if (old_device_ != new_device) 47 | SetDevice(new_device); 48 | } 49 | } 50 | 51 | ~DeviceGuard() { 52 | if (old_device_ != -1 && old_device_ != new_device_) { 53 | // restore the previous device 54 | SetDevice(old_device_); 55 | } 56 | // else it was either a CPU context or the device IDs 57 | // were the same 58 | } 59 | 60 | DeviceGuard(const DeviceGuard &) = delete; 61 | DeviceGuard &operator=(const DeviceGuard &) = delete; 62 | 63 | DeviceGuard(DeviceGuard &&) = delete; 64 | DeviceGuard &operator=(DeviceGuard &&) = delete; 65 | 66 | private: 67 | static int32_t GetDevice() { 68 | #ifdef FT_WITH_CUDA 69 | int32_t device; 70 | auto s = cudaGetDevice(&device); 71 | TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s)); 72 | return device; 73 | #else 74 | return -1; 75 | #endif 76 | } 77 | 78 | static void SetDevice(int32_t device) { 79 | #ifdef FT_WITH_CUDA 80 | auto s = cudaSetDevice(device); 81 | TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s)); 82 | #else 83 | return; 84 | #endif 85 | } 86 | 87 | private: 88 | int32_t old_device_ = -1; 89 | int32_t new_device_ = -1; 90 | }; 91 | 92 | } // namespace fast_rnnt 93 | 94 | #endif // FAST_RNNT_CSRC_DEVICE_GUARD_H_ 95 | -------------------------------------------------------------------------------- /fast_rnnt/csrc/mutual_information.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @copyright 3 | * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey) 4 | * 5 | * @copyright 6 | * See LICENSE for clarification regarding multiple authors 7 | * 8 | * Licensed under the Apache License, Version 2.0 (the "License"); 9 | * you may not use this file except in compliance with the License. 10 | * You may obtain a copy of the License at 11 | * 12 | * http://www.apache.org/licenses/LICENSE-2.0 13 | * 14 | * Unless required by applicable law or agreed to in writing, software 15 | * distributed under the License is distributed on an "AS IS" BASIS, 16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | * See the License for the specific language governing permissions and 18 | * limitations under the License. 19 | */ 20 | 21 | #ifndef FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_ 22 | #define FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_ 23 | 24 | #include 25 | #include 26 | #include "torch/extension.h" 27 | 28 | #ifdef __CUDA_ARCH__ 29 | #define FT_CUDA_HOSTDEV __host__ __device__ 30 | #else 31 | #define FT_CUDA_HOSTDEV 32 | #endif 33 | 34 | namespace fast_rnnt { 35 | 36 | FT_CUDA_HOSTDEV inline double LogAdd(double x, double y) { 37 | double diff; 38 | if (x < y) { 39 | diff = x - y; 40 | x = y; 41 | } else { 42 | diff = y - x; 43 | } 44 | // diff is negative. x is now the larger one. 45 | if (diff - diff != 0) 46 | return x; // x and y are probably -inf. Return the larger one. 47 | else 48 | return x + log1p(exp(diff)); 49 | } 50 | 51 | // returns log(exp(x) + exp(y)). 52 | FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) { 53 | float diff; 54 | if (x < y) { 55 | diff = x - y; 56 | x = y; 57 | } else { 58 | diff = y - x; 59 | } 60 | // diff is negative. x is now the larger one. 61 | if (diff - diff != 0) 62 | return x; // x and y are probably -inf. Return the larger one. 63 | else 64 | return x + log1p(exp(diff)); 65 | } 66 | 67 | /* 68 | Forward of mutual_information. See also comment of `mutual_information` 69 | in ../pyhton/fast_rnnt/mutual_information.py. This is the core recursion 70 | in the sequence-to-sequence mutual information computation. 71 | 72 | @param px Tensor of shape [B][S][T + 1] if not modified, [B][S][T] if 73 | modified. `modified` can be worked out from this. In not-modified case, 74 | it can be thought of as the log-odds ratio of generating the next x in 75 | the sequence, i.e. 76 | px[b][s][t] is the log of 77 | p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s), 78 | i.e. the log-prob of generating x_s given subsequences of 79 | lengths (s, t), divided by the prior probability of generating x_s. 80 | (See mutual_information.py for more info). 81 | @param py The log-odds ratio of generating the next y in the sequence. 82 | Shape [B][S + 1][T] 83 | @param p This function writes to p[b][s][t] the mutual information between 84 | sub-sequences of x and y of length s and t respectively, from the 85 | b'th sequences in the batch. Its shape is [B][S + 1][T + 1]. 86 | Concretely, this function implements the following recursion, 87 | in the case where s_begin == t_begin == 0: 88 | 89 | p[b,0,0] = 0.0 90 | if not modified: 91 | p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], 92 | p[b,s,t-1] + py[b,s,t-1]) 93 | if modified: 94 | p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], 95 | p[b,s,t-1] + py[b,s,t-1]) 96 | ... treating values with any -1 index as -infinity. 97 | .. if `boundary` is set, we start from p[b,s_begin,t_begin]=0.0. 98 | @param boundary If set, a tensor of shape [B][4] of type int64_t, which 99 | contains, where for each batch element b, boundary[b] 100 | equals [s_begin, t_begin, s_end, t_end] 101 | which are the beginning and end (i.e. one-past-the-last) 102 | of the x and y sequences that we should process. 103 | Alternatively, may be a tensor of shape [0][0] and type 104 | int64_t; the elements will default to (0, 0, S, T). 105 | @return A tensor `ans` of shape [B], where this function will set 106 | ans[b] = p[b][s_end][t_end], 107 | with s_end and t_end being (S, T) if `boundary` was specified, 108 | and (boundary[b][2], boundary[b][3]) otherwise. 109 | `ans` represents the mutual information between each pair of 110 | sequences (i.e. x[b] and y[b], although the sequences are not 111 | supplied directly to this function). 112 | 113 | The block-dim and grid-dim must both be 1-dimensional, and the block-dim must 114 | be at least 128. 115 | */ 116 | torch::Tensor MutualInformationCpu( 117 | torch::Tensor px, // [B][S][T+1] 118 | torch::Tensor py, // [B][S+1][T] 119 | torch::optional boundary, // [B][4], int64_t. 120 | torch::Tensor p); // [B][S+1][T+1]; an output 121 | 122 | torch::Tensor MutualInformationCuda( 123 | torch::Tensor px, // [B][S][T+1] if !modified, [B][S][T] if modified. 124 | torch::Tensor py, // [B][S+1][T] 125 | torch::optional boundary, // [B][4], int64_t. 126 | torch::Tensor p); // [B][S+1][T+1]; an output 127 | 128 | /* 129 | backward of mutual_information; returns (grad_px, grad_py) 130 | 131 | if overwrite_ans_grad == true, this function will overwrite ans_grad with a 132 | value that, if the computation worked correctly, should be identical to or 133 | very close to the value of ans_grad at entry. This can be used 134 | to validate the correctness of this code. 135 | */ 136 | std::vector 137 | MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py, 138 | torch::optional boundary, 139 | torch::Tensor p, torch::Tensor ans_grad); 140 | 141 | std::vector MutualInformationBackwardCuda( 142 | torch::Tensor px, torch::Tensor py, torch::optional boundary, 143 | torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad); 144 | 145 | } // namespace fast_rnnt 146 | 147 | #endif // FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_ 148 | -------------------------------------------------------------------------------- /fast_rnnt/csrc/mutual_information_cpu.cu: -------------------------------------------------------------------------------- 1 | /** 2 | * @copyright 3 | * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey) 4 | * 5 | * @copyright 6 | * See LICENSE for clarification regarding multiple authors 7 | * 8 | * Licensed under the Apache License, Version 2.0 (the "License"); 9 | * you may not use this file except in compliance with the License. 10 | * You may obtain a copy of the License at 11 | * 12 | * http://www.apache.org/licenses/LICENSE-2.0 13 | * 14 | * Unless required by applicable law or agreed to in writing, software 15 | * distributed under the License is distributed on an "AS IS" BASIS, 16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | * See the License for the specific language governing permissions and 18 | * limitations under the License. 19 | */ 20 | 21 | #include 22 | #include "fast_rnnt/csrc/mutual_information.h" 23 | 24 | namespace fast_rnnt { 25 | 26 | // forward of mutual_information. See """... """ comment of 27 | // `mutual_information_recursion` in 28 | // in python/fast_rnnt/mutual_information.py for documentation of the 29 | // behavior of this function. 30 | 31 | // px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out 32 | // `modified` from this. 33 | // py: of shape [B, S+1, T] 34 | // boundary: of shape [B, 4], containing (s_begin, t_begin, s_end, t_end) 35 | // defaulting to (0, 0, S, T). 36 | // p: of shape (S+1, T+1) 37 | // Computes the recursion: 38 | // if !modified: 39 | // p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], 40 | // p[b,s,t-1] + py[b,s,t-1]) 41 | // if modified: 42 | // p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], 43 | // p[b,s,t-1] + py[b,s,t-1]) 44 | 45 | // .. treating out-of-range elements as -infinity and with special cases: 46 | // p[b, s_begin, t_begin] = 0.0 47 | // 48 | // and this function returns a tensor of shape (B,) consisting of elements 49 | // p[b, s_end, t_end] 50 | torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py, 51 | torch::optional opt_boundary, 52 | torch::Tensor p) { 53 | TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); 54 | TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); 55 | TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); 56 | TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && 57 | p.device().is_cpu(), 58 | "inputs must be CPU tensors"); 59 | 60 | bool modified = (px.size(2) == py.size(2)); 61 | 62 | auto scalar_t = px.scalar_type(); 63 | auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); 64 | 65 | const int B = px.size(0), S = px.size(1), T = py.size(2); 66 | TORCH_CHECK(px.size(2) == (modified ? T : T + 1)); 67 | TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); 68 | TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); 69 | 70 | auto boundary = opt_boundary.value_or( 71 | torch::tensor({0, 0, S, T}, 72 | torch::dtype(torch::kInt64).device(torch::kCPU)) 73 | .reshape({1, 4}) 74 | .expand({B, 4})); 75 | TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional."); 76 | TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4); 77 | TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64); 78 | 79 | torch::Tensor ans = torch::empty({B}, opts); 80 | 81 | AT_DISPATCH_FLOATING_TYPES( 82 | px.scalar_type(), "mutual_information_cpu_loop", ([&] { 83 | auto px_a = px.accessor(), 84 | py_a = py.accessor(), p_a = p.accessor(); 85 | auto boundary_a = boundary.accessor(); 86 | auto ans_a = ans.accessor(); 87 | 88 | int t_offset = (modified ? -1 : 0); 89 | for (int b = 0; b < B; b++) { 90 | int s_begin = boundary_a[b][0]; 91 | int t_begin = boundary_a[b][1]; 92 | int s_end = boundary_a[b][2]; 93 | int t_end = boundary_a[b][3]; 94 | p_a[b][s_begin][t_begin] = 0.0; 95 | if (modified) { 96 | for (int s = s_begin + 1; s <= s_end; ++s) 97 | p_a[b][s][t_begin] = -std::numeric_limits::infinity(); 98 | } else { 99 | // note: t_offset = 0 so don't need t_begin + t_offset below. 100 | for (int s = s_begin + 1; s <= s_end; ++s) 101 | p_a[b][s][t_begin] = 102 | p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin]; 103 | } 104 | for (int t = t_begin + 1; t <= t_end; ++t) 105 | p_a[b][s_begin][t] = 106 | p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1]; 107 | for (int s = s_begin + 1; s <= s_end; ++s) { 108 | scalar_t p_s_t1 = p_a[b][s][t_begin]; 109 | for (int t = t_begin + 1; t <= t_end; ++t) { 110 | // The following statement is a small optimization of: 111 | // p_a[b][s][t] = LogAdd( 112 | // p_a[b][s - 1][t + t_offset] + px_a[b][s -1][t + t_offset], 113 | // p_a[b][s][t - 1] + py_a[b][s][t - 1]); 114 | // .. which obtains p_a[b][s][t - 1] from a register. 115 | p_a[b][s][t] = p_s_t1 = LogAdd(p_a[b][s - 1][t + t_offset] + 116 | px_a[b][s - 1][t + t_offset], 117 | p_s_t1 + py_a[b][s][t - 1]); 118 | } 119 | } 120 | ans_a[b] = p_a[b][s_end][t_end]; 121 | } 122 | })); 123 | return ans; 124 | } 125 | 126 | // backward of mutual_information. Returns (px_grad, py_grad). 127 | // p corresponds to what we computed in the forward pass. 128 | std::vector 129 | MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py, 130 | torch::optional opt_boundary, 131 | torch::Tensor p, torch::Tensor ans_grad) { 132 | TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); 133 | TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); 134 | TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); 135 | TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional."); 136 | 137 | bool modified = (px.size(2) == py.size(2)); 138 | 139 | TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() && 140 | p.device().is_cpu() && ans_grad.device().is_cpu(), 141 | "inputs must be CPU tensors"); 142 | 143 | auto scalar_t = px.scalar_type(); 144 | auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); 145 | 146 | const int B = px.size(0), S = px.size(1), T = py.size(2); 147 | TORCH_CHECK(px.size(2) == (modified ? T : T + 1)); 148 | TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1); 149 | TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); 150 | 151 | auto boundary = opt_boundary.value_or( 152 | torch::tensor({0, 0, S, T}, 153 | torch::dtype(torch::kInt64).device(torch::kCPU)) 154 | .reshape({1, 4}) 155 | .expand({B, 4})); 156 | TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional."); 157 | TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4); 158 | TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64); 159 | 160 | bool has_boundary = opt_boundary.has_value(); 161 | int T1 = T + (modified ? 0 : 1); 162 | torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts), 163 | px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts) 164 | : torch::empty({B, S, T1}, opts)), 165 | py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) 166 | : torch::empty({B, S + 1, T}, opts)); 167 | 168 | AT_DISPATCH_FLOATING_TYPES( 169 | px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] { 170 | auto px_a = px.accessor(), p_a = p.accessor(), 171 | p_grad_a = p_grad.accessor(), 172 | px_grad_a = px_grad.accessor(), 173 | py_grad_a = py_grad.accessor(); 174 | 175 | auto ans_grad_a = ans_grad.accessor(); 176 | auto boundary_a = boundary.accessor(); 177 | int t_offset = (modified ? -1 : 0); 178 | 179 | for (int b = 0; b < B; b++) { 180 | int s_begin = boundary_a[b][0]; 181 | int t_begin = boundary_a[b][1]; 182 | int s_end = boundary_a[b][2]; 183 | int t_end = boundary_a[b][3]; 184 | // Backprop for: ans_a[b] = p_a[b][s_end][t_end]; 185 | p_grad_a[b][s_end][t_end] = ans_grad_a[b]; 186 | 187 | for (int s = s_end; s > s_begin; --s) { 188 | for (int t = t_end; t > t_begin; --t) { 189 | // The s,t indexes correspond to 190 | // The statement we are backpropagating here is: 191 | // p_a[b][s][t] = LogAdd( 192 | // p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset], 193 | // p_a[b][s][t - 1] + py_a[b][s][t - 1]); 194 | // .. which obtains p_a[b][s][t - 1] from a register. 195 | scalar_t term1 = p_a[b][s - 1][t + t_offset] + 196 | px_a[b][s - 1][t + t_offset], 197 | // term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not 198 | // actually needed.. 199 | total = p_a[b][s][t]; 200 | if (total - total != 0) 201 | total = 0; 202 | scalar_t term1_deriv = exp(term1 - total), 203 | term2_deriv = 1.0 - term1_deriv, 204 | grad = p_grad_a[b][s][t]; 205 | scalar_t term1_grad, term2_grad; 206 | if (term1_deriv - term1_deriv == 0.0) { 207 | term1_grad = term1_deriv * grad; 208 | term2_grad = term2_deriv * grad; 209 | } else { 210 | // could happen if total == -inf 211 | term1_grad = term2_grad = 0.0; 212 | } 213 | px_grad_a[b][s - 1][t + t_offset] = term1_grad; 214 | p_grad_a[b][s - 1][t + t_offset] = term1_grad; 215 | py_grad_a[b][s][t - 1] = term2_grad; 216 | p_grad_a[b][s][t - 1] += term2_grad; 217 | } 218 | } 219 | for (int t = t_end; t > t_begin; --t) { 220 | // Backprop for: 221 | // p_a[b][s_begin][t] = 222 | // p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1]; 223 | scalar_t this_p_grad = p_grad_a[b][s_begin][t]; 224 | p_grad_a[b][s_begin][t - 1] += this_p_grad; 225 | py_grad_a[b][s_begin][t - 1] = this_p_grad; 226 | } 227 | if (!modified) { 228 | for (int s = s_end; s > s_begin; --s) { 229 | // Backprop for: 230 | // p_a[b][s][t_begin] = 231 | // p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin]; 232 | scalar_t this_p_grad = p_grad_a[b][s][t_begin]; 233 | p_grad_a[b][s - 1][t_begin] += this_p_grad; 234 | px_grad_a[b][s - 1][t_begin] = this_p_grad; 235 | } 236 | } // else these were all -infinity's and there is nothing to 237 | // backprop. 238 | // There is no backprop for: 239 | // p_a[b][s_begin][t_begin] = 0.0; 240 | // .. but we can use this for a check, that the grad at the beginning 241 | // of the sequence is equal to the grad at the end of the sequence. 242 | if (ans_grad_a[b] != 0.0) { 243 | float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b]; 244 | if (fabs(grad_ratio - 1.0) > 0.01) { 245 | std::cout 246 | << "Warning: mutual_information backprop: expected these " 247 | << "numbers to be the same:" 248 | << static_cast(p_grad_a[b][s_begin][t_begin]) << " vs " 249 | << static_cast(ans_grad_a[b]); 250 | } 251 | } 252 | } 253 | })); 254 | 255 | return std::vector({px_grad, py_grad}); 256 | } 257 | } // namespace fast_rnnt 258 | -------------------------------------------------------------------------------- /fast_rnnt/csrc/mutual_information_cuda.cu: -------------------------------------------------------------------------------- 1 | /** 2 | * @copyright 3 | * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey) 4 | * 5 | * @copyright 6 | * See LICENSE for clarification regarding multiple authors 7 | * 8 | * Licensed under the Apache License, Version 2.0 (the "License"); 9 | * you may not use this file except in compliance with the License. 10 | * You may obtain a copy of the License at 11 | * 12 | * http://www.apache.org/licenses/LICENSE-2.0 13 | * 14 | * Unless required by applicable law or agreed to in writing, software 15 | * distributed under the License is distributed on an "AS IS" BASIS, 16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | * See the License for the specific language governing permissions and 18 | * limitations under the License. 19 | */ 20 | 21 | #include // for getCurrentCUDAStream() 22 | #include 23 | 24 | #include "fast_rnnt/csrc/mutual_information.h" 25 | 26 | namespace fast_rnnt { 27 | /* 28 | Forward of mutual_information. Each thread block computes blocks of the 'p' 29 | array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32). 30 | Thread-blocks loop over such blocks, but they might loop only once if there is 31 | not that much data to process. We sequentially launch thread groups in 32 | such a way that thread-blocks within a group do not depend on each other 33 | (see the "iter" parameter). The blocks of the 'image' (i.e. of the p matrix) 34 | that each group handles are arranged in a diagonal. 35 | 36 | Template args: 37 | scalar_t: the floating-point type, e.g. float, double; maybe eventually 38 | half, although I think we don't support LogAdd for half yet. 39 | BLOCK_SIZE: an integer power of two no greater than 32 (this limitation 40 | is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading 41 | code). 42 | Args: 43 | px: Tensor of shape [B][S][T + 1], if !modified; [B][S][T] if modified; 44 | may be interpreted as the log-odds ratio of 45 | generating the next x in the sequence, i.e. 46 | px[b][s][t] is the log of 47 | p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s), 48 | i.e. the log-prob of generating x_s given subsequences of lengths 49 | (s, t), divided by the prior probability of generating x_s. (See 50 | mutual_information.py for more info). 51 | py: The log-odds ratio of generating the next y in the sequence. 52 | Shape [B][S + 1][T] 53 | p: This function writes to p[b][s][t] the mutual information between 54 | sub-sequences of x and y of length s and t respectively, from the 55 | b'th sequences in the batch. Its shape is [B][S + 1][T + 1]. 56 | Concretely, this function implements the following recursion, 57 | in the case where s_begin == t_begin == 0: 58 | 59 | p[b,0,0] = 0.0 60 | if not `modified`: 61 | p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], 62 | p[b,s,t-1] + py[b,s,t-1]) (eq. 0) 63 | if `modified`: 64 | p[b,s,t] = log_add(p[b,s-1,t-t] + px[b,s-1,t-1], 65 | p[b,s,t-1] + py[b,s,t-1]) (eq. 0) 66 | 67 | treating values with any -1 index as -infinity. 68 | .. if `boundary` is set, we start from p[b,s_begin,t_begin]=0.0. 69 | boundary: If set, a tensor of shape [B][4] of type int64_t, which 70 | contains, where for each batch element b, boundary[b] equals 71 | [s_begin, t_begin, s_end, t_end] 72 | which are the beginning and end (i.e. one-past-the-last) of the 73 | x and y sequences that we should process. Otherwise, must be 74 | a tensor of shape [0][0] of type int64_t; the values will 75 | default to (0, 0, S, T). 76 | ans: a tensor `ans` of shape [B], where this function will set 77 | ans[b] = p[b][s_end][t_end], 78 | with s_end and t_end being (S, T) if `boundary` was specified, 79 | and (boundary[b][2], boundary[b][3]) otherwise. 80 | `ans` represents the mutual information between each pair of 81 | sequences (i.e. x[b] and y[b], although the sequences are not 82 | supplied directly to this function). 83 | 84 | The block-dim and grid-dim must both be 1-dimensional, and the block-dim must 85 | be at least 128. 86 | */ 87 | template // e.g. BLOCK_SIZE == 16 or 32. 89 | __global__ void mutual_information_kernel( 90 | // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1 91 | torch::PackedTensorAccessor32 px, 92 | torch::PackedTensorAccessor32 py, // B, S + 1, T. 93 | // B, S + 1, T + 1. This is an output. 94 | torch::PackedTensorAccessor32 p, 95 | // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) 96 | torch::PackedTensorAccessor32 boundary, 97 | torch::PackedTensorAccessor32 ans, // [B] 98 | int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and 99 | // so on, up to num_iters - 1 where num_iters = num_s_blocks + 100 | // num_t_blocks - 1 num_s_blocks = S / BLOCK_SIZE + 1 101 | // num_t_blocks = T / BLOCK_SIZE + 1 102 | // so that each group depends on the previous group... 103 | const int B = px.size(0), S = px.size(1), T = py.size(2); 104 | const bool modified = (px.size(2) == T); 105 | const int t_offset = (modified ? -1 : 0); // see CPU code to understand. 106 | 107 | // num_s_blocks and num_t_blocks are the number of blocks we need to cover the 108 | // array of size (S, T) with blocks of this size, in the s and t directions 109 | // respectively. 110 | // You can read the following expressions as simplifications of, for example, 111 | // num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE, 112 | // i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + 113 | // 1). 114 | const int num_s_blocks = S / BLOCK_SIZE + 1; 115 | //, num_t_blocks = T / BLOCK_SIZE + 1; 116 | 117 | // num_blocks_this_iter is an upper bound on the number of blocks of size 118 | // (BLOCK_SIZE by BLOCK_SIZE) that might be active on this iteration (`iter`). 119 | // These iterations start from the bottom left of the image so that on iter == 120 | // 0 we process only one block with block-index (0, 0) then on iter == 1 we 121 | // process block-indexes (1, 0) and (0, 1); and then on iter==2 we process (2, 122 | // 0), (1, 1) and (0, 2); and so on. We also will never have more than 123 | // `num_s_blocks` blocks (We'll never have more than num_t_blocks either, but 124 | // the numbering we use corresponds to s and not t, so when we hit the 125 | // num_t_blocks limit, the blocks with the lowest s indexes would just not be 126 | // active and we'll 'continue' in the loop below). 127 | int num_blocks_this_iter = min(iter + 1, num_s_blocks); 128 | 129 | // For the block with s_block_begin == 0 and t_block_begin == 0 (for 130 | // easy illustration), px_buf[s][t] will contain px[s - 1][t + t_offset]; or 131 | // -infinity. for out-of-range indexes into px. Likewise, py_buf[s][t] will 132 | // contain (py[s][t - 1]). 133 | __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], 134 | py_buf[BLOCK_SIZE][BLOCK_SIZE]; 135 | 136 | // p_buf[s][t] == p[s+s_block_begin-1][t+t_block_begin-1] 137 | // 1st row/col of p_buf correspond to the previously computed blocks (lower 138 | // `iter`), or to negative indexes into p. So, for the origin block, 139 | // p_buf[s][t] corresponds to p[s - 1][t - 1]; or -inf for 140 | // out-of-range values. 141 | __shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1]; 142 | 143 | // boundary_buf will be used to store the b'th row of `boundary` if we have 144 | // boundary information supplied; or (0, 0, S, T) otherwise. 145 | __shared__ int64_t boundary_buf[4]; 146 | 147 | if (threadIdx.x == 0) { 148 | boundary_buf[0] = 0; 149 | boundary_buf[1] = 0; 150 | boundary_buf[2] = S; 151 | boundary_buf[3] = T; 152 | } 153 | 154 | // batch_block_iter iterates over batch elements (index b) and block 155 | // indexes in the range [0..num_blocks_this_iter-1], combining both 156 | // batch and block indexes. 157 | for (int batch_block_iter = blockIdx.x; 158 | batch_block_iter < B * num_blocks_this_iter; 159 | batch_block_iter += gridDim.x) { 160 | int block = batch_block_iter / B, 161 | b = batch_block_iter % B; // b is the index into the batch 162 | 163 | // Note: `block` can be no greater than `iter` because num_blocks_this_iter 164 | // <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and 165 | // block < num_blocks_this_iter, so iter - block >= 0. 166 | int s_block_begin = block * BLOCK_SIZE, 167 | t_block_begin = (iter - block) * BLOCK_SIZE; 168 | bool is_origin_block = (s_block_begin + t_block_begin == 0); 169 | 170 | __syncthreads(); 171 | 172 | if (threadIdx.x < 4) 173 | boundary_buf[threadIdx.x] = boundary[b][threadIdx.x]; 174 | 175 | __syncthreads(); 176 | 177 | int s_begin = boundary_buf[0], t_begin = boundary_buf[1], 178 | s_end = boundary_buf[2], t_end = boundary_buf[3]; 179 | 180 | s_block_begin += s_begin; 181 | t_block_begin += t_begin; 182 | 183 | // block_S and block_T are the actual sizes of this block (the block of `p` 184 | // that we will write), no greater than (BLOCK_SIZE, BLOCK_SIZE) but 185 | // possibly less than that if we are towards the end of the sequence. The 186 | // last element in the output matrix p that we need to write is (s_end, 187 | // t_end), i.e. the one-past-the-end index is (s_end + 1, t_end + 1). 188 | int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin), 189 | block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin); 190 | 191 | if (block_S <= 0 || block_T <= 0) 192 | continue; 193 | 194 | // Load px_buf and py_buf. 195 | for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { 196 | int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE, 197 | s = s_in_block + s_block_begin, t = t_in_block + t_block_begin, 198 | t_off = t + t_offset; 199 | // comparing as unsigned int makes sure the index is nonnegative. 200 | // Caution: if s_begin > 0 or t_begin > 0 we may end up loading some px 201 | // and py values that are outside the proper boundaries that we need, but 202 | // the corresponding p_buf values will end up being 0 so this won't 203 | // matter. 204 | scalar_t this_px = -INFINITY; 205 | // Below, "&& t <= t_end" can be interpreted as: 206 | // "&& (modified ? t_off < t_end : t_off <= t_end) 207 | // [since px's last valid index is t_end - 1 if modified, else t_end. 208 | if (s > s_begin && s <= s_end && t_off >= t_begin && t <= t_end) 209 | this_px = px[b][s - 1][t_off]; 210 | 211 | px_buf[s_in_block][t_in_block] = this_px; 212 | 213 | scalar_t this_py = -INFINITY; 214 | if (t > t_begin && t <= t_end && s <= s_end) 215 | this_py = py[b][s][t - 1]; 216 | py_buf[s_in_block][t_in_block] = this_py; 217 | } 218 | 219 | // Load the 1st row and 1st column of p_buf. 220 | // This is the context from previously computed blocks of the 221 | // image. Remember: p_buf[s][t] will correspond to p[s + s_block_begin - 222 | // 1][t + t_block_begin - 1] 223 | if (threadIdx.x <= BLOCK_SIZE) { 224 | // s_in_p_buf and t_in_pbuf are simply the indexes into p_buf 225 | int s_in_p_buf = threadIdx.x, t_in_p_buf = 0, 226 | s = s_in_p_buf + s_block_begin - 1, 227 | t = t_in_p_buf + t_block_begin - 1; 228 | 229 | scalar_t this_p = -INFINITY; 230 | if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end) 231 | this_p = p[b][s][t]; 232 | p_buf[s_in_p_buf][t_in_p_buf] = this_p; 233 | } else if (static_cast(static_cast(threadIdx.x) - 64) <= 234 | static_cast(BLOCK_SIZE)) { 235 | // Another warp handles the other leg. Checking as unsigned 236 | // tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE 237 | int s_in_p_buf = 0, t_in_p_buf = static_cast(threadIdx.x) - 64, 238 | s = s_in_p_buf + s_block_begin - 1, 239 | t = t_in_p_buf + t_block_begin - 1; 240 | 241 | scalar_t this_p = -INFINITY; 242 | if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end) 243 | this_p = p[b][s][t]; 244 | p_buf[s_in_p_buf][t_in_p_buf] = this_p; 245 | } 246 | 247 | __syncthreads(); 248 | 249 | // from here to the next __syncthreads(), only the 1st warp should be active 250 | // so we shouldn't need to synchronize. (implicit within-warp 251 | // synchronization). 252 | 253 | if (threadIdx.x == 0) { 254 | // This if-statement is an optimization and modification of the loop below 255 | // for the value i == 0, i.e. inner-iteration == 0. The modification is 256 | // to set p_buf to 1.0 = exp(0.0) if this is the "origin block", 257 | // i.e. s == s_begin, t == t_begin. This corresponds to the 258 | // probability of the pair of sequences of length (0, 0). 259 | p_buf[1][1] = 260 | (is_origin_block ? 0.0 261 | : LogAdd( 262 | // px_buf has t_offset applied. 263 | p_buf[0][1 + t_offset] + px_buf[0][0], 264 | p_buf[1][0] + py_buf[0][0])); 265 | } 266 | 267 | int s = threadIdx.x; 268 | for (int i = 1; i < block_S + block_T - 1; ++i) { 269 | __syncwarp(); 270 | // i is the inner iteration, which corresponds to the (s + t) indexes of 271 | // the elements within the block that we write. So i == 0 writes 272 | // positions (s, t) == (0, 0) (but we treated i == 0 as a special case 273 | // above); i == 1 writes (0, 1) and (1, 0); i == 2 writes (0, 2), (1, 1) 274 | // and (2, 1); and so on. Note: not many threads participate in this 275 | // part, only up to BLOCK_SIZE at most. Unfortunately we couldn't figure 276 | // out a very meaningful way for more threads to do work, that looked like 277 | // it would really speed things up. 278 | // So this kernel does (2 * BLOCK_SIZE) iterations, which may seem a lot, 279 | // but we do at least do the I/O in an efficient way and keep the 280 | // inner loop simple and fast (e.g. no exp() or log()). 281 | int t = i - s; 282 | if (s < block_S && 283 | static_cast(t) < static_cast(block_T)) { 284 | // p_buf is indexed by s + 1 and t + 1 because it has an extra initial 285 | // row and column for context from previous blocks. Taking into account 286 | // the way these buffers relate to the tensors p, px and py, 287 | // can be interpreted as follows, 288 | // writing sbb for s_block_begin and tbb for t_block_begin: 289 | // 290 | // p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + 291 | // px[s+sbb-1][t+tbb], 292 | // p[b][s+sbb][t+tbb-1] + 293 | // py[s+sbb][t+tbb-1] 294 | // 295 | // where you can see that apart from the offsets of tbb and sbb, this is 296 | // the same as the recursion defined for p in 297 | // mutual_information.py:mutual_information_recursion(); and (eq. 0) 298 | // above. 299 | 300 | // note: px_buf has t_offset applied.. 301 | p_buf[s + 1][t + 1] = LogAdd(p_buf[s][t + 1 + t_offset] + px_buf[s][t], 302 | p_buf[s + 1][t] + py_buf[s][t]); 303 | // We don't need to do __syncthreads() in this loop because all the 304 | // threads that are active are in the same warp. (However, in future, 305 | // if NVidia changes some things, we might need to sync here). 306 | } 307 | } 308 | __syncthreads(); 309 | 310 | // Write out the data to p; 311 | for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { 312 | int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE, 313 | s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; 314 | if (s_in_block < block_S && t_in_block < block_T) { 315 | scalar_t this_p = p_buf[s_in_block + 1][t_in_block + 1]; 316 | p[b][s][t] = this_p; 317 | } 318 | } 319 | 320 | __syncthreads(); 321 | 322 | if (threadIdx.x == 0) { 323 | // Write `ans`, if this is the final (top-right) block in its sequence 324 | // Logically, the following equation corresponds to: 325 | // ans[b] = p[b][s_end][t_end] 326 | if (s_block_begin + block_S - 1 == s_end && 327 | t_block_begin + block_T - 1 == t_end) { 328 | // you could read block_S below as block_S - 1 + 1, meaning, 329 | // it's the last index in a block of size block_S, but the indexes into 330 | // p_buf have a "+ 1". Likewise for block_T. 331 | ans[b] = p_buf[block_S][block_T]; 332 | } 333 | } 334 | } 335 | } 336 | 337 | // like exp(), but returns 0 if arg is inf/nan, or if result would be 338 | // infinity or nan (note: this can happen for out-of-range elements 339 | // when setting px_buf and py_buf is block_S != BLOCK_SIZE or 340 | // block_T != BLOCK_SIZE, and it's a problem because even though 341 | // out-of-range gradients are zero, if we multiply them by infinity 342 | // we get NaN. 343 | template __forceinline__ __device__ Real safe_exp(Real x) { 344 | if (x - x != 0) 345 | return 0; 346 | else { 347 | Real ans = exp(x); 348 | if (ans - ans != 0.0) 349 | return 0; 350 | return ans; 351 | } 352 | } 353 | 354 | /* 355 | Backward of mutual_information. 356 | 357 | The forward pass is: 358 | 359 | p[b,s,t] = log_add(p[b,s-1,t+t_offset] + px[b,s-1,t+t_offset], 360 | p[b,s,t-1] + py[b,s,t-1]) (eq. 0) 361 | 362 | where t_offset = (modified ? -1 : 0) 363 | 364 | The backprop for the above, implemented in the obvious way, would be as 365 | follows (note, we define term1 and term2 with offsets in the indexes, which 366 | will be convenient later..): 367 | 368 | term1(b,s-1,t+t_offset) = 369 | exp(p[b,s-1,t+t_offset] + px[b,s-1,t+t_offset] - p[b,s,t]) (0a) 370 | term2(b,s,t-1) = exp(p[b,s,t-1] + py[b,s,t-1] - p[b,s,t]) (0b) 371 | 372 | p_grad[b,s-1,t+t_offset] += p_grad[b,s,t] * term1(b,s-1,t+t_offset) (1a) 373 | px_grad[b,s-1,t+t_offset] += p_grad[b,s,t] * term1(b,s-1,t+t_offset) (1b) 374 | p_grad[b,s,t-1] += p_grad[b,s,t] * term2(b,s,t-1) (1c) 375 | py_grad[b,s,t-1] += p_grad[b,s,t] * term2(b,s,t-1) (1d) 376 | 377 | Adding 1 and -t_offset to the s and t indexes of (1a) an (1b), and 378 | 1 to the t index of (1c) and (1d), the equations become: 379 | 380 | p_grad[b,s,t] += p_grad[b,s+1,t-t_offset] * term1(b,s,t) (2a) 381 | px_grad[b,s,t] += p_grad[b,s+1,t-t_offset] * term1(b,s,t) (2b) 382 | p_grad[b,s,t] += p_grad[b,s,t+1] * term2(b,s,t) (2c) 383 | py_grad[b,s,t] += p_grad[b,s,t+1] * term2(b,s,t) (2d) 384 | 385 | .. and replacing "+=" with "=", we can write: 386 | 387 | p_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t) + (3a) 388 | p_grad[b,s,t+1] * term2(b,s,t) 389 | px_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t) (3b) 390 | py_grad[b,s,t] = p_grad[b,s,t+1] * term2(b,s,t) (3c) 391 | 392 | Writing the definitions of term1 and term2 in a more convenient way: 393 | term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a) 394 | term2(b,s,t) = exp(p[b,s,t] + py[b,s,t] - p[b,s,t+1]) (4b) 395 | 396 | The backward pass will be slightly different from the forward pass in terms of 397 | how we store and index p (and p_grad), because for writing a particular block 398 | of p_grad, we need context on the top and right instead of the bottom and 399 | left. So there are offsets of 1. 400 | */ 401 | template 402 | __global__ void mutual_information_backward_kernel( 403 | torch::PackedTensorAccessor32 404 | px, // B, S, T + 1 if !modified; B, S, T if modified. 405 | torch::PackedTensorAccessor32 py, // B, S + 1, T. 406 | // B, S + 1, T + 1. Produced in forward pass. 407 | torch::PackedTensorAccessor32 p, 408 | // [B]. This is an input. 409 | torch::PackedTensorAccessor32 ans_grad, 410 | torch::PackedTensorAccessor32 411 | p_grad, // B, S + 1, T + 1 if !modified; B, S, T if modified. 412 | torch::PackedTensorAccessor32 px_grad, // B, S, T + 1. 413 | torch::PackedTensorAccessor32 py_grad, // B, S + 1, T. 414 | // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) 415 | torch::PackedTensorAccessor32 boundary, 416 | int iter, // This kernel is sequentially called with 'iter' = num_iters 417 | // - 1, num_iters - 2, .. 0, where num_iters can be taken to 418 | // be any sufficiently large number but will actually be: 419 | // num_s_blocks + num_t_blocks - 1 where num_s_blocks = S / 420 | // BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1 421 | bool overwrite_ans_grad) { // If overwrite_ans_grad == true, this function 422 | // will overwrite ans_grad with a value which, 423 | // if everything is working correctly, should be 424 | // identical or very close to the value of 425 | // ans_grad that was passed in. 426 | const int B = px.size(0), S = px.size(1), T = py.size(2); 427 | const bool modified = (px.size(2) == T); 428 | const int neg_t_offset = (modified ? 1 : 0); 429 | 430 | // For statements that are the same as the forward pass, we are omitting some 431 | // comments. We'll focus, in the comments, on differences from the forward 432 | // pass. 433 | const int num_s_blocks = S / BLOCK_SIZE + 1, 434 | // num_t_blocks = T / BLOCK_SIZE + 1, 435 | num_blocks_this_iter = min(iter + 1, num_s_blocks); 436 | 437 | // px_buf and py_buf are used temporarily to store the px and py values, 438 | // but then modified to store the "xderiv" and "yderiv" values defined 439 | // in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0 440 | // here. 441 | // Initially (before xderiv/yderiv are written): 442 | // px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin]; 443 | // py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin]. 444 | // Later (see eq. 4 and eq. 5): 445 | // px_buf[s][t] contains term1(b,ss,tt) == 446 | // exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt-t_offset]), 447 | // py_buf[s][t] contains term2(b,ss,tt) == 448 | 449 | // where ss == s + s_block_begin, tt = t + t_block_begin. 450 | // Unlike in the forward code, there is no offset of 1 in the indexes. 451 | __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], 452 | py_buf[BLOCK_SIZE][BLOCK_SIZE]; 453 | 454 | // p_buf is initially used to store p, and then (after we are done putting 455 | // term1 and term2 into px_buf and py_buf) it is repurposed to store 456 | // p_grad. 457 | // 458 | // Unlike in the forward pass, p_buf has the same numbering as px_buf and 459 | // py_buf, it's not offset by 1: e.g., for the origin block, p_buf[0][0] 460 | // refers to p[0][0] and not p[-1][-1]. The p_buf block is larger by 1 than 461 | // the block for px_buf and py_buf; unlike in the forward pass, we store 462 | // context on the top and right, not the bottom and left, i.e. the elements at 463 | // (one past the largest indexes in the block). 464 | // 465 | // For out-of-range elements of p_buf, we'll put zero. 466 | __shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1]; 467 | 468 | // boundary_buf will be used to store the b'th row of `boundary` if we have 469 | // boundary information supplied; or (0, 0, S, T) if not. 470 | __shared__ int64_t boundary_buf[4]; 471 | 472 | if (threadIdx.x == 0) { 473 | boundary_buf[0] = 0; 474 | boundary_buf[1] = 0; 475 | boundary_buf[2] = S; 476 | boundary_buf[3] = T; 477 | } 478 | 479 | // batch_block_iter iterates over both batch elements (index b), and block 480 | // indexes in the range [0..num_blocks_this_iter-1]. The order here 481 | // doesn't matter, since there are no interdependencies between these 482 | // blocks (they are on a diagonal). 483 | for (int batch_block_iter = blockIdx.x; 484 | batch_block_iter < B * num_blocks_this_iter; 485 | batch_block_iter += gridDim.x) { 486 | int block = batch_block_iter / B, b = batch_block_iter % B; 487 | int s_block_begin = block * BLOCK_SIZE, 488 | t_block_begin = (iter - block) * BLOCK_SIZE; 489 | 490 | if (threadIdx.x < 4) 491 | boundary_buf[threadIdx.x] = boundary[b][threadIdx.x]; 492 | __syncthreads(); 493 | 494 | int s_begin = boundary_buf[0], t_begin = boundary_buf[1], 495 | s_end = boundary_buf[2], t_end = boundary_buf[3]; 496 | s_block_begin += s_begin; 497 | t_block_begin += t_begin; 498 | 499 | // block_S and block_T are the actual sizes of this block, no greater than 500 | // (BLOCK_SIZE, BLOCK_SIZE) but possibly less than that if we are towards 501 | // the end of the sequence. 502 | // The last element of the output matrix p_grad we write is (s_end, t_end), 503 | // i.e. the one-past-the-end index of p_grad is (s_end + 1, t_end + 1). 504 | int block_S = min(BLOCK_SIZE, s_end + 1 - s_block_begin), 505 | block_T = min(BLOCK_SIZE, t_end + 1 - t_block_begin); 506 | 507 | if (block_S <= 0 || block_T <= 0) 508 | continue; 509 | 510 | // Load px_buf and py_buf. At this point we just set them to the px and py 511 | // for this block. 512 | for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { 513 | int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE, 514 | s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; 515 | // We let px and py default to -infinity if they are out of range, which 516 | // will cause xderiv and yderiv for out-of-range values to be zero, and 517 | // cause correct behavior in edge cases (for the top and right blocks). 518 | // The issue is that p and p_grad are of larger size than px and py. 519 | scalar_t this_px = -INFINITY; 520 | if (s < s_end && t <= t_end) 521 | this_px = px[b][s][t]; 522 | px_buf[s_in_block][t_in_block] = this_px; 523 | scalar_t this_py = -INFINITY; 524 | if (s <= s_end && t < t_end) 525 | this_py = py[b][s][t]; 526 | py_buf[s_in_block][t_in_block] = this_py; 527 | } 528 | __syncthreads(); 529 | 530 | // load p. 531 | for (int i = threadIdx.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1); 532 | i += blockDim.x) { 533 | int s_in_block = i / (BLOCK_SIZE + 1), t_in_block = i % (BLOCK_SIZE + 1), 534 | s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; 535 | // Setting 0.0 for out-of-bounds elements of p, together with setting 536 | // -INFINITY for out-of-bounds elements of px_buf and py_buf, will 537 | // ensure that we do the right thing in top and right edge cases, 538 | // i.e. that no derivatives will be propagated from out-of-bounds points 539 | // because the corresponding xderiv and yderiv values will be zero. 540 | scalar_t this_p = 0.0; 541 | if (s <= s_end && t <= t_end) 542 | this_p = p[b][s][t]; 543 | // if this_p is -inf, replace with large finite negative value, to avoid 544 | // NaN's below. 545 | // TODO: use a value that would work correctly in half precision 546 | if (this_p < -1.0e+30) 547 | this_p = -1.0e+30; 548 | p_buf[s_in_block][t_in_block] = this_p; 549 | } 550 | __syncthreads(); 551 | 552 | // Set term1 and term2; see equations (4a) and (4b) above. 553 | for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { 554 | // We can apply this formula to the entire block even if we are processing 555 | // a partial block; we have ensured that x_buf and y_buf contain 556 | // -infinity, and p contains 0, for out-of-range elements, so we'll get 557 | // x_buf and y_buf containing 0 after applying the following formulas. 558 | int s = i / BLOCK_SIZE, t = i % BLOCK_SIZE; 559 | // Mathematically the following is doing: 560 | // term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a) 561 | // (with an offset on the s and t indexes) 562 | // Use safe_exp() not exp(), as we could have (-inf) - (-inf) = nan, want 563 | // any finite number in this case as derivs would be zero. 564 | // Also want -inf->zero. 565 | px_buf[s][t] = 566 | safe_exp(p_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t + neg_t_offset]); 567 | // Mathematically the following is doing: 568 | // term2(b,s,t) = exp(p[b,s,t] + py[b,s,t] - p[b,s,t+1]) (4b) 569 | // (with an offset on the s and t indexes) 570 | py_buf[s][t] = safe_exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]); 571 | } 572 | 573 | __syncthreads(); 574 | 575 | // Load p_grad for the top and right elements in p_buf: i.e. for elements 576 | // p_buf[s][t] where s == block_S (exclusive-or) t == block_T. 577 | // These are the p_grad values computed by previous instances of this kernel 578 | // If this is one of the top or right blocks, some or all of the p_grad 579 | // values we'd be reading here will be out of range, and we use zeros 580 | // to ensure no gradient gets propagated from those positions. 581 | if (threadIdx.x <= block_S) { 582 | int s_in_block = threadIdx.x, t_in_block = block_T, 583 | s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; 584 | p_buf[s_in_block][t_in_block] = 585 | (s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0); 586 | } else if (static_cast(static_cast(threadIdx.x) - 64) < 587 | static_cast(block_T)) { 588 | // casting to unsigned before the comparison tests for both negative and 589 | // out-of-range values of (int)threadIdx.x - 64. 590 | int s_in_block = block_S, t_in_block = static_cast(threadIdx.x) - 64, 591 | s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; 592 | p_buf[s_in_block][t_in_block] = 593 | (s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0); 594 | } 595 | 596 | __syncthreads(); 597 | 598 | // The highest-numbered value in p_buf that we need (corresponding, 599 | // of course, to p_grad), is: 600 | // p_buf[block_S - 1][block_T - 1], 601 | // and the inner iteration number (i) on which we set this is the sum of 602 | // these indexes, i.e. (block_S - 1) + (block_T - 1). 603 | bool is_final_block = (s_block_begin + block_S == s_end + 1 && 604 | t_block_begin + block_T == t_end + 1); 605 | 606 | int first_iter = block_S + block_T - 2; 607 | if (is_final_block) { 608 | // The following statement corresponds to: 609 | // p_grad[b][s_end][t_end] = ans_grad[b] 610 | // Normally this element of p_buf would be set by the first iteration of 611 | // the loop below, so if it's set this way we have to decrement first_iter 612 | // to prevent it from being overwritten. 613 | p_buf[block_S - 1][block_T - 1] = ans_grad[b]; 614 | --first_iter; 615 | } 616 | 617 | { 618 | int s = threadIdx.x; 619 | for (int i = first_iter; i >= 0; --i) { 620 | __syncwarp(); 621 | int t = i - s; 622 | if (s < block_S && 623 | static_cast(t) < static_cast(block_T)) { 624 | // The following statement is really operating on the gradients; 625 | // it corresponds, with offsets of s_block_begin and t_block_begin 626 | // on the indexes, to equation (3a) above, i.e.: 627 | // p_grad[b,s,t] = 628 | // p_grad[b,s+1,t-t_offset] * term1(b,s,t) + (3a) 629 | // p_grad[b,s,t+1] * term2(b,s,t) 630 | p_buf[s][t] = (p_buf[s + 1][t + neg_t_offset] * px_buf[s][t] + 631 | p_buf[s][t + 1] * py_buf[s][t]); 632 | } 633 | } 634 | } 635 | 636 | __syncthreads(); 637 | 638 | // Write out p_grad, px_grad and py_grad. 639 | for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { 640 | int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE, 641 | s = s_in_block + s_block_begin, t = t_in_block + t_block_begin; 642 | // s_end and t_end are the one-past-the-end of the (x,y) sequences, but 643 | // the one-past-the-end element of p_grad would be (s_end + 1, t_end + 1). 644 | if (t <= t_end && s <= s_end) { 645 | p_grad[b][s][t] = p_buf[s_in_block][t_in_block]; 646 | 647 | if (s < s_end && t <= t_end - neg_t_offset) { 648 | // write px_grad, which is of shape [B][S][T + 1] if !modified, 649 | // [B][S][T] if modified. the condition "t <= t_end - neg_t_offset" 650 | // becomes "t <= t_end" if !modified, and "t <= t_end - 1" if 651 | // modified, keeping us within the bounds of px_grad. 652 | 653 | // From (eq. 3b): 654 | // px_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t) 655 | px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block + neg_t_offset] * 656 | px_buf[s_in_block][t_in_block]); 657 | } 658 | if (t < t_end) { // write py_grad, which is of shape [B][S + 1][T] 659 | // from (eq. 3c): 660 | // py_grad[b,s,t] = p_grad[b,s,t+1] * term2(b,s,t) 661 | py_grad[b][s][t] = (p_buf[s_in_block][t_in_block + 1] * 662 | py_buf[s_in_block][t_in_block]); 663 | } 664 | } 665 | } 666 | 667 | if (threadIdx.x == 0 && s_block_begin == s_begin && 668 | t_block_begin == t_begin && overwrite_ans_grad) 669 | ans_grad[b] = p_buf[0][0]; 670 | } 671 | } 672 | 673 | // forward of mutual_information. See """... """ comment of 674 | // `mutual_information` in mutual_information.py for documentation of the 675 | // behavior of this function. 676 | torch::Tensor MutualInformationCuda(torch::Tensor px, torch::Tensor py, 677 | torch::optional opt_boundary, 678 | torch::Tensor p) { 679 | TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); 680 | TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); 681 | TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); 682 | TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() && 683 | p.device().is_cuda(), 684 | "inputs must be CUDA tensors"); 685 | 686 | auto scalar_t = px.scalar_type(); 687 | auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); 688 | 689 | const int B = px.size(0), S = px.size(1), T = py.size(2); 690 | TORCH_CHECK(px.size(2) == T || px.size(2) == T + 1); 691 | TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); 692 | TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); 693 | 694 | auto boundary = opt_boundary.value_or( 695 | torch::tensor({0, 0, S, T}, 696 | torch::dtype(torch::kInt64).device(px.device())) 697 | .reshape({1, 4}) 698 | .expand({B, 4})); 699 | TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4); 700 | TORCH_CHECK(boundary.device().is_cuda() && boundary.dtype() == torch::kInt64); 701 | 702 | torch::Tensor ans = torch::empty({B}, opts); 703 | 704 | // num_threads and num_blocks and BLOCK_SIZE can be tuned. 705 | // (however, num_threads may not be less than 128). 706 | const int num_threads = 128, num_blocks = 256, BLOCK_SIZE = 32; 707 | 708 | // The blocks cover the 'p' matrix, which is of size (B, S+1, T+1), 709 | // so dividing by BLOCK_SIZE rounding up we get e.g. 710 | // (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1 711 | const int num_s_blocks = S / BLOCK_SIZE + 1, 712 | num_t_blocks = T / BLOCK_SIZE + 1, 713 | num_iters = num_s_blocks + num_t_blocks - 1; 714 | 715 | AT_DISPATCH_FLOATING_TYPES( 716 | px.scalar_type(), "mutual_information_cuda_stub", ([&] { 717 | for (int iter = 0; iter < num_iters; ++iter) { 718 | mutual_information_kernel 719 | <<>>( 720 | px.packed_accessor32(), 721 | py.packed_accessor32(), 722 | p.packed_accessor32(), 723 | boundary.packed_accessor32(), 724 | ans.packed_accessor32(), iter); 725 | } 726 | })); 727 | return ans; 728 | } 729 | 730 | // backward of mutual_information; returns (grad_px, grad_py) 731 | // If overwrite_ans_grad == true, will overwrite ans_grad with a value which 732 | // should be identical to the original ans_grad if the computation worked 733 | // as it should. 734 | std::vector 735 | MutualInformationBackwardCuda(torch::Tensor px, torch::Tensor py, 736 | torch::optional opt_boundary, 737 | torch::Tensor p, torch::Tensor ans_grad, 738 | bool overwrite_ans_grad) { 739 | TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); 740 | TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); 741 | TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); 742 | TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional."); 743 | 744 | TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() && 745 | p.device().is_cuda() && ans_grad.device().is_cuda() && 746 | "inputs must be CUDA tensors"); 747 | 748 | auto scalar_t = px.scalar_type(); 749 | auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); 750 | 751 | const int B = px.size(0), S = px.size(1), T = py.size(2); 752 | 753 | TORCH_CHECK(px.size(2) == T || 754 | px.size(2) == T + 1); // modified case || not-modified case 755 | const bool modified = (px.size(2) == T); 756 | TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1); 757 | TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); 758 | 759 | auto boundary = opt_boundary.value_or( 760 | torch::tensor({0, 0, S, T}, 761 | torch::dtype(torch::kInt64).device(px.device())) 762 | .reshape({1, 4}) 763 | .expand({B, 4})); 764 | TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4); 765 | TORCH_CHECK(boundary.device().is_cuda() && boundary.dtype() == torch::kInt64); 766 | TORCH_CHECK(ans_grad.size(0) == B); 767 | 768 | bool has_boundary = opt_boundary.has_value(); 769 | 770 | int T1 = T + (modified ? 0 : 1); 771 | torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts), 772 | px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts) 773 | : torch::empty({B, S, T1}, opts)), 774 | py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) 775 | : torch::empty({B, S + 1, T}, opts)); 776 | 777 | // num_threads and num_blocks and BLOCK_SIZE can be tuned. 778 | // (however, num_threads may not be less than 128). 779 | const int num_threads = 128, num_blocks = 256, BLOCK_SIZE = 32; 780 | 781 | // The blocks cover the 'p' matrix, which is of size (B, S+1, T+1), 782 | // so dividing by BLOCK_SIZE rounding up we get e.g. 783 | // (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1 784 | const int num_s_blocks = S / BLOCK_SIZE + 1, 785 | num_t_blocks = T / BLOCK_SIZE + 1, 786 | num_iters = num_s_blocks + num_t_blocks - 1; 787 | 788 | AT_DISPATCH_FLOATING_TYPES( 789 | px.scalar_type(), "mutual_information_backward_stub", ([&] { 790 | for (int iter = num_iters - 1; iter >= 0; --iter) { 791 | mutual_information_backward_kernel 792 | <<>>( 793 | px.packed_accessor32(), 794 | py.packed_accessor32(), 795 | p.packed_accessor32(), 796 | ans_grad.packed_accessor32(), 797 | p_grad.packed_accessor32(), 798 | px_grad.packed_accessor32(), 799 | py_grad.packed_accessor32(), 800 | boundary.packed_accessor32(), iter, 801 | overwrite_ans_grad); 802 | } 803 | })); 804 | return std::vector({px_grad, py_grad}); 805 | } 806 | } // namespace fast_rnnt 807 | -------------------------------------------------------------------------------- /fast_rnnt/python/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(csrc) 2 | add_subdirectory(tests) 3 | -------------------------------------------------------------------------------- /fast_rnnt/python/csrc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | include_directories(${CMAKE_SOURCE_DIR}) 2 | 3 | include(transform) 4 | 5 | # please keep the list sorted 6 | set(fast_rnnt_srcs 7 | fast_rnnt.cu 8 | mutual_information.cu 9 | ) 10 | 11 | if(NOT FT_WITH_CUDA) 12 | transform(OUTPUT_VARIABLE fast_rnnt_srcs SRCS ${fast_rnnt_srcs}) 13 | endif() 14 | 15 | 16 | pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs}) 17 | target_link_libraries(_fast_rnnt PRIVATE mutual_information_core) 18 | 19 | if(APPLE) 20 | target_link_libraries(_fast_rnnt 21 | PRIVATE 22 | ${TORCH_DIR}/lib/libtorch_python.dylib 23 | ) 24 | elseif(UNIX) 25 | target_link_libraries(_fast_rnnt 26 | PRIVATE 27 | ${PYTHON_LIBRARY} 28 | ${TORCH_DIR}/lib/libtorch_python.so 29 | ) 30 | endif() 31 | -------------------------------------------------------------------------------- /fast_rnnt/python/csrc/fast_rnnt.cu: -------------------------------------------------------------------------------- 1 | /** 2 | * @copyright 3 | * Copyright 2022 Xiaomi Corporation (authors: Wei Kang) 4 | * 5 | * @copyright 6 | * See LICENSE for clarification regarding multiple authors 7 | * 8 | * Licensed under the Apache License, Version 2.0 (the "License"); 9 | * you may not use this file except in compliance with the License. 10 | * You may obtain a copy of the License at 11 | * 12 | * http://www.apache.org/licenses/LICENSE-2.0 13 | * 14 | * Unless required by applicable law or agreed to in writing, software 15 | * distributed under the License is distributed on an "AS IS" BASIS, 16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | * See the License for the specific language governing permissions and 18 | * limitations under the License. 19 | */ 20 | 21 | #include "fast_rnnt/python/csrc/fast_rnnt.h" 22 | #include "fast_rnnt/python/csrc/mutual_information.h" 23 | 24 | PYBIND11_MODULE(_fast_rnnt, m) { 25 | m.doc() = "Python wrapper for Fast Rnnt."; 26 | 27 | fast_rnnt::PybindMutualInformation(m); 28 | } 29 | -------------------------------------------------------------------------------- /fast_rnnt/python/csrc/fast_rnnt.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @copyright 3 | * Copyright 2022 Xiaomi Corporation (authors: Wei Kang) 4 | * 5 | * @copyright 6 | * See LICENSE for clarification regarding multiple authors 7 | * 8 | * Licensed under the Apache License, Version 2.0 (the "License"); 9 | * you may not use this file except in compliance with the License. 10 | * You may obtain a copy of the License at 11 | * 12 | * http://www.apache.org/licenses/LICENSE-2.0 13 | * 14 | * Unless required by applicable law or agreed to in writing, software 15 | * distributed under the License is distributed on an "AS IS" BASIS, 16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | * See the License for the specific language governing permissions and 18 | * limitations under the License. 19 | */ 20 | 21 | #ifndef FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_ 22 | #define FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_ 23 | 24 | #include "pybind11/pybind11.h" 25 | 26 | namespace py = pybind11; 27 | 28 | #endif // FAST_RNNT_PYTHON_CSRC_FAST_RNNT_H_ 29 | -------------------------------------------------------------------------------- /fast_rnnt/python/csrc/mutual_information.cu: -------------------------------------------------------------------------------- 1 | /** 2 | * @copyright 3 | * Copyright 2022 Xiaomi Corporation (authors: Wei Kang) 4 | * 5 | * @copyright 6 | * See LICENSE for clarification regarding multiple authors 7 | * 8 | * Licensed under the Apache License, Version 2.0 (the "License"); 9 | * you may not use this file except in compliance with the License. 10 | * You may obtain a copy of the License at 11 | * 12 | * http://www.apache.org/licenses/LICENSE-2.0 13 | * 14 | * Unless required by applicable law or agreed to in writing, software 15 | * distributed under the License is distributed on an "AS IS" BASIS, 16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | * See the License for the specific language governing permissions and 18 | * limitations under the License. 19 | */ 20 | 21 | #include "fast_rnnt/csrc/device_guard.h" 22 | #include "fast_rnnt/csrc/mutual_information.h" 23 | #include "fast_rnnt/python/csrc/mutual_information.h" 24 | 25 | namespace fast_rnnt { 26 | void PybindMutualInformation(py::module &m) { 27 | m.def( 28 | "mutual_information_forward", 29 | [](torch::Tensor px, torch::Tensor py, 30 | torch::optional boundary, 31 | torch::Tensor p) -> torch::Tensor { 32 | fast_rnnt::DeviceGuard guard(px.device()); 33 | if (px.device().is_cpu()) { 34 | return MutualInformationCpu(px, py, boundary, p); 35 | } else { 36 | #ifdef FT_WITH_CUDA 37 | return MutualInformationCuda(px, py, boundary, p); 38 | #else 39 | TORCH_CHECK(false, "Failed to find native CUDA module, make sure " 40 | "that you compiled the code with K2_WITH_CUDA."); 41 | return torch::Tensor(); 42 | #endif 43 | } 44 | }, 45 | py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p")); 46 | 47 | m.def( 48 | "mutual_information_backward", 49 | [](torch::Tensor px, torch::Tensor py, 50 | torch::optional boundary, torch::Tensor p, 51 | torch::Tensor ans_grad) -> std::vector { 52 | fast_rnnt::DeviceGuard guard(px.device()); 53 | if (px.device().is_cpu()) { 54 | return MutualInformationBackwardCpu(px, py, boundary, p, ans_grad); 55 | } else { 56 | #ifdef FT_WITH_CUDA 57 | return MutualInformationBackwardCuda(px, py, boundary, p, ans_grad, 58 | true); 59 | #else 60 | TORCH_CHECK(false, "Failed to find native CUDA module, make sure " 61 | "that you compiled the code with K2_WITH_CUDA."); 62 | return std::vector(); 63 | #endif 64 | } 65 | }, 66 | py::arg("px"), py::arg("py"), py::arg("boundary"), py::arg("p"), 67 | py::arg("ans_grad")); 68 | 69 | m.def("with_cuda", []() -> bool { 70 | #ifdef FT_WITH_CUDA 71 | return true; 72 | #else 73 | return false; 74 | #endif 75 | }); 76 | } 77 | } // namespace fast_rnnt 78 | -------------------------------------------------------------------------------- /fast_rnnt/python/csrc/mutual_information.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @copyright 3 | * Copyright 2022 Xiaomi Corporation (authors: Wei Kang) 4 | * 5 | * @copyright 6 | * See LICENSE for clarification regarding multiple authors 7 | * 8 | * Licensed under the Apache License, Version 2.0 (the "License"); 9 | * you may not use this file except in compliance with the License. 10 | * You may obtain a copy of the License at 11 | * 12 | * http://www.apache.org/licenses/LICENSE-2.0 13 | * 14 | * Unless required by applicable law or agreed to in writing, software 15 | * distributed under the License is distributed on an "AS IS" BASIS, 16 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | * See the License for the specific language governing permissions and 18 | * limitations under the License. 19 | */ 20 | 21 | #ifndef FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_ 22 | #define FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_ 23 | 24 | #include "fast_rnnt/python/csrc/fast_rnnt.h" 25 | 26 | namespace fast_rnnt { 27 | 28 | void PybindMutualInformation(py::module &m); 29 | 30 | } // namespace fast_rnnt 31 | 32 | #endif // FAST_RNNT_PYTHON_CSRC_MUTUAL_INFORMATION_H_ 33 | -------------------------------------------------------------------------------- /fast_rnnt/python/fast_rnnt/__init__.py: -------------------------------------------------------------------------------- 1 | from _fast_rnnt import with_cuda 2 | 3 | from .mutual_information import mutual_information_recursion 4 | from .mutual_information import joint_mutual_information_recursion 5 | 6 | from .rnnt_loss import do_rnnt_pruning 7 | from .rnnt_loss import get_rnnt_logprobs 8 | from .rnnt_loss import get_rnnt_logprobs_joint 9 | from .rnnt_loss import get_rnnt_logprobs_pruned 10 | from .rnnt_loss import get_rnnt_logprobs_smoothed 11 | from .rnnt_loss import get_rnnt_prune_ranges 12 | from .rnnt_loss import rnnt_loss 13 | from .rnnt_loss import rnnt_loss_pruned 14 | from .rnnt_loss import rnnt_loss_simple 15 | from .rnnt_loss import rnnt_loss_smoothed 16 | -------------------------------------------------------------------------------- /fast_rnnt/python/fast_rnnt/mutual_information.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey, Wei Kang) 2 | # 3 | # See ../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import torch 19 | import _fast_rnnt 20 | from torch import Tensor 21 | from typing import Tuple, Optional, Sequence, Union, List 22 | 23 | 24 | class MutualInformationRecursionFunction(torch.autograd.Function): 25 | """A recursion that is useful in computing mutual information between two 26 | sequences of real vectors, but may be useful more generally in 27 | sequence-to-sequence tasks where monotonic alignment between pairs of 28 | sequences is desired. 29 | """ 30 | 31 | @staticmethod 32 | def forward( 33 | ctx, 34 | px: torch.Tensor, 35 | py: torch.Tensor, 36 | pxy_grads: List[Optional[torch.Tensor]], 37 | boundary: Optional[torch.Tensor] = None, 38 | return_grad: bool = False, 39 | ) -> torch.Tensor: 40 | """ 41 | Computing mutual information between two sequences of real vectors. 42 | Args: 43 | px: 44 | A torch.Tensor of some floating point type, with shape 45 | ``[B][S][T+1]`` where ``B`` is the batch size, ``S`` is the 46 | length of the ``x`` sequence (including representations of 47 | ``EOS`` symbols but not ``BOS`` symbols), and ``T`` is the 48 | length of the ``y`` sequence (including representations of 49 | ``EOS`` symbols but not ``BOS`` symbols). In the mutual 50 | information application, ``px[b][s][t]`` would represent the 51 | following log odds ratio; ignoring the b index on the right 52 | to make the notation more 53 | compact:: 54 | 55 | px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ] 56 | 57 | This expression also implicitly includes the log-probability of 58 | choosing to generate an ``x`` value as opposed to a ``y`` value. In 59 | practice it might be computed as ``a + b``, where ``a`` is the log 60 | probability of choosing to extend the sequence of length ``(s,t)`` 61 | with an ``x`` as opposed to a ``y`` value; and ``b`` might in 62 | practice be of the form:: 63 | 64 | log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t')) 65 | 66 | where ``N`` is the number of terms that the sum over ``t'`` 67 | included, which might include some or all of the other sequences as 68 | well as this one. 69 | 70 | Note: 71 | we don't require ``px`` and py to be contiguous, but the 72 | code assumes for optimization purposes that the ``T`` axis has 73 | stride 1. 74 | 75 | py: 76 | A torch.Tensor of the same dtype as ``px``, with shape 77 | ``[B][S+1][T]``, representing:: 78 | 79 | py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ] 80 | 81 | This function does not treat ``x`` and ``y`` differently; the only 82 | difference is that for optimization purposes we assume the last axis 83 | (the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` 84 | are contiguous. 85 | 86 | pxy_grads: 87 | A List to store the return grads of ``px`` and ``py`` 88 | if return_grad == True. 89 | Remain unchanged if return_grad == False. 90 | 91 | See `this PR ` for more 92 | information about why we add this parameter. 93 | 94 | Note: 95 | the length of the list must be 2, where the first element 96 | represents the grads of ``px`` and the second one represents 97 | the grads of ``py``. 98 | 99 | boundary: 100 | If supplied, a torch.LongTensor of shape ``[B][4]``, where each 101 | row contains ``[s_begin, t_begin, s_end, t_end]``, 102 | with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T`` 103 | (this implies that empty sequences are allowed). 104 | If not supplied, the values ``[0, 0, S, T]`` will be assumed. 105 | These are the beginning and one-past-the-last positions in the 106 | ``x`` and ``y`` sequences respectively, and can be used if not 107 | all sequences are 108 | of the same length. 109 | 110 | return_grad: 111 | Whether to return grads of ``px`` and ``py``, this grad standing 112 | for the occupation probability is the output of the backward with a 113 | ``fake gradient`` the ``fake gradient`` is the same as the gradient 114 | you'd get if you did 115 | ``torch.autograd.grad((scores.sum()), [px, py])``. 116 | This is useful to implement the pruned version of rnnt loss. 117 | 118 | Returns: 119 | Returns a torch.Tensor of shape ``[B]``, containing the log of 120 | the mutual information between the b'th pair of sequences. This is 121 | defined by the following recursion on ``p[b,s,t]`` (where ``p`` 122 | is of shape ``[B,S+1,T+1]``), representing a mutual information 123 | between sub-sequences of lengths ``s`` and ``t``:: 124 | 125 | p[b,0,0] = 0.0 126 | p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], 127 | p[b,s,t-1] + py[b,s,t-1]) 128 | (if s > 0 or t > 0) 129 | 130 | where we handle edge cases by treating quantities with negative 131 | indexes as **-infinity**. The extension to cases where the 132 | boundaries are specified should be obvious; it just works on 133 | shorter sequences with offsets into ``px`` and ``py``. 134 | """ 135 | (B, S, T1) = px.shape 136 | T = py.shape[-1] 137 | assert T1 in [T, T + 1] 138 | assert py.shape == (B, S + 1, T) 139 | if boundary is not None: 140 | assert boundary.shape == (B, 4) 141 | 142 | # p is a tensor of shape (B, S + 1, T + 1) were p[s][t] is the 143 | # the mutual information of the pair of subsequences of x and y that 144 | # are of length s and t respectively. p[0][0] will be 0.0 and p[S][T] 145 | # is the mutual information of the entire pair of sequences, 146 | # i.e. of lengths S and T respectively. 147 | # It is computed as follows (in C++ and CUDA): 148 | # p[b,0,0] = 0.0 149 | # p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], 150 | # p[b,s,t-1] + py[b,s,t-1]) 151 | # if s > 0 or t > 0, 152 | # treating values with any -1 index as -infinity. 153 | # .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0. 154 | 155 | p = torch.empty(B, S + 1, T + 1, device=px.device, dtype=px.dtype) 156 | 157 | ans = _fast_rnnt.mutual_information_forward(px, py, boundary, p) 158 | 159 | px_grad, py_grad = None, None 160 | if return_grad or px.requires_grad or py.requires_grad: 161 | ans_grad = torch.ones(B, device=px.device, dtype=px.dtype) 162 | (px_grad, py_grad) = _fast_rnnt.mutual_information_backward( 163 | px, py, boundary, p, ans_grad 164 | ) 165 | ctx.save_for_backward(px_grad, py_grad) 166 | assert len(pxy_grads) == 2 167 | pxy_grads[0] = px_grad 168 | pxy_grads[1] = py_grad 169 | 170 | return ans 171 | 172 | @staticmethod 173 | def backward( 174 | ctx, ans_grad: Tensor 175 | ) -> Tuple[torch.Tensor, torch.Tensor, None, None, None]: 176 | (px_grad, py_grad) = ctx.saved_tensors 177 | (B,) = ans_grad.shape 178 | ans_grad = ans_grad.reshape(B, 1, 1) # (B, 1, 1) 179 | px_grad *= ans_grad 180 | py_grad *= ans_grad 181 | return (px_grad, py_grad, None, None, None) 182 | 183 | 184 | def mutual_information_recursion( 185 | px: Tensor, 186 | py: Tensor, 187 | boundary: Optional[Tensor] = None, 188 | return_grad: bool = False, 189 | ) -> Union[Tuple[Tensor, Tuple[Tensor, Tensor]], Tensor]: 190 | """A recursion that is useful in computing mutual information between two 191 | sequences of real vectors, but may be useful more generally in 192 | sequence-to-sequence tasks where monotonic alignment between pairs of 193 | sequences is desired. The definitions of the arguments are definitions that 194 | would be used when computing this type of mutual information, but you can 195 | also view them as arbitrary quantities and just make use of the formula 196 | computed by this function. 197 | 198 | Args: 199 | px: 200 | A torch.Tensor of some floating point type, with shape ``[B][S][T+1]``, 201 | where ``B`` is the batch size, ``S`` is the length of the ``x`` sequence 202 | (including representations of ``EOS`` symbols but not ``BOS`` symbols), 203 | and ``T`` is the length of the ``y`` sequence (including representations 204 | of ``EOS`` symbols but not ``BOS`` symbols). In the mutual information 205 | application, ``px[b][s][t]`` would represent the following log odds 206 | ratio; ignoring the b index on the right to make the notation more 207 | compact:: 208 | 209 | px[b][s][t] = log [ p(x_s | x_{0..s-1}, y_{0..t-1}) / p(x_s) ] 210 | 211 | This expression also implicitly includes the log-probability of 212 | choosing to generate an ``x`` value as opposed to a ``y`` value. In 213 | practice it might be computed as ``a + b``, where ``a`` is the log 214 | probability of choosing to extend the sequence of length ``(s,t)`` 215 | with an ``x`` as opposed to a ``y`` value; and ``b`` might in practice 216 | be of the form:: 217 | 218 | log(N exp f(x_s, y_{t-1}) / sum_t' exp f(x_s, y_t')) 219 | 220 | where ``N`` is the number of terms that the sum over ``t'`` included, 221 | which might include some or all of the other sequences as well as this 222 | one. 223 | 224 | Note: 225 | we don't require ``px`` and py to be contiguous, but the 226 | code assumes for optimization purposes that the ``T`` axis has 227 | stride 1. 228 | 229 | py: 230 | A torch.Tensor of the same dtype as ``px``, with shape ``[B][S+1][T]``, 231 | representing:: 232 | 233 | py[b][s][t] = log [ p(y_t | x_{0..s-1}, y_{0..t-1}) / p(y_t) ] 234 | 235 | This function does not treat ``x`` and ``y`` differently; the only 236 | difference is that for optimization purposes we assume the last axis 237 | (the ``t`` axis) has stride of 1; this is true if ``px`` and ``py`` are 238 | contiguous. 239 | 240 | boundary: 241 | If supplied, a torch.LongTensor of shape ``[B][4]``, where each 242 | row contains ``[s_begin, t_begin, s_end, t_end]``, 243 | with ``0 <= s_begin <= s_end < S`` and ``0 <= t_begin <= t_end < T`` 244 | (this implies that empty sequences are allowed). 245 | If not supplied, the values ``[0, 0, S, T]`` will be assumed. 246 | These are the beginning and one-past-the-last positions in the ``x`` and 247 | ``y`` sequences respectively, and can be used if not all sequences are 248 | of the same length. 249 | 250 | return_grad: 251 | Whether to return grads of ``px`` and ``py``, this grad standing for the 252 | occupation probability is the output of the backward with a 253 | ``fake gradient`` the ``fake gradient`` is the same as the gradient 254 | you'd get if you did ``torch.autograd.grad((scores.sum()), [px, py])``. 255 | This is useful to implement the pruned version of rnnt loss. 256 | 257 | Returns: 258 | Returns a torch.Tensor of shape ``[B]``, containing the log of the mutual 259 | information between the b'th pair of sequences. This is defined by 260 | the following recursion on ``p[b,s,t]`` (where ``p`` is of shape 261 | ``[B,S+1,T+1]``), representing a mutual information between sub-sequences 262 | of lengths ``s`` and ``t``:: 263 | 264 | p[b,0,0] = 0.0 265 | if !modified: 266 | p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], 267 | p[b,s,t-1] + py[b,s,t-1]) 268 | if modified: 269 | p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1], 270 | p[b,s,t-1] + py[b,s,t-1]) 271 | 272 | where we handle edge cases by treating quantities with negative indexes 273 | as **-infinity**. The extension to cases where the boundaries are 274 | specified should be obvious; it just works on shorter sequences with 275 | offsets into ``px`` and ``py``. 276 | """ 277 | assert px.ndim == 3 278 | B, S, T1 = px.shape 279 | T = py.shape[-1] 280 | assert px.shape[-1] in [T, T + 1] # if T, then "modified". 281 | assert py.shape == (B, S + 1, T) 282 | assert px.dtype == py.dtype 283 | if boundary is not None: 284 | assert boundary.dtype == torch.int64 285 | assert boundary.shape == (B, 4) 286 | for s_begin, t_begin, s_end, t_end in boundary.tolist(): 287 | assert 0 <= s_begin <= s_end <= S 288 | assert 0 <= t_begin <= t_end <= T 289 | 290 | # The following statements are for efficiency 291 | px, py = px.contiguous(), py.contiguous() 292 | 293 | pxy_grads = [None, None] 294 | scores = MutualInformationRecursionFunction.apply( 295 | px, py, pxy_grads, boundary, return_grad 296 | ) 297 | px_grad, py_grad = pxy_grads 298 | return (scores, (px_grad, py_grad)) if return_grad else scores 299 | 300 | 301 | def _inner_product(a: Tensor, b: Tensor) -> Tensor: 302 | """ 303 | Does inner product on the last dimension, with expected broadcasting, 304 | i.e. equivalent to (a * b).sum(dim=-1) 305 | without creating a large temporary. 306 | """ 307 | assert a.shape[-1] == b.shape[-1] # The last dim must be equal 308 | a = a.unsqueeze(-2) # (..., 1, K) 309 | b = b.unsqueeze(-1) # (..., K, 1) 310 | c = torch.matmul(a, b) # (..., 1, 1) 311 | return c.squeeze(-1).squeeze(-1) 312 | 313 | 314 | def joint_mutual_information_recursion( 315 | px: Sequence[Tensor], 316 | py: Sequence[Tensor], 317 | boundary: Optional[Tensor] = None, 318 | ) -> Sequence[Tensor]: 319 | """A recursion that is useful for modifications of RNN-T and similar loss 320 | functions, where the recursion probabilities have a number of terms and you 321 | want them reported separately. See mutual_information_recursion() for more 322 | documentation of the basic aspects of this. 323 | 324 | Args: 325 | px: 326 | a sequence of Tensors, each of the same shape [B][S][T+1] 327 | py: 328 | a sequence of Tensor, each of the same shape [B][S+1][T], 329 | the sequence must be the same length as px. 330 | boundary: 331 | optionally, a LongTensor of shape [B][4] containing rows 332 | [s_begin, t_begin, s_end, t_end], with 0 <= s_begin <= s_end < S 333 | and 0 <= t_begin <= t_end < T, defaulting to [0, 0, S, T]. 334 | These are the beginning and one-past-the-last positions in the x 335 | and y sequences respectively, and can be used if not all 336 | sequences are of the same length. 337 | Returns: 338 | a Tensor of shape (len(px), B), 339 | whose sum over dim 0 is the total log-prob of the recursion mentioned 340 | below, per sequence. The first element of the sequence of length len(px) 341 | is "special", in that it has an offset term reflecting the difference 342 | between sum-of-log and log-of-sum; for more interpretable loss values, 343 | the "main" part of your loss function should be first. 344 | 345 | The recursion below applies if boundary == None, when it defaults 346 | to (0, 0, S, T); where px_sum, py_sum are the sums of the elements of px 347 | and py:: 348 | 349 | p = tensor of shape (B, S+1, T+1), containing -infinity 350 | p[b,0,0] = 0.0 351 | # do the following in loop over s and t: 352 | p[b,s,t] = log_add(p[b,s-1,t] + px_sum[b,s-1,t], 353 | p[b,s,t-1] + py_sum[b,s,t-1]) 354 | (if s > 0 or t > 0) 355 | return b[:][S][T] 356 | 357 | This function lets you implement the above recursion efficiently, except 358 | that it gives you a breakdown of the contribution from all the elements of 359 | px and py separately. As noted above, the first element of the 360 | sequence is "special". 361 | """ 362 | N = len(px) 363 | assert len(py) == N and N > 0 364 | B, S, T1 = px[0].shape 365 | T = py[0].shape[2] 366 | assert T1 in [T, T + 1] # T if modified... 367 | assert py[0].shape == (B, S + 1, T) 368 | assert px[0].dtype == py[0].dtype 369 | 370 | px_cat = torch.stack( 371 | px, dim=0 372 | ) # (N, B, S, T+1) if !modified,(N, B, S, T) if modified. 373 | py_cat = torch.stack(py, dim=0) # (N, B, S+1, T) 374 | px_tot = px_cat.sum(dim=0) # (B, S, T+1) 375 | py_tot = py_cat.sum(dim=0) # (B, S+1, T) 376 | 377 | if boundary is not None: 378 | assert boundary.dtype == torch.int64 379 | assert boundary.shape == (B, 4) 380 | for s_begin, t_begin, s_end, t_end in boundary.tolist(): 381 | assert 0 <= s_begin <= s_end <= S 382 | assert 0 <= t_begin <= t_end <= T 383 | 384 | # The following statements are for efficiency 385 | px_tot, py_tot = px_tot.contiguous(), py_tot.contiguous() 386 | 387 | assert px_tot.ndim == 3 388 | assert py_tot.ndim == 3 389 | 390 | p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype) 391 | 392 | # note, tot_probs is without grad. 393 | tot_probs = _fast_rnnt.mutual_information_forward( 394 | px_tot, py_tot, boundary, p 395 | ) 396 | 397 | # this is a kind of "fake gradient" that we use, in effect to compute 398 | # occupation probabilities. The backprop will work regardless of the 399 | # actual derivative w.r.t. the total probs. 400 | ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype) 401 | 402 | (px_grad, py_grad) = _fast_rnnt.mutual_information_backward( 403 | px_tot, py_tot, boundary, p, ans_grad 404 | ) 405 | 406 | px_grad = px_grad.reshape(1, B, -1) 407 | py_grad = py_grad.reshape(1, B, -1) 408 | px_cat = px_cat.reshape(N, B, -1) 409 | py_cat = py_cat.reshape(N, B, -1) 410 | # get rid of -inf, would generate nan on product with 0 411 | px_cat = px_cat.clamp(min=torch.finfo(px_cat.dtype).min) 412 | py_cat = py_cat.clamp(min=torch.finfo(py_cat.dtype).min) 413 | 414 | x_prods = _inner_product(px_grad, px_cat) # (N, B) 415 | y_prods = _inner_product(py_grad, py_cat) # (N, B) 416 | 417 | # If all the occupation counts were exactly 1.0 (i.e. no partial counts), 418 | # "prods" should be equal to "tot_probs"; however, in general, "tot_probs" 419 | # will be more positive due to the difference between log-of-sum and 420 | # sum-of-log 421 | prods = x_prods + y_prods # (N, B) 422 | with torch.no_grad(): 423 | offset = tot_probs - prods.sum(dim=0) # (B,) 424 | prods[0] += offset 425 | return prods # (N, B) 426 | -------------------------------------------------------------------------------- /fast_rnnt/python/tests/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | function(fast_rnnt_add_py_test source) 2 | get_filename_component(name ${source} NAME_WE) 3 | set(name "${name}_py") 4 | 5 | add_test(NAME ${name} 6 | COMMAND 7 | "${PYTHON_EXECUTABLE}" 8 | "${CMAKE_CURRENT_SOURCE_DIR}/${source}" 9 | ) 10 | 11 | get_filename_component(fast_rnnt_path ${CMAKE_CURRENT_LIST_DIR} DIRECTORY) 12 | 13 | set_property(TEST ${name} 14 | PROPERTY ENVIRONMENT "PYTHONPATH=${fast_rnnt_path}:$:$ENV{PYTHONPATH}" 15 | ) 16 | endfunction() 17 | 18 | # please sort the files in alphabetic order 19 | set(py_test_files 20 | mutual_information_test.py 21 | rnnt_loss_test.py 22 | ) 23 | 24 | foreach(source IN LISTS py_test_files) 25 | fast_rnnt_add_py_test(${source}) 26 | endforeach() 27 | -------------------------------------------------------------------------------- /fast_rnnt/python/tests/mutual_information_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright 2021 Xiaomi Corporation (authors: Daniel Povey, 4 | # Wei Kang) 5 | # 6 | # See ../../../LICENSE for clarification regarding multiple authors 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | # To run this single test, use 21 | # 22 | # ctest --verbose -R mutual_information_test_py 23 | 24 | import random 25 | import unittest 26 | 27 | import fast_rnnt 28 | import torch 29 | 30 | 31 | # Caution: this will fail occasionally due to cutoffs not being quite large 32 | # enough. As long as it passes most of the time, it's OK. 33 | class TestMutualInformation(unittest.TestCase): 34 | @classmethod 35 | def setUpClass(cls): 36 | cls.devices = [torch.device("cpu")] 37 | if torch.cuda.is_available() and fast_rnnt.with_cuda(): 38 | cls.devices.append(torch.device("cuda", 0)) 39 | if torch.cuda.device_count() > 1: 40 | torch.cuda.set_device(1) 41 | cls.devices.append(torch.device("cuda", 1)) 42 | cls.dtypes = [torch.float32, torch.float64] 43 | 44 | def test_mutual_information_basic(self): 45 | for _iter in range(10): 46 | (B, S, T) = ( 47 | random.randint(1, 10), 48 | random.randint(1, 16), 49 | random.randint(1, 500), 50 | ) 51 | random_px = random.random() < 0.2 52 | random_py = random.random() < 0.2 53 | random_boundary = random.random() < 0.7 54 | big_px = random.random() < 0.2 55 | big_py = random.random() < 0.2 56 | 57 | modified = random.random() < 0.5 58 | 59 | if modified and T < S: 60 | T = S + random.randint(0, 30) 61 | 62 | for dtype in self.dtypes: 63 | for device in self.devices: 64 | if random_boundary: 65 | 66 | def get_boundary_row(): 67 | this_S = random.randint( 68 | 0, S 69 | ) # allow empty sequence 70 | this_T = random.randint( 71 | this_S if modified else 1, T 72 | ) 73 | s_begin = random.randint(0, S - this_S) 74 | t_begin = random.randint(0, T - this_T) 75 | s_end = s_begin + this_S 76 | t_end = t_begin + this_T 77 | return [s_begin, t_begin, s_end, t_end] 78 | 79 | if device == torch.device("cpu"): 80 | boundary = torch.tensor( 81 | [get_boundary_row() for _ in range(B)], 82 | dtype=torch.int64, 83 | device=device, 84 | ) 85 | else: 86 | boundary = boundary.to(device) 87 | else: 88 | # Use default boundary, but either specified directly 89 | # or not. 90 | if random.random() < 0.5: 91 | boundary = ( 92 | torch.tensor([0, 0, S, T], dtype=torch.int64) 93 | .unsqueeze(0) 94 | .expand(B, 4) 95 | .to(device) 96 | ) 97 | else: 98 | boundary = None 99 | 100 | if device == torch.device("cpu"): 101 | if random_px: 102 | # log of an odds ratio 103 | px = torch.randn( 104 | B, S, T + (0 if modified else 1), dtype=dtype 105 | ).to(device) 106 | if S > 1 and not random_boundary and not modified: 107 | px[:, :, -1:] = float("-inf") 108 | else: 109 | # log of an odds ratio 110 | px = torch.zeros( 111 | B, S, T + (0 if modified else 1), dtype=dtype 112 | ).to(device) 113 | # px and py get exponentiated, and then multiplied 114 | # together up to 32 times (BLOCK_SIZE in the CUDA code), 115 | # so 15 is actually a big number that could lead to 116 | # overflow. 117 | if big_px: 118 | px += 15.0 119 | if random_py: 120 | # log of an odds ratio 121 | py = torch.randn(B, S + 1, T, dtype=dtype).to( 122 | device 123 | ) 124 | else: 125 | # log of an odds ratio 126 | py = torch.zeros(B, S + 1, T, dtype=dtype).to( 127 | device 128 | ) 129 | if big_py: 130 | py += 15.0 131 | 132 | else: 133 | px = px.to(device).detach() 134 | py = py.to(device).detach() 135 | px.requires_grad = True 136 | py.requires_grad = True 137 | 138 | m = fast_rnnt.mutual_information_recursion(px, py, boundary) 139 | 140 | m2 = fast_rnnt.joint_mutual_information_recursion( 141 | (px,), (py,), boundary 142 | ) 143 | 144 | m3 = fast_rnnt.joint_mutual_information_recursion( 145 | (px * 0.5, px * 0.5), (py * 0.5, py * 0.5), boundary 146 | ) 147 | 148 | # it is supposed to be identical only after 149 | # summing over dim 0, corresponding to the 150 | # sequence dim 151 | m3 = m3.sum(dim=0) 152 | 153 | assert torch.allclose(m, m2) 154 | assert torch.allclose(m, m3) 155 | 156 | # the loop this is in checks that the CPU and CUDA versions 157 | # give the same derivative; 158 | # by randomizing which of m, m2 or m3 we backprop, we also 159 | # ensure that the joint version of the code gives the same 160 | # derivative as the regular version 161 | scale = 3 162 | if random.random() < 0.5: 163 | (m.sum() * scale).backward() 164 | elif random.random() < 0.5: 165 | (m2.sum() * scale).backward() 166 | else: 167 | (m3.sum() * scale).backward() 168 | 169 | if device == torch.device("cpu"): 170 | expected_px_grad = px.grad 171 | expected_py_grad = py.grad 172 | expected_m = m 173 | assert torch.allclose( 174 | px.grad, 175 | expected_px_grad.to(device), 176 | atol=1.0e-02, 177 | rtol=1.0e-02, 178 | ) 179 | assert torch.allclose( 180 | py.grad, 181 | expected_py_grad.to(device), 182 | atol=1.0e-02, 183 | rtol=1.0e-02, 184 | ) 185 | assert torch.allclose( 186 | m, expected_m.to(device), atol=1.0e-02, rtol=1.0e-02 187 | ) 188 | 189 | def test_mutual_information_deriv(self): 190 | for _iter in range(10): 191 | (B, S, T) = ( 192 | random.randint(1, 100), 193 | random.randint(1, 200), 194 | random.randint(1, 200), 195 | ) 196 | random_px = random.random() < 0.2 197 | random_py = random.random() < 0.2 198 | random_boundary = random.random() < 0.7 199 | big_px = random.random() < 0.2 200 | big_py = random.random() < 0.2 201 | 202 | modified = random.random() < 0.5 203 | 204 | if modified and T < S: 205 | T = S + random.randint(0, 30) 206 | 207 | for dtype in self.dtypes: 208 | for device in self.devices: 209 | if random_boundary: 210 | 211 | def get_boundary_row(): 212 | this_S = random.randint(1, S) 213 | this_T = random.randint( 214 | this_S if modified else 1, T 215 | ) 216 | s_begin = random.randint(0, S - this_S) 217 | t_begin = random.randint(0, T - this_T) 218 | s_end = s_begin + this_S 219 | t_end = t_begin + this_T 220 | return [s_begin, t_begin, s_end, t_end] 221 | 222 | if device == torch.device("cpu"): 223 | boundary = torch.tensor( 224 | [get_boundary_row() for _ in range(B)], 225 | dtype=torch.int64, 226 | device=device, 227 | ) 228 | else: 229 | boundary = boundary.to(device) 230 | else: 231 | # Use default boundary, but either specified directly 232 | # or not. 233 | if random.random() < 0.5: 234 | boundary = ( 235 | torch.tensor([0, 0, S, T], dtype=torch.int64) 236 | .unsqueeze(0) 237 | .expand(B, 4) 238 | .to(device) 239 | ) 240 | else: 241 | boundary = None 242 | 243 | T1 = T + (0 if modified else 1) 244 | if device == torch.device("cpu"): 245 | if random_px: 246 | # log of an odds ratio 247 | px = torch.randn(B, S, T1, dtype=dtype).to(device) 248 | else: 249 | # log of an odds ratio 250 | px = torch.zeros(B, S, T1, dtype=dtype).to(device) 251 | # px and py get exponentiated, and then multiplied 252 | # together up to 32 times (BLOCK_SIZE in the CUDA code), 253 | # so 15 is actually a big number that could lead to 254 | # overflow. 255 | if big_px: 256 | px += 15.0 257 | if random_py: 258 | # log of an odds ratio 259 | py = torch.randn(B, S + 1, T, dtype=dtype).to( 260 | device 261 | ) 262 | else: 263 | # log of an odds ratio 264 | py = torch.zeros(B, S + 1, T, dtype=dtype).to( 265 | device 266 | ) 267 | if big_py: 268 | py += 15.0 269 | else: 270 | px = px.to(device).detach() 271 | py = py.to(device).detach() 272 | px.requires_grad = True 273 | py.requires_grad = True 274 | 275 | m = fast_rnnt.mutual_information_recursion(px, py, boundary) 276 | 277 | m_grad = torch.randn(B, dtype=dtype, device=device) 278 | m.backward(gradient=m_grad) 279 | delta = 1.0e-04 280 | delta_px = delta * torch.randn_like(px) 281 | m2 = fast_rnnt.mutual_information_recursion( 282 | px + delta_px, py, boundary 283 | ) 284 | delta_m = m2 - m 285 | observed_delta = (delta_m * m_grad).sum().to("cpu") 286 | predicted_delta = (delta_px * px.grad).sum().to("cpu") 287 | 288 | atol = 1.0e-02 if dtype == torch.float32 else 1.0e-04 289 | rtol = 1.0e-02 if dtype == torch.float32 else 1.0e-04 290 | 291 | assert torch.allclose( 292 | observed_delta, predicted_delta, atol=atol, rtol=rtol 293 | ) 294 | 295 | delta_py = delta * torch.randn_like(py) 296 | m2 = fast_rnnt.mutual_information_recursion( 297 | px, py + delta_py, boundary 298 | ) 299 | delta_m = m2 - m 300 | observed_delta = (delta_m * m_grad).sum().to("cpu") 301 | predicted_delta = (delta_py * py.grad).sum().to("cpu") 302 | 303 | assert torch.allclose( 304 | observed_delta, predicted_delta, atol=atol, rtol=rtol 305 | ) 306 | 307 | 308 | if __name__ == "__main__": 309 | unittest.main() 310 | -------------------------------------------------------------------------------- /fast_rnnt/python/tests/rnnt_loss_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright 2021 Xiaomi Corporation (authors: Daniel Povey, 4 | # Wei Kang) 5 | # 6 | # See ../../../LICENSE for clarification regarding multiple authors 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | # To run this single test, use 21 | # 22 | # ctest --verbose -R rnnt_loss_test_py 23 | 24 | import unittest 25 | 26 | import fast_rnnt 27 | import random 28 | import torch 29 | 30 | 31 | class TestRnntLoss(unittest.TestCase): 32 | @classmethod 33 | def setUpClass(cls): 34 | cls.devices = [torch.device("cpu")] 35 | if torch.cuda.is_available() and fast_rnnt.with_cuda(): 36 | cls.devices.append(torch.device("cuda", 0)) 37 | if torch.cuda.device_count() > 1: 38 | torch.cuda.set_device(1) 39 | cls.devices.append(torch.device("cuda", 1)) 40 | try: 41 | import torchaudio 42 | import torchaudio.functional 43 | 44 | if hasattr(torchaudio.functional, "rnnt_loss"): 45 | cls.has_torch_rnnt_loss = True 46 | else: 47 | cls.has_torch_rnnt_loss = False 48 | print( 49 | f"Current torchaudio version: {torchaudio.__version__}\n" 50 | "Skipping the tests of comparing rnnt loss with torch " 51 | "one, to enable these tests please install a " 52 | "version >= 0.10.0" 53 | ) 54 | except ImportError as e: 55 | cls.has_torch_rnnt_loss = False 56 | print( 57 | f"Import torchaudio error, error message: {e}\n" 58 | "Skipping the tests of comparing rnnt loss with torch " 59 | "one, to enable these tests, please install torchaudio " 60 | "with version >= 0.10.0" 61 | ) 62 | 63 | def test_rnnt_loss_basic(self): 64 | B = 1 65 | S = 3 66 | T = 4 67 | # C = 3 68 | for device in self.devices: 69 | # lm: [B][S+1][C] 70 | lm = torch.tensor( 71 | [[[0, 0, 1], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], 72 | dtype=torch.float, 73 | device=device, 74 | ) 75 | # am: [B][T][C] 76 | am = torch.tensor( 77 | [[[0, 1, 2], [0, 0, 0], [0, 2, 4], [0, 3, 3]]], 78 | dtype=torch.float, 79 | device=device, 80 | ) 81 | termination_symbol = 2 82 | symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device) 83 | 84 | px, py = fast_rnnt.get_rnnt_logprobs( 85 | lm=lm, 86 | am=am, 87 | symbols=symbols, 88 | termination_symbol=termination_symbol, 89 | ) 90 | assert px.shape == (B, S, T + 1) 91 | assert py.shape == (B, S + 1, T) 92 | assert symbols.shape == (B, S) 93 | m = fast_rnnt.mutual_information_recursion( 94 | px=px, py=py, boundary=None 95 | ) 96 | 97 | if device == torch.device("cpu"): 98 | expected = -m 99 | assert torch.allclose(-m, expected.to(device)) 100 | 101 | # test rnnt_loss_simple 102 | m = fast_rnnt.rnnt_loss_simple( 103 | lm=lm, 104 | am=am, 105 | symbols=symbols, 106 | termination_symbol=termination_symbol, 107 | boundary=None, 108 | reduction="none", 109 | ) 110 | assert torch.allclose(m, expected.to(device)) 111 | 112 | # test rnnt_loss_smoothed 113 | m = fast_rnnt.rnnt_loss_smoothed( 114 | lm=lm, 115 | am=am, 116 | symbols=symbols, 117 | termination_symbol=termination_symbol, 118 | lm_only_scale=0.0, 119 | am_only_scale=0.0, 120 | boundary=None, 121 | reduction="none", 122 | ) 123 | assert torch.allclose(m, expected.to(device)) 124 | 125 | logits = am.unsqueeze(2) + lm.unsqueeze(1) 126 | 127 | # test rnnt_loss 128 | m = fast_rnnt.rnnt_loss( 129 | logits=logits, 130 | symbols=symbols, 131 | termination_symbol=termination_symbol, 132 | boundary=None, 133 | reduction="none", 134 | ) 135 | assert torch.allclose(m, expected.to(device)) 136 | 137 | # compare with torchaudio rnnt_loss 138 | if self.has_torch_rnnt_loss: 139 | import torchaudio.functional 140 | 141 | m = torchaudio.functional.rnnt_loss( 142 | logits=logits, 143 | targets=symbols.int(), 144 | logit_lengths=torch.tensor( 145 | [T] * B, dtype=torch.int32, device=device 146 | ), 147 | target_lengths=torch.tensor( 148 | [S] * B, dtype=torch.int32, device=device 149 | ), 150 | blank=termination_symbol, 151 | reduction="none", 152 | ) 153 | assert torch.allclose(m, expected.to(device)) 154 | 155 | # should be invariant to adding a constant for any frame. 156 | lm += torch.randn(B, S + 1, 1, device=device) 157 | am += torch.randn(B, T, 1, device=device) 158 | 159 | m = fast_rnnt.rnnt_loss_simple( 160 | lm=lm, 161 | am=am, 162 | symbols=symbols, 163 | termination_symbol=termination_symbol, 164 | boundary=None, 165 | reduction="none", 166 | ) 167 | assert torch.allclose(m, expected.to(device)) 168 | 169 | m = fast_rnnt.rnnt_loss_smoothed( 170 | lm=lm, 171 | am=am, 172 | symbols=symbols, 173 | termination_symbol=termination_symbol, 174 | lm_only_scale=0.0, 175 | am_only_scale=0.0, 176 | boundary=None, 177 | reduction="none", 178 | ) 179 | assert torch.allclose(m, expected.to(device)) 180 | 181 | logits = am.unsqueeze(2) + lm.unsqueeze(1) 182 | m = fast_rnnt.rnnt_loss( 183 | logits=logits, 184 | symbols=symbols, 185 | termination_symbol=termination_symbol, 186 | boundary=None, 187 | reduction="none", 188 | ) 189 | assert torch.allclose(m, expected.to(device)) 190 | 191 | def test_rnnt_loss_random(self): 192 | B = 5 193 | S = 20 194 | T = 300 195 | C = 100 196 | frames = torch.randint(S, T, (B,)) 197 | seq_length = torch.randint(3, S - 1, (B,)) 198 | T = torch.max(frames) 199 | S = torch.max(seq_length) 200 | 201 | am_ = torch.randn((B, T, C), dtype=torch.float32) 202 | lm_ = torch.randn((B, S + 1, C), dtype=torch.float32) 203 | symbols_ = torch.randint(0, C - 1, (B, S)) 204 | termination_symbol = C - 1 205 | 206 | boundary_ = torch.zeros((B, 4), dtype=torch.int64) 207 | boundary_[:, 2] = seq_length 208 | boundary_[:, 3] = frames 209 | 210 | for rnnt_type in ["regular", "modified", "constrained"]: 211 | for device in self.devices: 212 | # lm: [B][S+1][C] 213 | lm = lm_.to(device) 214 | # am: [B][T][C] 215 | am = am_.to(device) 216 | symbols = symbols_.to(device) 217 | boundary = boundary_.to(device) 218 | 219 | px, py = fast_rnnt.get_rnnt_logprobs( 220 | lm=lm, 221 | am=am, 222 | symbols=symbols, 223 | termination_symbol=termination_symbol, 224 | boundary=boundary, 225 | rnnt_type=rnnt_type, 226 | ) 227 | assert ( 228 | px.shape == (B, S, T) 229 | if rnnt_type != "regular" 230 | else (B, S, T + 1) 231 | ) 232 | assert py.shape == (B, S + 1, T) 233 | assert symbols.shape == (B, S) 234 | m = fast_rnnt.mutual_information_recursion( 235 | px=px, py=py, boundary=boundary 236 | ) 237 | 238 | if device == torch.device("cpu"): 239 | expected = -torch.mean(m) 240 | assert torch.allclose(-torch.mean(m), expected.to(device)) 241 | 242 | m = fast_rnnt.rnnt_loss_simple( 243 | lm=lm, 244 | am=am, 245 | symbols=symbols, 246 | termination_symbol=termination_symbol, 247 | boundary=boundary, 248 | rnnt_type=rnnt_type, 249 | ) 250 | assert torch.allclose(m, expected.to(device)) 251 | 252 | m = fast_rnnt.rnnt_loss_smoothed( 253 | lm=lm, 254 | am=am, 255 | symbols=symbols, 256 | termination_symbol=termination_symbol, 257 | lm_only_scale=0.0, 258 | am_only_scale=0.0, 259 | boundary=boundary, 260 | rnnt_type=rnnt_type, 261 | ) 262 | assert torch.allclose(m, expected.to(device)) 263 | 264 | logits = am.unsqueeze(2) + lm.unsqueeze(1) 265 | m = fast_rnnt.rnnt_loss( 266 | logits=logits, 267 | symbols=symbols, 268 | termination_symbol=termination_symbol, 269 | boundary=boundary, 270 | rnnt_type=rnnt_type, 271 | ) 272 | assert torch.allclose(m, expected.to(device)) 273 | 274 | # compare with torchaudio rnnt_loss 275 | if self.has_torch_rnnt_loss and rnnt_type == "regular": 276 | import torchaudio.functional 277 | 278 | m = torchaudio.functional.rnnt_loss( 279 | logits=logits, 280 | targets=symbols.int(), 281 | logit_lengths=boundary[:, 3].int(), 282 | target_lengths=boundary[:, 2].int(), 283 | blank=termination_symbol, 284 | ) 285 | assert torch.allclose(m, expected.to(device)) 286 | 287 | # should be invariant to adding a constant for any frame. 288 | lm += torch.randn(B, S + 1, 1, device=device) 289 | am += torch.randn(B, T, 1, device=device) 290 | 291 | m = fast_rnnt.rnnt_loss_simple( 292 | lm=lm, 293 | am=am, 294 | symbols=symbols, 295 | termination_symbol=termination_symbol, 296 | boundary=boundary, 297 | rnnt_type=rnnt_type, 298 | ) 299 | assert torch.allclose(m, expected.to(device)) 300 | 301 | logits = am.unsqueeze(2) + lm.unsqueeze(1) 302 | m = fast_rnnt.rnnt_loss( 303 | logits=logits, 304 | symbols=symbols, 305 | termination_symbol=termination_symbol, 306 | boundary=boundary, 307 | rnnt_type=rnnt_type, 308 | ) 309 | assert torch.allclose(m, expected.to(device)) 310 | 311 | m = fast_rnnt.rnnt_loss_smoothed( 312 | lm=lm, 313 | am=am, 314 | symbols=symbols, 315 | termination_symbol=termination_symbol, 316 | lm_only_scale=0.0, 317 | am_only_scale=0.0, 318 | boundary=boundary, 319 | rnnt_type=rnnt_type, 320 | ) 321 | assert torch.allclose(m, expected.to(device)) 322 | 323 | def test_rnnt_loss_gradient(self): 324 | if self.has_torch_rnnt_loss: 325 | import torchaudio.functional 326 | 327 | B = 5 328 | S = 20 329 | T = 300 330 | C = 100 331 | frames = torch.randint(S, T, (B,)) 332 | seq_length = torch.randint(3, S - 1, (B,)) 333 | T = torch.max(frames) 334 | S = torch.max(seq_length) 335 | 336 | am_ = torch.randn((B, T, C), dtype=torch.float32) 337 | lm_ = torch.randn((B, S + 1, C), dtype=torch.float32) 338 | symbols_ = torch.randint(0, C - 1, (B, S)) 339 | termination_symbol = C - 1 340 | 341 | boundary_ = torch.zeros((B, 4), dtype=torch.int64) 342 | boundary_[:, 2] = seq_length 343 | boundary_[:, 3] = frames 344 | 345 | for device in self.devices: 346 | # lm: [B][S+1][C] 347 | lm = lm_.to(device) 348 | # am: [B][T][C] 349 | am = am_.to(device) 350 | symbols = symbols_.to(device) 351 | boundary = boundary_.to(device) 352 | 353 | logits = am.unsqueeze(2) + lm.unsqueeze(1) 354 | logits.requires_grad_() 355 | fast_loss = fast_rnnt.rnnt_loss( 356 | logits=logits, 357 | symbols=symbols, 358 | termination_symbol=termination_symbol, 359 | boundary=boundary, 360 | ) 361 | fast_grad = torch.autograd.grad(fast_loss, logits) 362 | fast_grad = fast_grad[0] 363 | 364 | logits2 = logits.detach().clone().float() 365 | logits2.requires_grad_() 366 | torch_loss = torchaudio.functional.rnnt_loss( 367 | logits=logits2, 368 | targets=symbols.int(), 369 | logit_lengths=boundary[:, 3].int(), 370 | target_lengths=boundary[:, 2].int(), 371 | blank=termination_symbol, 372 | ) 373 | torch_grad = torch.autograd.grad(torch_loss, logits2) 374 | torch_grad = torch_grad[0] 375 | 376 | assert torch.allclose( 377 | fast_loss, torch_loss, atol=1e-2, rtol=1e-2 378 | ) 379 | 380 | assert torch.allclose( 381 | fast_grad, torch_grad, atol=1e-2, rtol=1e-2 382 | ) 383 | 384 | def test_rnnt_loss_smoothed(self): 385 | B = 1 386 | S = 3 387 | T = 4 388 | # C = 3 389 | for device in self.devices: 390 | # lm: [B][S+1][C] 391 | lm = torch.tensor( 392 | [[[0, 0, 1], [0, 1, 1], [1, 0, 1], [2, 2, 0]]], 393 | dtype=torch.float, 394 | device=device, 395 | ) 396 | # am: [B][T][C] 397 | am = torch.tensor( 398 | [[[0, 1, 2], [0, 0, 0], [0, 2, 4], [0, 3, 3]]], 399 | dtype=torch.float, 400 | device=device, 401 | ) 402 | 403 | termination_symbol = 2 404 | symbols = torch.tensor([[0, 1, 0]], dtype=torch.long, device=device) 405 | 406 | m = fast_rnnt.rnnt_loss_smoothed( 407 | lm=lm, 408 | am=am, 409 | symbols=symbols, 410 | termination_symbol=termination_symbol, 411 | lm_only_scale=0.0, 412 | am_only_scale=0.333, 413 | boundary=None, 414 | ) 415 | 416 | if device == torch.device("cpu"): 417 | expected = m 418 | assert torch.allclose(m, expected.to(device)) 419 | 420 | # should be invariant to adding a constant for any frame. 421 | lm += torch.randn(B, S + 1, 1, device=device) 422 | am += torch.randn(B, T, 1, device=device) 423 | 424 | m = fast_rnnt.rnnt_loss_smoothed( 425 | lm=lm, 426 | am=am, 427 | symbols=symbols, 428 | termination_symbol=termination_symbol, 429 | lm_only_scale=0.0, 430 | am_only_scale=0.333, 431 | boundary=None, 432 | ) 433 | assert torch.allclose(m, expected.to(device)) 434 | 435 | def test_rnnt_loss_pruned(self): 436 | B = 4 437 | T = 300 438 | S = 50 439 | C = 10 440 | 441 | frames = torch.randint(S, T, (B,)) 442 | seq_length = torch.randint(3, S - 1, (B,)) 443 | T = torch.max(frames) 444 | S = torch.max(seq_length) 445 | 446 | am_ = torch.randn((B, T, C), dtype=torch.float64) 447 | lm_ = torch.randn((B, S + 1, C), dtype=torch.float64) 448 | symbols_ = torch.randint(0, C - 1, (B, S)) 449 | terminal_symbol = C - 1 450 | 451 | boundary_ = torch.zeros((B, 4), dtype=torch.int64) 452 | boundary_[:, 2] = seq_length 453 | boundary_[:, 3] = frames 454 | 455 | for rnnt_type in ["regular", "modified", "constrained"]: 456 | for device in self.devices: 457 | # normal rnnt 458 | am = am_.to(device) 459 | lm = lm_.to(device) 460 | symbols = symbols_.to(device) 461 | boundary = boundary_.to(device) 462 | logits = am.unsqueeze(2) + lm.unsqueeze(1) 463 | logits = logits.float() 464 | 465 | # nonlinear transform 466 | logits = torch.sigmoid(logits) 467 | fast_loss = fast_rnnt.rnnt_loss( 468 | logits=logits, 469 | symbols=symbols, 470 | termination_symbol=terminal_symbol, 471 | boundary=boundary, 472 | rnnt_type=rnnt_type, 473 | ) 474 | 475 | print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {fast_loss}") 476 | 477 | # pruning 478 | simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( 479 | lm=lm, 480 | am=am, 481 | symbols=symbols, 482 | termination_symbol=terminal_symbol, 483 | boundary=boundary, 484 | rnnt_type=rnnt_type, 485 | return_grad=True, 486 | reduction="none", 487 | ) 488 | 489 | for r in range(2, 50, 5): 490 | ranges = fast_rnnt.get_rnnt_prune_ranges( 491 | px_grad=px_grad, 492 | py_grad=py_grad, 493 | boundary=boundary, 494 | s_range=r, 495 | ) 496 | # (B, T, r, C) 497 | pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning( 498 | am=am, lm=lm, ranges=ranges 499 | ) 500 | 501 | logits = pruned_am + pruned_lm 502 | 503 | # nonlinear transform 504 | logits = torch.sigmoid(logits) 505 | 506 | pruned_loss = fast_rnnt.rnnt_loss_pruned( 507 | logits=logits, 508 | symbols=symbols, 509 | ranges=ranges, 510 | termination_symbol=terminal_symbol, 511 | boundary=boundary, 512 | rnnt_type=rnnt_type, 513 | reduction="none", 514 | ) 515 | print(f"Pruning loss with range {r} : {pruned_loss}") 516 | 517 | # Test the sequences that only have small number of symbols, 518 | # at this circumstance, the s_range would be greater than S, which will 519 | # raise errors (like, nan or inf loss) in our previous versions. 520 | def test_rnnt_loss_pruned_small_symbols_number(self): 521 | B = 2 522 | T = 20 523 | S = 3 524 | C = 10 525 | 526 | frames = torch.randint(S + 1, T, (B,)) 527 | seq_lengths = torch.randint(1, S, (B,)) 528 | T = torch.max(frames) 529 | S = torch.max(seq_lengths) 530 | 531 | am_ = torch.randn((B, T, C), dtype=torch.float64) 532 | lm_ = torch.randn((B, S + 1, C), dtype=torch.float64) 533 | symbols_ = torch.randint(0, C, (B, S)) 534 | terminal_symbol = C - 1 535 | 536 | boundary_ = torch.zeros((B, 4), dtype=torch.int64) 537 | boundary_[:, 2] = seq_lengths 538 | boundary_[:, 3] = frames 539 | 540 | print(f"B = {B}, T = {T}, S = {S}, C = {C}") 541 | 542 | for rnnt_type in ["regular", "modified", "constrained"]: 543 | for device in self.devices: 544 | # normal rnnt 545 | am = am_.to(device) 546 | lm = lm_.to(device) 547 | symbols = symbols_.to(device) 548 | boundary = boundary_.to(device) 549 | 550 | logits = am.unsqueeze(2) + lm.unsqueeze(1) 551 | logits = logits.float() 552 | 553 | # nonlinear transform 554 | logits = torch.sigmoid(logits) 555 | 556 | loss = fast_rnnt.rnnt_loss( 557 | logits=logits, 558 | symbols=symbols, 559 | termination_symbol=terminal_symbol, 560 | boundary=boundary, 561 | rnnt_type=rnnt_type, 562 | reduction="none", 563 | ) 564 | 565 | print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}") 566 | 567 | # pruning 568 | simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( 569 | lm=lm, 570 | am=am, 571 | symbols=symbols, 572 | termination_symbol=terminal_symbol, 573 | boundary=boundary, 574 | rnnt_type=rnnt_type, 575 | return_grad=True, 576 | reduction="none", 577 | ) 578 | 579 | S0 = 2 580 | if rnnt_type != "regular": 581 | S0 = 1 582 | 583 | for r in range(S0, S + 2): 584 | ranges = fast_rnnt.get_rnnt_prune_ranges( 585 | px_grad=px_grad, 586 | py_grad=py_grad, 587 | boundary=boundary, 588 | s_range=r, 589 | ) 590 | # (B, T, r, C) 591 | pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning( 592 | am=am, lm=lm, ranges=ranges 593 | ) 594 | 595 | logits = pruned_am + pruned_lm 596 | 597 | # nonlinear transform 598 | logits = torch.sigmoid(logits) 599 | 600 | pruned_loss = fast_rnnt.rnnt_loss_pruned( 601 | logits=logits, 602 | symbols=symbols, 603 | ranges=ranges, 604 | termination_symbol=terminal_symbol, 605 | boundary=boundary, 606 | rnnt_type=rnnt_type, 607 | reduction="none", 608 | ) 609 | print(f"Pruned loss with range {r} : {pruned_loss}") 610 | 611 | # Test low s_range values with large S and small T, 612 | # at this circumstance, the s_range would not be enough 613 | # to cover the whole sequence length (in regular rnnt mode) 614 | # and would result in inf loss 615 | def test_rnnt_loss_pruned_small_s_range(self): 616 | B = 2 617 | T = 2 618 | S = 10 619 | C = 10 620 | 621 | frames = torch.randint(1, T, (B,)) 622 | seq_lengths = torch.randint(1, S, (B,)) 623 | T = torch.max(frames) 624 | S = torch.max(seq_lengths) 625 | 626 | am_ = torch.randn((B, T, C), dtype=torch.float64) 627 | lm_ = torch.randn((B, S + 1, C), dtype=torch.float64) 628 | symbols_ = torch.randint(0, C, (B, S)) 629 | terminal_symbol = C - 1 630 | 631 | boundary_ = torch.zeros((B, 4), dtype=torch.int64) 632 | boundary_[:, 2] = seq_lengths 633 | boundary_[:, 3] = frames 634 | 635 | print(f"B = {B}, T = {T}, S = {S}, C = {C}") 636 | 637 | for rnnt_type in ["regular"]: 638 | for device in self.devices: 639 | # normal rnnt 640 | am = am_.to(device) 641 | lm = lm_.to(device) 642 | symbols = symbols_.to(device) 643 | boundary = boundary_.to(device) 644 | 645 | logits = am.unsqueeze(2) + lm.unsqueeze(1) 646 | logits = logits.float() 647 | 648 | # nonlinear transform 649 | logits = torch.sigmoid(logits) 650 | 651 | loss = fast_rnnt.rnnt_loss( 652 | logits=logits, 653 | symbols=symbols, 654 | termination_symbol=terminal_symbol, 655 | boundary=boundary, 656 | rnnt_type=rnnt_type, 657 | reduction="none", 658 | ) 659 | 660 | print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {loss}") 661 | 662 | # pruning 663 | simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( 664 | lm=lm, 665 | am=am, 666 | symbols=symbols, 667 | termination_symbol=terminal_symbol, 668 | boundary=boundary, 669 | rnnt_type=rnnt_type, 670 | return_grad=True, 671 | reduction="none", 672 | ) 673 | 674 | S0 = 2 675 | 676 | for r in range(S0, S + 2): 677 | ranges = fast_rnnt.get_rnnt_prune_ranges( 678 | px_grad=px_grad, 679 | py_grad=py_grad, 680 | boundary=boundary, 681 | s_range=r, 682 | ) 683 | # (B, T, r, C) 684 | pruned_am, pruned_lm = fast_rnnt.do_rnnt_pruning( 685 | am=am, lm=lm, ranges=ranges 686 | ) 687 | 688 | logits = pruned_am + pruned_lm 689 | 690 | # nonlinear transform 691 | logits = torch.sigmoid(logits) 692 | 693 | pruned_loss = fast_rnnt.rnnt_loss_pruned( 694 | logits=logits, 695 | symbols=symbols, 696 | ranges=ranges, 697 | termination_symbol=terminal_symbol, 698 | boundary=boundary, 699 | rnnt_type=rnnt_type, 700 | reduction="none", 701 | ) 702 | assert ( 703 | not pruned_loss.isinf().any() 704 | ), f"Pruned loss is inf for r={r}, S={S}, T={T}: {pruned_loss}" 705 | print(f"Pruned loss with range {r} : {pruned_loss}") 706 | 707 | 708 | if __name__ == "__main__": 709 | unittest.main() 710 | -------------------------------------------------------------------------------- /package.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ROOT=$(realpath $(dirname $0)) 4 | cd ${ROOT} 5 | 6 | python3 -m pip install --upgrade setuptools wheel twine 7 | python3 setup.py sdist 8 | echo "Inspect dist and upload with \"python3 -m twine upload dist/*\"" 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.5 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # Copyright (c) 2022 Xiaomi Corporation (author: Wei Kang) 4 | 5 | import glob 6 | import os 7 | import re 8 | import shutil 9 | import sys 10 | 11 | import setuptools 12 | from setuptools.command.build_ext import build_ext 13 | 14 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 15 | 16 | 17 | def cmake_extension(name, *args, **kwargs) -> setuptools.Extension: 18 | kwargs["language"] = "c++" 19 | sources = [] 20 | return setuptools.Extension(name, sources, *args, **kwargs) 21 | 22 | 23 | class BuildExtension(build_ext): 24 | def build_extension(self, ext: setuptools.extension.Extension): 25 | # build/temp.linux-x86_64-3.8 26 | build_dir = self.build_temp 27 | os.makedirs(build_dir, exist_ok=True) 28 | 29 | # build/lib.linux-x86_64-3.8 30 | os.makedirs(self.build_lib, exist_ok=True) 31 | 32 | ft_dir = os.path.dirname(os.path.abspath(__file__)) 33 | 34 | cmake_args = os.environ.get("FT_CMAKE_ARGS", "") 35 | make_args = os.environ.get("FT_MAKE_ARGS", "") 36 | system_make_args = os.environ.get("MAKEFLAGS", "") 37 | 38 | if cmake_args == "": 39 | cmake_args = "-DCMAKE_BUILD_TYPE=Release -DFT_BUILD_TESTS=OFF" 40 | 41 | if make_args == "" and system_make_args == "": 42 | make_args = " -j " 43 | 44 | if "PYTHON_EXECUTABLE" not in cmake_args: 45 | print(f"Setting PYTHON_EXECUTABLE to {sys.executable}") 46 | cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}" 47 | 48 | build_cmd = f""" 49 | cd {self.build_temp} 50 | 51 | cmake {cmake_args} {ft_dir} 52 | 53 | make {make_args} _fast_rnnt 54 | """ 55 | print(f"build command is:\n{build_cmd}") 56 | 57 | ret = os.system(build_cmd) 58 | if ret != 0: 59 | raise Exception( 60 | "\nBuild fast_rnnt failed. Please check the error " 61 | "message.\n" 62 | "You can ask for help by creating an issue on GitHub.\n" 63 | "\nClick:\n" 64 | "\thttps://github.com/danpovey/fast_rnnt/issues/new\n" # noqa 65 | ) 66 | lib_so = glob.glob(f"{build_dir}/lib/*.so*") 67 | for so in lib_so: 68 | print(f"Copying {so} to {self.build_lib}/") 69 | shutil.copy(f"{so}", f"{self.build_lib}/") 70 | 71 | # macos 72 | lib_so = glob.glob(f"{build_dir}/lib/*.dylib*") 73 | for so in lib_so: 74 | print(f"Copying {so} to {self.build_lib}/") 75 | shutil.copy(f"{so}", f"{self.build_lib}/") 76 | 77 | 78 | def read_long_description(): 79 | with open("README.md", encoding="utf8") as f: 80 | readme = f.read() 81 | return readme 82 | 83 | 84 | def get_package_version(): 85 | with open("CMakeLists.txt") as f: 86 | content = f.read() 87 | 88 | latest_version = re.search(r"set\(FT_VERSION (.*)\)", content).group(1) 89 | latest_version = latest_version.strip('"') 90 | return latest_version 91 | 92 | 93 | def get_requirements(): 94 | with open("requirements.txt", encoding="utf8") as f: 95 | requirements = f.read().splitlines() 96 | 97 | return requirements 98 | 99 | 100 | package_name = "fast_rnnt" 101 | 102 | with open("fast_rnnt/python/fast_rnnt/__init__.py", "a") as f: 103 | f.write(f"__version__ = '{get_package_version()}'\n") 104 | 105 | setuptools.setup( 106 | name=package_name, 107 | version=get_package_version(), 108 | author="Dan Povey", 109 | author_email="dpovey@gmail.com", 110 | package_dir={ 111 | package_name: "fast_rnnt/python/fast_rnnt", 112 | }, 113 | packages=[package_name], 114 | url="https://github.com/danpovey/fast_rnnt", 115 | description="Fast and memory-efficient RNN-T loss.", 116 | long_description=read_long_description(), 117 | long_description_content_type="text/markdown", 118 | install_requires=get_requirements(), 119 | ext_modules=[cmake_extension("_fast_rnnt")], 120 | cmdclass={"build_ext": BuildExtension}, 121 | zip_safe=False, 122 | classifiers=[ 123 | "Programming Language :: C++", 124 | "Programming Language :: Python", 125 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 126 | ], 127 | license="Apache licensed, as found in the LICENSE file", 128 | ) 129 | --------------------------------------------------------------------------------