├── LICENSE ├── README.md ├── assets ├── MAS.csv ├── MAS.png ├── MAS_log.png └── memory_read_write.png ├── cython_monotonic_align ├── __init__.py ├── core.pyx └── setup.py ├── jit_monotonic_align └── __init__.py ├── setup.py ├── super_monotonic_align ├── __init__.py └── core.py └── test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Supertone Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Super-Monotonic-Alignment-Search 2 | 3 | [![TechnicalReport](https://img.shields.io/badge/TechnicalReport-2409.07704-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2409.07704) 4 | 5 | This repo contains [Triton-Lang](https://github.com/triton-lang/triton) and PyTorch implementation of the monotonic alignment search (MAS), originally from [Glow-TTS](https://arxiv.org/abs/2005.11129). 6 | MAS is an effective algorithm for estimating the alignment between paired speech and text in a self-supervised manner. 7 | 8 | ![Image0](./assets/memory_read_write.png) 9 | 10 | 11 | The authors of Glow-TTS noted: 12 | > "The time complexity of the algorithm is O(T_{text} × T_{mel}). Even though the algorithm is difficult to parallelize, it runs efficiently on CPU without the need for GPU executions. In our experiments, it spends less than 20 ms on each iteration, which amounts to less than 2% of the total training time. Furthermore, we do not need MAS during inference, as the duration predictor is used to estimate the alignment." 13 | 14 | However, we found three issues while using MAS. 15 | 1. MAS can be parallelized in the text-length dimension, while the original implementation uses nested loops. 16 | 2. CPU execution consumes an inordinate amount of time for large inputs due to the need to copy large tensors between CPU and GPU. 17 | 3. The hard-coded value of max_neg_val at -1e9 is insufficient to prevent alignment mismatches in the upper diagonal parts. 18 | 19 | Therefore, we implemented a Triton kernel `super_monotonic_align` and PyTorch code `jit_monotonic_align` to accelerate MAS on GPU without inter-device copy. 20 | 21 | # Requirments 22 | 1. PyTorch (tested with version `torch==2.3.0+cu121`) 23 | 2. Triton-Lang (tested with version `triton==2.3.0`) 24 | 3. Cython (optional for bench, tested with version `Cython== 0.29.36`) 25 | 26 | Please ensure you have these packages installed to run the code in this repository, as version checks are not enforced. 27 | 28 | # How to use 29 | 1. Install super-monotonic-align 30 | ``` 31 | git clone git@github.com:supertone-inc/super-monotonic-align.git 32 | cd super-monotonic-align; pip install -e ./ 33 | ``` 34 | or 35 | ``` 36 | pip install git+https://github.com/supertone-inc/super-monotonic-align.git 37 | ``` 38 | 2. Import `super_monotonic_align` and use it! 39 | ```python 40 | from super_monotonic_align import maximum_path 41 | ... 42 | # You need to know value's value is modified by triton kernel. 43 | # If you want to keep value without changing, you need to clone it before maximum_path. 44 | # B: batch_size, T: text_length, S: audio_length 45 | value = torch.randn((B, T, S), dtype=torch.float32, device='cuda') 46 | attn_mask = torch.ones((B, T, S), dtype=torch.int32, device='cuda') 47 | # path: [B,T,S] tensor, you can specify path's dtype, default=torch.float32 48 | path = maximum_path(value, attn_mask, dtype=torch.bool) 49 | ``` 50 | 51 | ## Warning 52 | 53 | Please **check your input shape** before use. 54 | 55 | Thanks to [codeghees](https://github.com/codeghees) for the issue, our implementation uses the shape \[B, T, S\], identical to Glow-TTS version, while the VITS implementation uses the shape \[B, S, T\]. 56 | 57 | For now, we recommend to transpose it if you using \[B, S, T\] shaped input, but we will soon release an option that supprots \[B, S, T\] as well. 58 | 59 | # Benchmark 60 | ``` 61 | MAS in ms: 62 | T Triton JIT_v1 JIT_v2 Cython 63 | 0 128.0 0.447488 83.742203 53.222176 8.819136 64 | 1 256.0 1.616896 155.424774 104.632477 43.533665 65 | 2 384.0 3.430400 325.307404 237.820435 136.257538 66 | 3 512.0 5.838848 439.984131 344.654236 304.981201 67 | 4 640.0 9.070592 532.910095 452.141907 462.405304 68 | 5 768.0 12.249088 655.960083 587.169739 488.272858 69 | 6 896.0 15.203328 557.997070 620.148315 863.919067 70 | 7 1024.0 19.778561 627.986450 815.933167 1299.567871 71 | 8 1152.0 33.276928 706.022400 968.533813 1467.056885 72 | 9 1280.0 39.800835 792.861694 1215.021240 1930.171509 73 | 10 1408.0 47.456257 903.750671 1289.656250 2231.598145 74 | 11 1536.0 59.238914 953.907227 1523.870972 2959.377930 75 | 12 1664.0 70.068741 1031.818237 2004.299438 3073.532471 76 | 13 1792.0 82.205696 1558.200317 2359.347900 3930.776367 77 | 14 1920.0 99.634689 1183.214600 2512.063477 4374.311035 78 | 15 2048.0 107.218948 1261.682739 2889.841797 7792.640137 79 | ``` 80 | 81 | The Triton MAS implementation is at least 19 times faster and up to 72 times faster than the Cython implementation. PyTorch JIT implementations are faster than the Cython implementation for large-sized tensors, especially version v1, which does not involve inter-device copying. 82 | 83 | | ms in linear scale | ms in log scale | 84 | |----------|----------| 85 | | ![Image 1](./assets/MAS.png) | ![Image 2](./assets/MAS_log.png) | 86 | 87 | ## How to run benchmark 88 | ```bash 89 | cd cython_monotonic_align; mkdir cython_monotonic_align; python setup.py build_ext --inplace 90 | cd ../super_monotonic_align; pip install -e ./ 91 | cd ../; python test.py 92 | ``` 93 | 94 | # References 95 | This implementation uses code from following repositories: 96 | - [jaywalnut310's Official Glow-TTS Implementation](https://github.com/jaywalnut310/glow-tts) 97 | - [OpenAI's Triton-Lang Tutorials](https://github.com/triton-lang/triton) 98 | - [Tri Dao's FlashAttention (memory hierarchy)](https://github.com/Dao-AILab/flash-attention) 99 | 100 | # Acknowledgement 101 | This work is supported by Supertone Inc. and HYBE Corp. 102 | We thank Jinhyeok Yang, Juheon Lee, Yechan Yu, Seunghoon Ji, Jacob Morton, Seungu Han, Sungho Lee, Joon Byun, and Hoon Heo of Supertone research team and Hyeong-Seok Choi of ElevenLabs. 103 | 104 | 105 | # Authors 106 | - Junhyeok Lee ([jlee843@jhu.edu](mailto:jlee843@jhu.edu)) 107 | - Hyoungju Kim ([hyeongju@supertone.ai](mailto:hyeongju@supertone.ai)) 108 | 109 | If this repository useful for your research, please consider citing (with Glow-TTS or VITS)! 110 | ```bib 111 | @article{supermas, 112 | title={{Super Monotonic Alignment Search}}, 113 | author={Lee, Junhyeok and Kim, Hyeongju}, 114 | journal={arXiv preprint arXiv:2409.07704}, 115 | year={2024} 116 | } 117 | ``` 118 | 119 | Feel free to create an issue if you encounter any problems or have any questions. 120 | 121 | Additionally, [Supertone](https://supertone.ai) is hiring TTS researchers. 122 | If you are interested, please check out our career opportunities! 123 | -------------------------------------------------------------------------------- /assets/MAS.csv: -------------------------------------------------------------------------------- 1 | T,Triton,JIT_v1,JIT_v2,Cython 2 | 128.0,0.4,83.7,53.2,8.8 3 | 256.0,1.6,155.4,104.6,43.5 4 | 384.0,3.4,325.3,237.8,136.3 5 | 512.0,5.8,440.0,344.7,305.0 6 | 640.0,9.1,532.9,452.1,462.4 7 | 768.0,12.2,656.0,587.2,488.3 8 | 896.0,15.2,558.0,620.1,863.9 9 | 1024.0,19.8,628.0,815.9,1299.6 10 | 1152.0,33.3,706.0,968.5,1467.1 11 | 1280.0,39.8,792.9,1215.0,1930.2 12 | 1408.0,47.5,903.8,1289.7,2231.6 13 | 1536.0,59.2,953.9,1523.9,2959.4 14 | 1664.0,70.1,1031.8,2004.3,3073.5 15 | 1792.0,82.2,1558.2,2359.3,3930.8 16 | 1920.0,99.6,1183.2,2512.1,4374.3 17 | 2048.0,107.2,1261.7,2889.8,7792.6 18 | -------------------------------------------------------------------------------- /assets/MAS.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/supertone-inc/super-monotonic-align/9bb1cb3a6fbab27bbe6e566827b9548010c5dd51/assets/MAS.png -------------------------------------------------------------------------------- /assets/MAS_log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/supertone-inc/super-monotonic-align/9bb1cb3a6fbab27bbe6e566827b9548010c5dd51/assets/MAS_log.png -------------------------------------------------------------------------------- /assets/memory_read_write.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/supertone-inc/super-monotonic-align/9bb1cb3a6fbab27bbe6e566827b9548010c5dd51/assets/memory_read_write.png -------------------------------------------------------------------------------- /cython_monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/jaywalnut310/glow-tts/blob/master/monotonic_align/__init__.py 2 | import numpy as np 3 | import torch 4 | from .cython_monotonic_align.core import maximum_path_c 5 | 6 | 7 | def maximum_path(value, mask): 8 | """ Cython optimised version. 9 | value: [b, t_x, t_y] 10 | mask: [b, t_x, t_y] 11 | """ 12 | value = value * mask 13 | device = value.device 14 | dtype = value.dtype 15 | value = value.data.cpu().numpy().astype(np.float32) 16 | path = np.zeros_like(value).astype(np.int32) 17 | mask = mask.data.cpu().numpy() 18 | 19 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 20 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 21 | maximum_path_c(path, value, t_x_max, t_y_max) 22 | return torch.from_numpy(path).to(device=device, dtype=dtype) -------------------------------------------------------------------------------- /cython_monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | # copied from https://github.com/jaywalnut310/glow-tts/blob/master/monotonic_align/core.pyx 2 | import numpy as np 3 | cimport numpy as np 4 | cimport cython 5 | from cython.parallel import prange 6 | 7 | 8 | @cython.boundscheck(False) 9 | @cython.wraparound(False) 10 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 11 | cdef int x 12 | cdef int y 13 | cdef float v_prev 14 | cdef float v_cur 15 | cdef float tmp 16 | cdef int index = t_x - 1 17 | 18 | for y in range(t_y): 19 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 20 | if x == y: 21 | v_cur = max_neg_val 22 | else: 23 | v_cur = value[x, y-1] 24 | if x == 0: 25 | if y == 0: 26 | v_prev = 0. 27 | else: 28 | v_prev = max_neg_val 29 | else: 30 | v_prev = value[x-1, y-1] 31 | value[x, y] = max(v_cur, v_prev) + value[x, y] 32 | 33 | for y in range(t_y - 1, -1, -1): 34 | path[index, y] = 1 35 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 36 | index = index - 1 37 | 38 | 39 | @cython.boundscheck(False) 40 | @cython.wraparound(False) 41 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e32) nogil: 42 | cdef int b = values.shape[0] 43 | 44 | cdef int i 45 | for i in prange(b, nogil=True): 46 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 47 | -------------------------------------------------------------------------------- /cython_monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/jaywalnut310/glow-tts/blob/master/monotonic_align/setup.py 2 | 3 | from distutils.core import setup 4 | from Cython.Build import cythonize 5 | import numpy 6 | 7 | setup( 8 | name = 'cython_monotonic_align', 9 | ext_modules = cythonize("core.pyx"), 10 | include_dirs=[numpy.get_include()] 11 | ) -------------------------------------------------------------------------------- /jit_monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.no_grad() 5 | @torch.jit.script 6 | def maximum_path1(logp: torch.Tensor, attn_mask: torch.Tensor): 7 | # logp: [B, Tx, Ty], attn_mask: [B, Tx, Ty] 8 | B, Tx, Ty = logp.size() 9 | device = logp.device 10 | logp = logp * attn_mask # [B, Tx, Ty] 11 | path = torch.zeros_like(logp) # [B, Tx, Ty] 12 | max_neg_val = torch.tensor(-1e32, dtype=logp.dtype, device=device) 13 | 14 | x_len = attn_mask[:, :, 0].sum(dim=1).long() # [B] 15 | y_len = attn_mask[:, 0, :].sum(dim=1).long() # [B] 16 | 17 | for b in range(B): 18 | path[b, x_len[b] - 1, y_len[b] - 1] = 1 19 | 20 | # logp to cumulative logp 21 | logp[:, 1:, 0] = max_neg_val 22 | 23 | for ty in range(1, Ty): 24 | logp_prev_frame_1 = logp[:, :, ty - 1] # [B, Tx] 25 | logp_prev_frame_2 = torch.roll(logp_prev_frame_1, shifts=1, dims=1) # [B, Tx] 26 | logp_prev_frame_2[:, 0] = max_neg_val 27 | logp_prev_frame_max = torch.where(logp_prev_frame_1 > logp_prev_frame_2, logp_prev_frame_1, logp_prev_frame_2) 28 | logp[:, :, ty] += logp_prev_frame_max 29 | 30 | ids = torch.ones_like(x_len, device=device) * (x_len - 1) # [B] 31 | arange = torch.arange(B, device=device) 32 | path = path.permute(2, 0, 1).contiguous() # [Ty, B, Tx] 33 | attn_mask = attn_mask.permute(2, 0, 1).contiguous() # [Ty, B, Tx] 34 | y_len_minus_1 = y_len - 1 # [B] 35 | for ty in range(Ty - 1, 0, -1): 36 | logp_prev_frame_1 = logp[:, :, ty - 1] # [B, Tx] 37 | logp_prev_frame_2 = torch.roll(logp_prev_frame_1, shifts=1, dims=1) # [B, Tx] 38 | logp_prev_frame_2[:, 0] = max_neg_val 39 | direction = torch.where(logp_prev_frame_1 > logp_prev_frame_2, 0, -1) # [B, Tx] 40 | gathered_dir = torch.gather(direction, 1, ids.view(-1, 1)).view(-1) # [B] 41 | gathered_dir.masked_fill_(ty > y_len_minus_1, 0) 42 | ids.add_(gathered_dir) 43 | path[ty - 1, arange, ids] = 1 44 | path *= attn_mask 45 | path = path.permute(1, 2, 0) # [B, Tx, Ty] 46 | return path 47 | 48 | 49 | @torch.no_grad() 50 | def maximum_path2(logp: torch.Tensor, attn_mask: torch.Tensor): 51 | @torch.jit.script 52 | def cumulative_logp(logp, attn_mask): 53 | B, Tx, Ty = logp.size() 54 | device = logp.device 55 | logp = logp * attn_mask # [B, Tx, Ty] 56 | path = torch.zeros_like(logp) # [B, Tx, Ty] 57 | max_neg_val = torch.tensor(-1e32, dtype=logp.dtype, device=device) 58 | 59 | x_len = attn_mask[:, :, 0].sum(dim=1).long() # [B] 60 | y_len = attn_mask[:, 0, :].sum(dim=1).long() # [B] 61 | 62 | for b in range(B): 63 | path[b, x_len[b] - 1, y_len[b] - 1] = 1 64 | 65 | # logp to cumulative logp 66 | logp[:, 1:, 0] = max_neg_val 67 | 68 | for ty in range(1, Ty): 69 | logp_prev_frame_1 = logp[:, :, ty - 1] # [B, Tx] 70 | logp_prev_frame_2 = torch.roll(logp_prev_frame_1, shifts=1, dims=1) # [B, Tx] 71 | logp_prev_frame_2[:, 0] = max_neg_val 72 | logp_prev_frame_max = torch.where( 73 | logp_prev_frame_1 > logp_prev_frame_2, logp_prev_frame_1, logp_prev_frame_2 74 | ) 75 | logp[:, :, ty] += logp_prev_frame_max 76 | return logp, x_len, y_len, path 77 | 78 | device = logp.device 79 | logp, x_len, y_len, path = cumulative_logp(logp, attn_mask) 80 | B, Tx, Ty = logp.size() 81 | logp = logp.detach().cpu().numpy() 82 | x_len = x_len.detach().cpu().numpy() 83 | y_len = y_len.detach().cpu().numpy() 84 | path = path.detach().cpu().numpy() 85 | # backtracking (naive) 86 | for b in range(B): 87 | idx = x_len[b] - 1 88 | path[b, x_len[b] - 1, y_len[b] - 1] = 1 89 | for ty in range(y_len[b] - 1, 0, -1): 90 | if idx != 0 and logp[b, idx - 1, ty - 1] > logp[b, idx, ty - 1]: 91 | idx = idx - 1 92 | path[b, idx, ty - 1] = 1 93 | path = torch.from_numpy(path).to(device) 94 | return path 95 | 96 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='super-monotonic-align', 5 | version='1.0.0', 6 | packages=find_packages(include=['super_monotonic_align', 'super_monotonic_align.*']) 7 | ) 8 | -------------------------------------------------------------------------------- /super_monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from super_monotonic_align.core import maximum_path_triton 3 | 4 | @torch.no_grad() 5 | def maximum_path(value, mask, dtype=torch.float32): 6 | """ Triton optimized version. 7 | value: [b, t_x, t_y] 8 | mask: [b, t_x, t_y] 9 | skip_mask: [b, t_x] 10 | """ 11 | # check value is contiguous 12 | value = value.contiguous() 13 | # Use masked_fill_ to avoid new tensor creation 14 | value = value.masked_fill_(mask.logical_not(), 0) 15 | path = torch.zeros_like(value, dtype=dtype) 16 | t_x_max = mask.sum(1)[:, 0].to(torch.int32) 17 | t_y_max = mask.sum(2)[:, 0].to(torch.int32) 18 | path = maximum_path_triton(path, value, t_x_max, t_y_max) 19 | return path -------------------------------------------------------------------------------- /super_monotonic_align/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import triton 4 | import triton.language as tl 5 | 6 | @triton.jit 7 | def maximum_path( 8 | path, value, t_x, t_y, 9 | B, T, S, 10 | max_neg_val, 11 | BLOCK_SIZE_X: tl.constexpr 12 | ): 13 | batch = tl.program_id(axis=0) 14 | path += batch * T * S 15 | value += batch * T * S 16 | x_length = tl.load(t_x + batch) 17 | y_length = tl.load(t_y + batch) 18 | offs_prev = tl.arange(0, BLOCK_SIZE_X) 19 | init = tl.where(offs_prev ==0, tl.load(value), max_neg_val) 20 | # for j in range(0,1,1): # set the first column to max_neg_val without init point 21 | tl.store(value + offs_prev * S, init, mask=offs_prev < x_length) 22 | for j in range(1, y_length, 1): 23 | v_cur= tl.load(value + (offs_prev) * S + (j-1), mask=(offs_prev < x_length), other=max_neg_val) 24 | v_prev =tl.load(value + (offs_prev-1) * S + (j-1), mask=(0 < offs_prev) & (offs_prev < x_length), other=max_neg_val) 25 | # compare v_cur and v_prev, and update v with larger value 26 | v = (tl.maximum(v_cur, v_prev) + tl.load(value + (offs_prev) * S + j, mask=(offs_prev < x_length))) 27 | tl.store(value + (offs_prev) * S + j, v, mask=(offs_prev < x_length)) 28 | 29 | index = x_length-1 30 | for j in range(y_length-1,-1,-1): 31 | tl.store(path + (index) * S + j, 1) 32 | if (index > 0): # (index == j) is not checked due to max_neg_val init 33 | v_left = tl.load(value+ (index) * S+ j-1)#.to(tl.float32) 34 | v_leftdown = tl.load(value+(index-1) * S + j-1)#.to(tl.float32) 35 | if (v_left < v_leftdown): 36 | index += - 1 37 | 38 | 39 | @torch.no_grad() 40 | def maximum_path_triton(path, value, t_x, t_y, max_neg_val=-1e32): 41 | B,T,S = path.shape 42 | BLOCK_SIZE_X = max(triton.next_power_of_2(T), 16) 43 | num_warps = 1 # Need to be 1 to prevent wrong output by slicing the operation 44 | with torch.cuda.device(value.device.index): 45 | maximum_path[(B, )]( 46 | path, value, t_x, t_y, 47 | B, T, S, 48 | max_neg_val = max_neg_val, 49 | num_warps = num_warps, 50 | BLOCK_SIZE_X = BLOCK_SIZE_X) 51 | return path 52 | 53 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import triton 3 | from super_monotonic_align import maximum_path as maximum_path_trion 4 | from cython_monotonic_align import maximum_path as maximum_path_cython 5 | from jit_monotonic_align import maximum_path1 as maximum_path_jit_v1 6 | from jit_monotonic_align import maximum_path2 as maximum_path_jit_v2 7 | 8 | 9 | def identical_test(B,T,S): 10 | value = torch.randn((B, T, S), dtype=torch.float32, device='cuda') 11 | attn_mask = torch.ones((B, T, S), dtype=torch.int32, device='cuda') 12 | path_c = maximum_path_cython(value, attn_mask) 13 | path_jit1 = maximum_path_jit_v1(value, attn_mask) 14 | path_jit2 = maximum_path_jit_v2(value, attn_mask) 15 | path_tri = maximum_path_trion(value.clone(), attn_mask) 16 | 17 | # not 100% equal due to precision issue 18 | assert torch.allclose(path_c, path_tri, atol=1e-2, rtol=0), f"Failed on shape=({B,T,S})\n{path_c}\n{path_tri}\ndiff:{(path_c-path_tri).abs().sum()}" 19 | assert torch.allclose(path_c, path_jit1, atol=1e-2, rtol=0), f"Failed on shape=({B,T,S})\n{path_c}\n{path_jit1}\ndiff:{(path_c-path_jit1).abs().sum()}" 20 | assert torch.allclose(path_c, path_jit2, atol=1e-2, rtol=0), f"Failed on shape=({B,T,S})\n{path_c}\n{path_jit2}\ndiff:{(path_c-path_jit2).abs().sum()}" 21 | 22 | # benchmark 23 | @triton.testing.perf_report( 24 | triton.testing.Benchmark( 25 | x_names=['T'], 26 | x_vals=[128 * i for i in range(1, 17)], 27 | line_arg='provider', 28 | line_vals= ['triton', 'jit_v1', 'jit_v2', 'cython'], 29 | line_names=['Triton', 'JIT_v1', 'JIT_v2', 'Cython'], 30 | styles=[('blue', '-'), ('green', '-'), ('red', '-'), ('orange', '-')], 31 | ylabel='ms', 32 | plot_name='MAS in ms', 33 | y_log=True, 34 | args={'B': 16}, 35 | )) 36 | def bench_mas(B, T, provider, device='cuda'): 37 | from cython_monotonic_align import maximum_path as maximum_path_cython 38 | # create data 39 | quantiles = [0.5, 0.2, 0.8] 40 | 41 | S = 4*T 42 | value = torch.randn((B, T, S), dtype=torch.float32, device=device) 43 | attn_mask = torch.ones((B, T, S), dtype=torch.int32, device=device) 44 | 45 | # utility functions 46 | if provider == 'triton': 47 | 48 | def y_fwd(): 49 | return maximum_path_trion(value, attn_mask) # noqa: F811, E704 50 | 51 | if provider == 'cython': 52 | 53 | def y_fwd(): 54 | return maximum_path_cython(value, attn_mask) # noqa: F811, E704 55 | 56 | if provider == 'jit_v1': 57 | 58 | def y_fwd(): 59 | return maximum_path_jit_v1(value, attn_mask) 60 | 61 | if provider == 'jit_v2': 62 | 63 | def y_fwd(): 64 | return maximum_path_jit_v2(value, attn_mask) 65 | 66 | ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) 67 | 68 | return (ms), (max_ms), (min_ms) 69 | 70 | if __name__ == "__main__": 71 | for (b,t,s) in [(32, 16, 16), (32, 128, 512), (32, 256, 1024), (32, 511, 2048)]: 72 | identical_test(b,t,s) 73 | print(f"Passed on shape=({b},{t},{s})") 74 | bench_mas.run(save_path='.', print_data=True) 75 | --------------------------------------------------------------------------------