├── requirements.txt ├── .gitignore ├── concept.png ├── torch_cif ├── __init__.py └── cif.py ├── setup.cfg ├── LICENSE ├── setup.py ├── README.md ├── benchmark └── benchmark.py └── tests └── test_cif.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode/ 3 | *.egg-info/ 4 | dist/ 5 | build/ -------------------------------------------------------------------------------- /concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/George0828Zhang/torch_cif/HEAD/concept.png -------------------------------------------------------------------------------- /torch_cif/__init__.py: -------------------------------------------------------------------------------- 1 | from torch_cif.cif import cif_function 2 | 3 | __version__ = "0.2.0" 4 | __all__ = [ 5 | "cif_function", 6 | "__version__", 7 | ] 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | long_description=file: README.md 4 | long_description_content_type=text/markdown 5 | 6 | [flake8] 7 | max-line-length = 100 8 | ignore = 9 | E203, 10 | W503, 11 | 12 | [isort] 13 | profile=black 14 | lines_between_types=1 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Chih-Chiang Chang 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import find_packages, setup 3 | 4 | 5 | base_dir = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | 8 | def get_version(): 9 | version_path = os.path.join(base_dir, "torch_cif", "__init__.py") 10 | version = {} 11 | with open(version_path, encoding="utf-8") as fp: 12 | exec(fp.read(), version) 13 | return version["__version__"] 14 | 15 | 16 | setup( 17 | name='torch_cif', 18 | packages=find_packages(), 19 | version=get_version(), 20 | license='MIT', 21 | description='A fast parallel implementation of continuous' 22 | ' integrate-and-fire (CIF) https://arxiv.org/abs/1905.11235', 23 | author='Chih-Chiang Chang', 24 | author_email='cc.chang0828@gmail.com', 25 | url='https://github.com/George0828Zhang/torch_cif', 26 | keywords=" ".join([ 27 | 'speech', 'speech-recognition', 'asr', 'automatic-speech-recognition', 28 | 'speech-to-text', 'speech-translation', 29 | 'continuous-integrate-and-fire', 'cif', 30 | 'monotonic', 'alignment', 'torch', 'pytorch' 31 | ]), 32 | classifiers=[ 33 | "Development Status :: 4 - Beta", 34 | "Intended Audience :: Developers", 35 | "Intended Audience :: Science/Research", 36 | "License :: OSI Approved :: MIT License", 37 | "Programming Language :: Python :: 3", 38 | "Programming Language :: Python :: 3 :: Only", 39 | "Programming Language :: Python :: 3.6", 40 | "Programming Language :: Python :: 3.7", 41 | "Programming Language :: Python :: 3.8", 42 | "Programming Language :: Python :: 3.9", 43 | "Programming Language :: Python :: 3.10", 44 | "Programming Language :: Python :: 3.11", 45 | "Programming Language :: Python :: 3.12", 46 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 47 | ], 48 | python_requires=">=3.6", 49 | install_requires=["torch"], 50 | extras_require={ 51 | "test": [ 52 | "hypothesis", 53 | "expecttest" 54 | ], 55 | }, 56 | include_package_data=True, 57 | ) 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torch-cif 2 | 3 | A fast parallel implementation pure PyTorch implementation of *"CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition"* https://arxiv.org/abs/1905.11235. 4 | 5 | ## Installation 6 | ### PyPI 7 | ```bash 8 | pip install torch-cif 9 | ``` 10 | ### Locally 11 | ```bash 12 | git clone https://github.com/George0828Zhang/torch_cif 13 | cd torch_cif 14 | python setup.py install 15 | ``` 16 | 17 | ## Usage 18 | ```python 19 | def cif_function( 20 | inputs: Tensor, 21 | alpha: Tensor, 22 | beta: float = 1.0, 23 | tail_thres: float = 0.5, 24 | padding_mask: Optional[Tensor] = None, 25 | target_lengths: Optional[Tensor] = None, 26 | eps: float = 1e-4, 27 | unbound_alpha: bool = False 28 | ) -> Dict[str, List[Tensor]]: 29 | r""" A fast parallel implementation of continuous integrate-and-fire (CIF) 30 | https://arxiv.org/abs/1905.11235 31 | 32 | Shapes: 33 | N: batch size 34 | S: source (encoder) sequence length 35 | C: source feature dimension 36 | T: target sequence length 37 | 38 | Args: 39 | inputs (Tensor): (N, S, C) Input features to be integrated. 40 | alpha (Tensor): (N, S) Weights corresponding to each elements in the 41 | inputs. It is expected to be after sigmoid function. 42 | beta (float): the threshold used for determine firing. 43 | tail_thres (float): the threshold for determine firing for tail handling. 44 | padding_mask (Tensor, optional): (N, S) A binary mask representing 45 | padded elements in the inputs. 1 is padding, 0 is not. 46 | target_lengths (Tensor, optional): (N,) Desired length of the targets 47 | for each sample in the minibatch. 48 | eps (float, optional): Epsilon to prevent underflow for divisions. 49 | Default: 1e-4 50 | unbound_alpha (bool, optional): Whether to check if 0 <= alpha <= 1. 51 | 52 | Returns -> Dict[str, List[Tensor]]: Key/values described below. 53 | cif_out: (N, T, C) The output integrated from the source. 54 | cif_lengths: (N,) The output length for each element in batch. 55 | alpha_sum: (N,) The sum of alpha for each element in batch. 56 | Can be used to compute the quantity loss. 57 | delays: (N, T) The expected delay (in terms of source tokens) for 58 | each target tokens in the batch. 59 | tail_weights: (N,) During inference, return the tail. 60 | scaled_alpha: (N, S) alpha after applying weight scaling. 61 | cumsum_alpha: (N, S) cumsum of alpha after scaling. 62 | right_indices: (N, S) right scatter indices, or floor(cumsum(alpha)). 63 | right_weights: (N, S) right scatter weights. 64 | left_indices: (N, S) left scatter indices. 65 | left_weights: (N, S) left scatter weights. 66 | """ 67 | ``` 68 | 69 | ## Note 70 | - This implementation uses `cumsum` and `floor` to determine the firing positions, and use `scatter` to merge the weighted source features. The figure below demonstrates this concept using *scaled* weight sequence `(0.4, 1.8, 1.2, 1.2, 1.4)` 71 | 72 | drawing 73 | 74 | - Runing test requires `pip install hypothesis expecttest`. 75 | - If `beta != 1`, our implementation slightly differ from Algorithm 1 in the paper [[1]](#reference): 76 | - When a boundary is located, the original algorithm add the last feature to the current integration with weight `1 - accumulation` (line 11 in Algorithm 1), which causes negative weights in next integration when `alpha < 1 - accumulation`. 77 | - We use `beta - accumulation`, which means the weight in next integration `alpha - (beta - accumulation)` is always positive. 78 | - Feel free to contact me if there are bugs in the code. 79 | 80 | ## References 81 | 1. [CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition](https://arxiv.org/abs/1905.11235) 82 | 2. [Exploring Continuous Integrate-and-Fire for Adaptive Simultaneous Speech Translation](https://www.isca-archive.org/interspeech_2022/chang22f_interspeech.html) -------------------------------------------------------------------------------- /benchmark/benchmark.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | from torch import Tensor 4 | 5 | import torch.utils.benchmark as benchmark 6 | 7 | 8 | def lengths_to_padding_mask(lens): 9 | bsz, max_lens = lens.size(0), torch.max(lens).item() 10 | mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) 11 | mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) 12 | return mask 13 | 14 | 15 | def cif_sequential_ref( 16 | input: Tensor, 17 | alpha: Tensor, 18 | beta: float = 1.0, 19 | tail_thres: float = 0.5, 20 | padding_mask: Optional[Tensor] = None, 21 | target_lengths: Optional[Tensor] = None, 22 | eps: float = 1e-4, 23 | ) -> Tensor: 24 | B, S, C = input.size() 25 | 26 | if padding_mask is not None: 27 | alpha = alpha.masked_fill(padding_mask, 0) 28 | 29 | if target_lengths is not None: 30 | feat_lengths = target_lengths.long() 31 | desired_sum = beta * target_lengths.type_as(input) + eps 32 | alpha_sum = alpha.sum(1) 33 | alpha = alpha * (desired_sum / alpha_sum).unsqueeze(1) 34 | T = feat_lengths.max() 35 | else: 36 | alpha_sum = alpha.sum(1) 37 | feat_lengths = (alpha_sum / beta).floor().long() 38 | T = feat_lengths.max() 39 | 40 | output = input.new_zeros((B, T + 1, C)) 41 | delay = input.new_zeros((B, T + 1)) 42 | 43 | if padding_mask is not None: 44 | source_lengths = (~padding_mask).sum(-1).long() 45 | else: 46 | source_lengths = input.new_full((B,), S, dtype=torch.long) 47 | 48 | # for b in range(B): 49 | assert B == 1 50 | b = 0 51 | 52 | csum = 0 53 | src_idx = 0 54 | dst_idx = 0 55 | tail_idx = 0 56 | while src_idx < source_lengths[b]: 57 | if csum + alpha[b, src_idx] < beta: 58 | csum += alpha[b, src_idx] 59 | output[b, dst_idx] += alpha[b, src_idx] * input[b, src_idx] 60 | delay[b, dst_idx] += alpha[b, src_idx] * (1 + src_idx) / beta 61 | tail_idx = dst_idx 62 | alpha[b, src_idx] = 0 63 | src_idx += 1 64 | else: 65 | fire_w = beta - csum 66 | alpha[b, src_idx] -= fire_w 67 | output[b, dst_idx] += fire_w * input[b, src_idx] 68 | delay[b, dst_idx] += fire_w * (1 + src_idx) / beta 69 | tail_idx = dst_idx 70 | csum = 0 71 | dst_idx += 1 72 | 73 | if csum >= tail_thres: 74 | output[b, tail_idx] *= beta / csum 75 | else: 76 | output[b, tail_idx:] = 0 77 | 78 | # tail handling 79 | if (target_lengths is not None) or output[:, T, :].eq(0).all(): 80 | # training time -> ignore tail 81 | output = output[:, :T, :] 82 | delay = delay[:, :T] 83 | 84 | return output, delay 85 | 86 | 87 | if __name__ == "__main__": 88 | # B, S, T, C = 256, 3072, 512, 256 89 | B, S, T, C = 1, 1024, 512, 256 90 | beta = 0.5 91 | 92 | # inputs 93 | device = torch.device("cuda:0") 94 | # inputs 95 | input = torch.rand(B, S, C, device=device) 96 | alpha = torch.randn((B, S), device=device).sigmoid_() 97 | source_lengths = torch.full((B,), S, device=device) 98 | target_lengths = torch.full((B,), T, device=device) 99 | 100 | padding_mask = lengths_to_padding_mask(source_lengths) 101 | 102 | globals = { 103 | 'input': input, 104 | 'alpha': alpha, 105 | 'beta': beta, 106 | 'padding_mask': padding_mask, 107 | 'target_lengths': target_lengths, 108 | } 109 | 110 | num_threads = torch.get_num_threads() 111 | print(f'Benchmarking on {num_threads} threads') 112 | 113 | t1 = benchmark.Timer( 114 | stmt='cif_function(input,alpha,beta,padding_mask=padding_mask,target_lengths=target_lengths)', 115 | setup='from torch_cif import cif_function', 116 | globals=globals, 117 | num_threads=num_threads 118 | ) 119 | 120 | t0 = benchmark.Timer( 121 | stmt='cif_sequential_ref(input,alpha,beta,padding_mask=padding_mask,target_lengths=target_lengths)', 122 | setup='from __main__ import cif_sequential_ref', 123 | globals=globals, 124 | num_threads=num_threads 125 | ) 126 | 127 | print(t1.timeit(1000)) 128 | print(t0.timeit(1000)) 129 | -------------------------------------------------------------------------------- /tests/test_cif.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from typing import Optional 4 | from torch import Tensor 5 | from torch_cif import cif_function 6 | 7 | import hypothesis.strategies as st 8 | from hypothesis import assume, given, settings 9 | from torch.testing._internal.common_utils import TestCase 10 | 11 | TEST_CUDA = torch.cuda.is_available() 12 | 13 | 14 | def lengths_to_padding_mask(lens): 15 | bsz, max_lens = lens.size(0), torch.max(lens).item() 16 | mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) 17 | mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) 18 | return mask 19 | 20 | 21 | class TestCIF(TestCase): 22 | def _test_cif_ref( 23 | self, 24 | input: Tensor, 25 | alpha: Tensor, 26 | beta: float = 1.0, 27 | tail_thres: float = 0.5, 28 | padding_mask: Optional[Tensor] = None, 29 | target_lengths: Optional[Tensor] = None, 30 | eps: float = 1e-4, 31 | ) -> Tensor: 32 | B, S, C = input.size() 33 | 34 | if padding_mask is not None: 35 | alpha = alpha.masked_fill(padding_mask, 0) 36 | 37 | if target_lengths is not None: 38 | feat_lengths = target_lengths.long() 39 | desired_sum = beta * target_lengths.type_as(input) + eps 40 | alpha_sum = alpha.sum(1) 41 | alpha = alpha * (desired_sum / alpha_sum).unsqueeze(1) 42 | T = feat_lengths.max() 43 | else: 44 | alpha_sum = alpha.sum(1) 45 | feat_lengths = (alpha_sum / beta).floor().long() 46 | T = feat_lengths.max() 47 | 48 | output = input.new_zeros((B, T + 1, C)) 49 | delay = input.new_zeros((B, T + 1)) 50 | 51 | if padding_mask is not None: 52 | source_lengths = (~padding_mask).sum(-1).long() 53 | else: 54 | source_lengths = input.new_full((B,), S, dtype=torch.long) 55 | 56 | for b in range(B): 57 | csum = 0 58 | src_idx = 0 59 | dst_idx = 0 60 | tail_idx = 0 61 | while src_idx < source_lengths[b]: 62 | if csum + alpha[b, src_idx] < beta: 63 | csum += alpha[b, src_idx] 64 | output[b, dst_idx] += alpha[b, src_idx] * input[b, src_idx] 65 | delay[b, dst_idx] += alpha[b, src_idx] * (1 + src_idx) / beta 66 | tail_idx = dst_idx 67 | alpha[b, src_idx] = 0 68 | src_idx += 1 69 | else: 70 | fire_w = beta - csum 71 | alpha[b, src_idx] -= fire_w 72 | output[b, dst_idx] += fire_w * input[b, src_idx] 73 | delay[b, dst_idx] += fire_w * (1 + src_idx) / beta 74 | tail_idx = dst_idx 75 | csum = 0 76 | dst_idx += 1 77 | 78 | if csum >= tail_thres: 79 | output[b, tail_idx] *= beta / csum 80 | else: 81 | output[b, tail_idx:] = 0 82 | 83 | # tail handling 84 | if (target_lengths is not None) or output[:, T, :].eq(0).all(): 85 | # training time -> ignore tail 86 | output = output[:, :T, :] 87 | delay = delay[:, :T] 88 | 89 | return output, delay 90 | 91 | def _test_custom_cif_impl( 92 | self, *args, **kwargs 93 | ): 94 | return cif_function(*args, **kwargs) 95 | 96 | @settings(deadline=None) 97 | @given( 98 | B=st.integers(1, 10), 99 | T=st.integers(1, 20), 100 | S=st.integers(1, 200), 101 | C=st.integers(1, 20), 102 | beta=st.floats(0.5, 1.5), 103 | device=st.sampled_from(["cpu", "cuda"]), 104 | ) 105 | def test_cif(self, B, T, S, C, beta, device): 106 | 107 | assume(device == "cpu" or TEST_CUDA) 108 | 109 | # inputs 110 | device = torch.device("cpu") 111 | # inputs 112 | input = torch.rand(B, S, C, device=device) 113 | alpha = torch.randn((B, S), device=device).sigmoid_() 114 | # source_lengths = torch.full((B,), S, device=device) 115 | # target_lengths = torch.full((B,), T, device=device) 116 | source_lengths = torch.randint(1, S + 1, (B,), device=device) 117 | target_lengths = torch.randint(1, T + 1, (B,), device=device) 118 | 119 | source_lengths = (source_lengths * S / source_lengths.max()).long() 120 | target_lengths = (target_lengths * T / target_lengths.max()).long() 121 | 122 | padding_mask = lengths_to_padding_mask(source_lengths) 123 | 124 | # train 125 | y, dy = self._test_cif_ref( 126 | input, 127 | alpha, 128 | beta, 129 | padding_mask=padding_mask, 130 | target_lengths=target_lengths 131 | ) 132 | 133 | x_out = self._test_custom_cif_impl( 134 | input, 135 | alpha, 136 | beta, 137 | padding_mask=padding_mask, 138 | target_lengths=target_lengths 139 | ) 140 | x = x_out['cif_out'][0] 141 | dx = x_out['delays'][0] 142 | torch.testing.assert_close( 143 | x, 144 | y, 145 | atol=1e-3, 146 | rtol=1e-3, 147 | ) 148 | torch.testing.assert_close( 149 | dx, 150 | dy, 151 | atol=1e-3, 152 | rtol=1e-3, 153 | ) 154 | 155 | # test 156 | y2, dy2 = self._test_cif_ref( 157 | input, 158 | alpha, 159 | beta, 160 | padding_mask=padding_mask 161 | ) 162 | 163 | x2_out = self._test_custom_cif_impl( 164 | input, 165 | alpha, 166 | beta, 167 | padding_mask=padding_mask 168 | ) 169 | x2 = x2_out['cif_out'][0] 170 | dx2 = x2_out['delays'][0] 171 | torch.testing.assert_close( 172 | x2, 173 | y2, 174 | atol=1e-3, 175 | rtol=1e-3, 176 | ) 177 | torch.testing.assert_close( 178 | dx2, 179 | dy2, 180 | atol=1e-3, 181 | rtol=1e-3, 182 | ) 183 | 184 | 185 | if __name__ == "__main__": 186 | unittest.main() 187 | -------------------------------------------------------------------------------- /torch_cif/cif.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides the cif_function(...) method that implements parallel CIF, 3 | including training (with target lengths) and inference. 4 | - Author: Chih-Chiang Chang (github: George0828Zhang) 5 | """ 6 | import torch 7 | from typing import Optional, Dict, List 8 | from torch import Tensor 9 | 10 | 11 | def cif_function( 12 | inputs: Tensor, 13 | alpha: Tensor, 14 | beta: float = 1.0, 15 | tail_thres: float = 0.5, 16 | padding_mask: Optional[Tensor] = None, 17 | target_lengths: Optional[Tensor] = None, 18 | eps: float = 1e-4, 19 | unbound_alpha: bool = False 20 | ) -> Dict[str, List[Tensor]]: 21 | r""" A fast parallel implementation of continuous integrate-and-fire (CIF) 22 | https://arxiv.org/abs/1905.11235 23 | 24 | Shapes: 25 | N: batch size 26 | S: source (encoder) sequence length 27 | C: source feature dimension 28 | T: target sequence length 29 | 30 | Args: 31 | inputs (Tensor): (N, S, C) Input features to be integrated. 32 | alpha (Tensor): (N, S) Weights corresponding to each elements in the 33 | inputs. It is expected to be after sigmoid function. 34 | beta (float): the threshold used for determine firing. 35 | tail_thres (float): the threshold for determine firing for tail handling. 36 | padding_mask (Tensor, optional): (N, S) A binary mask representing 37 | padded elements in the inputs. 1 is padding, 0 is not. 38 | target_lengths (Tensor, optional): (N,) Desired length of the targets 39 | for each sample in the minibatch. 40 | eps (float, optional): Epsilon to prevent underflow for divisions. 41 | Default: 1e-4 42 | unbound_alpha (bool, optional): Whether to check if 0 <= alpha <= 1. 43 | 44 | Returns -> Dict[str, List[Tensor]]: Key/values described below. 45 | cif_out: (N, T, C) The output integrated from the source. 46 | cif_lengths: (N,) The output length for each element in batch. 47 | alpha_sum: (N,) The sum of alpha for each element in batch. 48 | Can be used to compute the quantity loss. 49 | delays: (N, T) The expected delay (in terms of source tokens) for 50 | each target tokens in the batch. 51 | tail_weights: (N,) During inference, return the tail. 52 | scaled_alpha: (N, S) alpha after applying weight scaling. 53 | cumsum_alpha: (N, S) cumsum of alpha after scaling. 54 | right_indices: (N, S) right scatter indices, or floor(cumsum(alpha)). 55 | right_weights: (N, S) right scatter weights. 56 | left_indices: (N, S) left scatter indices. 57 | left_weights: (N, S) left scatter weights. 58 | """ 59 | B, S, C = inputs.size() 60 | assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}" 61 | assert not torch.isnan(alpha).any(), "Nan in alpha tensor." 62 | assert unbound_alpha or (alpha.le(1.0 + eps).all() and alpha.ge(0.0 - eps).all()), ( 63 | "Incorrect values in alpha tensor" 64 | ", 0.0 <= tensor <= 1.0" 65 | ) 66 | 67 | dtype = alpha.dtype 68 | alpha = alpha.float() 69 | if padding_mask is not None: 70 | padding_mask = padding_mask.bool() 71 | assert not padding_mask[:, 0].any(), "Expected right-padded inputs." 72 | alpha = alpha.masked_fill(padding_mask, 0) 73 | 74 | if target_lengths is not None: 75 | assert target_lengths.size() == (B,) 76 | feat_lengths = target_lengths.long() 77 | desired_sum = beta * target_lengths.type_as(inputs) + eps 78 | alpha_sum = alpha.sum(1) 79 | alpha = alpha * (desired_sum / alpha_sum).unsqueeze(1) 80 | T = feat_lengths.max() 81 | else: 82 | alpha_sum = alpha.sum(1) 83 | feat_lengths = (alpha_sum / beta).floor().long() 84 | T = feat_lengths.max() 85 | 86 | # aggregate and integrate 87 | csum = alpha.cumsum(-1) 88 | with torch.no_grad(): 89 | # indices used for scattering 90 | right_idx = (csum / beta).floor().long().clip(max=T) 91 | left_idx = right_idx.roll(1, dims=1) 92 | left_idx[:, 0] = 0 93 | 94 | # count # of fires from each source 95 | fire_num = right_idx - left_idx 96 | extra_weights = (fire_num - 1).clip(min=0) 97 | 98 | # The extra entry in last dim is for tail 99 | output = inputs.new_zeros((B, T + 1, C)) 100 | delay = inputs.new_zeros((B, T + 1)) 101 | source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(inputs) 102 | zero = alpha.new_zeros((1,)) 103 | 104 | # right scatter 105 | fire_mask = fire_num > 0 106 | right_weight = torch.where( 107 | fire_mask, 108 | csum - right_idx.type_as(alpha) * beta, 109 | zero 110 | ).type_as(inputs) 111 | output.scatter_add_( 112 | 1, 113 | right_idx.unsqueeze(-1).expand(-1, -1, C), 114 | right_weight.unsqueeze(-1) * inputs 115 | ) 116 | delay.scatter_add_( 117 | 1, 118 | right_idx, 119 | right_weight * source_range / beta 120 | ) 121 | 122 | # left scatter 123 | left_weight = ( 124 | alpha - right_weight - extra_weights.type_as(alpha) * beta 125 | ).type_as(inputs) 126 | output.scatter_add_( 127 | 1, 128 | left_idx.unsqueeze(-1).expand(-1, -1, C), 129 | left_weight.unsqueeze(-1) * inputs 130 | ) 131 | delay.scatter_add_( 132 | 1, 133 | left_idx, 134 | left_weight * source_range / beta 135 | ) 136 | 137 | # extra scatters 138 | if extra_weights.ge(0).any(): 139 | extra_steps = extra_weights.max().item() 140 | tgt_idx = left_idx 141 | src_feats = inputs * beta 142 | for _ in range(extra_steps): 143 | tgt_idx = (tgt_idx + 1).clip(max=T) 144 | # (B, S, 1) 145 | src_mask = (extra_weights > 0) 146 | output.scatter_add_( 147 | 1, 148 | tgt_idx.unsqueeze(-1).expand(-1, -1, C), 149 | src_feats * src_mask.unsqueeze(2) 150 | ) 151 | delay.scatter_add_( 152 | 1, 153 | tgt_idx, 154 | source_range * src_mask 155 | ) 156 | extra_weights -= 1 157 | 158 | # tail handling 159 | if target_lengths is not None: 160 | # training time -> ignore tail 161 | output = output[:, :T, :] 162 | delay = delay[:, :T] 163 | else: 164 | # find out contribution to output tail 165 | # note: w/o scaling, extra weight is all 0 166 | zero = right_weight.new_zeros((1,)) 167 | r_mask = right_idx == feat_lengths.unsqueeze(1) 168 | tail_weights = torch.where(r_mask, right_weight, zero).sum(-1) 169 | l_mask = left_idx == feat_lengths.unsqueeze(1) 170 | tail_weights += torch.where(l_mask, left_weight, zero).sum(-1) 171 | 172 | # a size (B,) mask that extends position that passed threshold. 173 | extend_mask = tail_weights >= tail_thres 174 | 175 | # extend 1 fire and upscale the weights 176 | if extend_mask.any(): 177 | # (B, T, C), may have infs so need the mask 178 | upscale = ( 179 | torch.ones_like(output) 180 | .scatter( 181 | 1, 182 | feat_lengths.view(B, 1, 1).expand(-1, -1, C), 183 | beta / ( 184 | tail_weights 185 | .masked_fill(~extend_mask, beta) 186 | .view(B, 1, 1) 187 | .expand(-1, -1, C)), 188 | ) 189 | .detach() 190 | ) 191 | output *= upscale 192 | feat_lengths += extend_mask.long() 193 | T = feat_lengths.max() 194 | output = output[:, :T, :] 195 | delay = delay[:, :T] 196 | 197 | # a size (B, T) mask to erase weights 198 | tail_mask = torch.arange(T, device=output.device).unsqueeze(0) >= feat_lengths.unsqueeze(1) 199 | output[tail_mask] = 0 200 | 201 | return { 202 | "cif_out": [output], 203 | "cif_lengths": [feat_lengths], 204 | "alpha_sum": [alpha_sum.to(dtype)], 205 | "delays": [delay], 206 | "tail_weights": [tail_weights] if target_lengths is None else [], 207 | "scaled_alpha": [alpha], 208 | "cumsum_alpha": [csum], 209 | "right_indices": [right_idx], 210 | "right_weights": [right_weight], 211 | "left_indices": [left_idx], 212 | "left_weights": [left_weight], 213 | } 214 | --------------------------------------------------------------------------------