├── gfd ├── __init__.py ├── utils.py ├── beam.py ├── tokenizer.py ├── model.py └── gfd.py ├── tests ├── conftest.py ├── test_model.py ├── test_kv_cache.py ├── test_tokenizer.py └── test_beam.py ├── config_files ├── prompt │ ├── fleurs-hk-prompt.yaml │ ├── formosa-long-prompt.yaml │ ├── ml-lecture-prompt.yaml │ ├── noisy-librispeech-prompt.yaml │ ├── ml-lecture-long-prompt.yaml │ ├── atco2-asr-only-prompt.yaml │ └── atco2-asr-prompt.yaml └── model │ ├── whisper-en.yaml │ ├── whisper-zhtw.yaml │ ├── gfd-asr-en.yaml │ └── gfd-asr-zhtw.yaml ├── assets └── teaser.png ├── setup.py ├── requirements.txt ├── benchmarks ├── run_single_file.py ├── run_benchmark.py └── calculate_mer.py ├── .gitignore ├── LICENSE └── README.md /gfd/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config_files/prompt/fleurs-hk-prompt.yaml: -------------------------------------------------------------------------------- 1 | asr_prompt: '以下是繁體中文轉錄文檔' 2 | llm_prompt: '以下是繁體中文轉錄文檔:' -------------------------------------------------------------------------------- /config_files/prompt/formosa-long-prompt.yaml: -------------------------------------------------------------------------------- 1 | asr_prompt: '以下是繁體中文轉錄文檔' 2 | llm_prompt: '以下是繁體中文轉錄文檔:' -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtkresearch/generative-fusion-decoding/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /config_files/model/whisper-en.yaml: -------------------------------------------------------------------------------- 1 | asr_model_path: 'openai/whisper-large-v2' 2 | asr_device: 'cuda:0' 3 | lang: 'en' 4 | -------------------------------------------------------------------------------- /config_files/model/whisper-zhtw.yaml: -------------------------------------------------------------------------------- 1 | asr_model_path: 'openai/whisper-large-v2' 2 | asr_device: 'cuda:0' 3 | lang: 'zh' 4 | -------------------------------------------------------------------------------- /config_files/prompt/ml-lecture-prompt.yaml: -------------------------------------------------------------------------------- 1 | asr_prompt: '繁體中文' 2 | llm_prompt: '以下是繁體中文轉錄文檔,講者的講解存在code-switching,有些詞彙的語言是英文:' -------------------------------------------------------------------------------- /config_files/prompt/noisy-librispeech-prompt.yaml: -------------------------------------------------------------------------------- 1 | asr_prompt: '' 2 | llm_prompt: 'The following is a transcription of a spoken sentence:' -------------------------------------------------------------------------------- /config_files/prompt/ml-lecture-long-prompt.yaml: -------------------------------------------------------------------------------- 1 | asr_prompt: '以下是繁體中文轉錄文檔,講者的講解存在code-switching,有些詞彙的語言是英文' 2 | llm_prompt: '以下是繁體中文轉錄文檔,講者的講解存在code-switching,有些詞彙的語言是英文:' -------------------------------------------------------------------------------- /config_files/prompt/atco2-asr-only-prompt.yaml: -------------------------------------------------------------------------------- 1 | asr_prompt: "Alfa Bravo Charlie Delta Echo Foxtrot Golf Hotel India Juliett Kilo Lima Mike November Oscar Papa Quebec Romeo Sierra Tango Uniform Victor Whiskey Xray Yankee Zulu One Two Three Four Five Six Seven Eight Nine Zero\nDayton radio, November One Two Three Four Five on one two two point two, over Springfield V-O-R, over.\nNew York Radio, Mooney Three One One Echo.\nColumbia Ground, Cessna Three One Six Zero Foxtrot, south ramp, I-F-R Memphis." 2 | llm_prompt: '' -------------------------------------------------------------------------------- /gfd/utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | from collections import namedtuple 4 | 5 | def process_config(file_path, args=argparse.Namespace()): 6 | with open(file_path, 'r') as file: 7 | config_dict = yaml.safe_load(file) 8 | 9 | for arg in vars(args): 10 | if getattr(args, arg) is not None and arg in config_dict.keys(): 11 | config_dict[arg] = getattr(args, arg) 12 | Config = namedtuple('Config', config_dict.keys()) 13 | return Config(**config_dict) 14 | 15 | def combine_config(config1, config2): 16 | CombinedConfig = namedtuple('Config', config1._fields + config2._fields) 17 | return CombinedConfig(*(config1+config2)) -------------------------------------------------------------------------------- /config_files/model/gfd-asr-en.yaml: -------------------------------------------------------------------------------- 1 | asr_model_path: 'openai/whisper-large-v2' 2 | llm_model_path: 'mistralai/Mistral-7B-v0.1' 3 | lang: 'en' 4 | asr_device: 'cuda:0' 5 | llm_device: 'cuda:0' 6 | 7 | seg_with_overlap: True 8 | use_cache: 'dynamic' 9 | fuse_strategy: 'simple' 10 | fusing_r: 0.2 11 | asr_attn_implementation: 'sdpa' 12 | llm_attn_implementation: NULL 13 | llm_temp: 1.7 14 | transcription_cutoff: 4000 15 | 16 | repetition_penalty: 2.0 17 | repetition_penalty_last: 50 18 | repetition_penalty_window: 50 19 | repetition_penalty_threshold: 1.0 20 | beam_terminated_strategy: 'when_all_end' 21 | beam_select_strategy: 'best' 22 | beam_max_decode_len: 448 23 | beam_max_len_diff: 20 24 | beam_max_len: -1 25 | beam_min_len: 9999 26 | logprob_min: -100000 27 | -------------------------------------------------------------------------------- /config_files/model/gfd-asr-zhtw.yaml: -------------------------------------------------------------------------------- 1 | asr_model_path: 'openai/whisper-large-v2' 2 | llm_model_path: 'MediaTek-Research/Breeze-7B-32k-Base-v1_0' 3 | lang: 'zh' 4 | asr_device: 'cuda:0' 5 | llm_device: 'cuda:0' 6 | 7 | seg_with_overlap: False 8 | use_cache: 'dynamic' 9 | fuse_strategy: 'simple' 10 | fusing_r: 0.2 11 | asr_attn_implementation: 'sdpa' 12 | llm_attn_implementation: Null 13 | llm_temp: 1.7 14 | transcription_cutoff: 500 15 | 16 | repetition_penalty: 2.0 17 | repetition_penalty_last: 50 18 | repetition_penalty_window: 50 19 | repetition_penalty_threshold: 1.0 20 | beam_terminated_strategy: 'when_all_end' 21 | beam_select_strategy: 'best' 22 | beam_max_decode_len: 448 23 | beam_max_len_diff: 20 24 | beam_max_len: -1 25 | beam_min_len: 9999 26 | logprob_min: -100000 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name='gfd', 5 | version='0.0.1', 6 | license='Apache License 2.0', 7 | author='', 8 | author_email='', 9 | description='mtkresearch', 10 | long_description=open('README.md').read(), 11 | long_description_content_type='text/markdown', 12 | url='', 13 | packages=setuptools.find_packages(), 14 | python_requires='>=3.6', 15 | include_package_data=True, 16 | classifiers=[ 17 | 'Development Status :: 4 - Beta', 18 | 'Intended Audience :: Developers', 19 | 'Intended Audience :: Science/Research', 20 | 'License :: OSI Approved :: Apache Software License', 21 | 'Operating System :: OS Independent', 22 | 'Programming Language :: Python :: 3', 23 | 'Programming Language :: Python :: 3.6', 24 | 'Programming Language :: Python :: 3.7', 25 | ], 26 | install_requires=[ 27 | ], 28 | extras_require={ 29 | } 30 | ) -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | 5 | import torch 6 | import numpy as np 7 | from transformers import RepetitionPenaltyLogitsProcessor 8 | import unittest 9 | from unittest.mock import patch, MagicMock 10 | 11 | from gfd.tokenizer import LlamaByteTokenizer 12 | from gfd.model import BreezeByte 13 | 14 | class TestBreezeByte(unittest.TestCase): 15 | @patch('gfd.model.AutoModelForCausalLM.from_pretrained') 16 | def test_get_logprob(self, MockModel): 17 | # Mock the tokenizer 18 | mock_tokenizer = MagicMock() 19 | mock_tokenizer.get_alternative_ids.return_value = [[0], [1], [2], [3], [4]] 20 | 21 | # Mock the model 22 | mock_model = MockModel.return_value 23 | mock_model.device = "cuda:0" 24 | mock_model.return_dict = True 25 | 26 | mock_logits = torch.randn(1, 5, 10).to(mock_model.device) 27 | mock_model.return_value.logits = mock_logits 28 | 29 | # Mock config 30 | mock_config = MagicMock() 31 | mock_config.llm_model_path = 'fakepath/to/model' 32 | mock_config.llm_device = 'cuda:0' 33 | mock_config.llm_attn_implementation = 'default' 34 | mock_config.llm_temp = 1.0 35 | mock_config.repetition_penalty = 1.2 36 | mock_config.repetition_penalty_threshold = 1.1 37 | mock_config.repetition_penalty_last = 2 38 | mock_config.repetition_penalty_window = 1 39 | 40 | # Initialize BreezeByte with the mock config 41 | breeze_byte = BreezeByte(mock_config) 42 | breeze_byte.llm = mock_model 43 | 44 | prefix_decoding_ids = [1, 2] 45 | llm_ids = [3, 4, 5] 46 | 47 | logprob = breeze_byte.get_logprob(prefix_decoding_ids, llm_ids, mock_tokenizer) 48 | 49 | self.assertIsInstance(logprob, float) 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.29.3 2 | aiohttp==3.9.5 3 | aiosignal==1.3.1 4 | asttokens==2.4.1 5 | async-timeout==4.0.3 6 | attrs==23.2.0 7 | audioread==3.0.1 8 | certifi==2024.2.2 9 | cffi==1.16.0 10 | charset-normalizer==3.3.2 11 | click==8.1.7 12 | cmake==3.29.2 13 | decorator==5.1.1 14 | dill==0.3.8 15 | einops==0.7.0 16 | exceptiongroup==1.2.1 17 | executing==2.0.1 18 | filelock==3.13.4 19 | frozenlist==1.4.1 20 | fsspec==2024.3.1 21 | huggingface-hub==0.22.2 22 | idna==3.7 23 | # ipython==8.24.0 24 | jedi==0.19.1 25 | Jinja2==3.1.3 26 | jiwer==3.0.4 27 | joblib==1.4.0 28 | lazy_loader==0.4 29 | librosa==0.10.1 30 | lit==18.1.4 31 | # llvmlite==0.42.0 32 | MarkupSafe==2.1.5 33 | matplotlib-inline==0.1.7 34 | more-itertools==10.2.0 35 | mpmath==1.3.0 36 | msgpack==1.0.8 37 | multidict==6.0.5 38 | multiprocess==0.70.16 39 | # networkx==3.3 40 | ninja==1.11.1.1 41 | # numba==0.59.1 42 | numpy<2.0.0 43 | openai-whisper==20231117 44 | opencc-python-reimplemented==0.1.7 45 | packaging==24.0 46 | pandas==2.2.2 47 | parso==0.8.4 48 | pexpect==4.9.0 49 | platformdirs==4.2.1 50 | pooch==1.8.1 51 | prompt-toolkit==3.0.43 52 | psutil==5.9.8 53 | ptyprocess==0.7.0 54 | pure-eval==0.2.2 55 | pycparser==2.22 56 | Pygments==2.18.0 57 | python-dateutil==2.9.0.post0 58 | pytz==2024.1 59 | PyYAML==6.0.1 60 | rapidfuzz==3.9.0 61 | regex==2024.4.16 62 | safetensors==0.4.3 63 | scikit-learn==1.4.2 64 | scipy==1.13.0 65 | sentencepiece==0.2.0 66 | six==1.16.0 67 | sortedcontainers==2.4.0 68 | soundfile==0.12.1 69 | soxr==0.3.7 70 | stack-data==0.6.3 71 | sympy==1.12 72 | threadpoolctl==3.4.0 73 | tiktoken==0.6.0 74 | tokenizers==0.19.1 75 | torch 76 | tqdm==4.66.2 77 | traitlets==5.14.3 78 | transformers==4.40.1 79 | triton==2.1.0 80 | typing_extensions==4.11.0 81 | tzdata==2024.1 82 | urllib3==2.2.1 83 | wcwidth==0.2.13 84 | xxhash==3.4.1 85 | yarl==1.9.4 86 | datasets==2.19.1 87 | requests 88 | -------------------------------------------------------------------------------- /benchmarks/run_single_file.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import json 4 | 5 | from gfd.gfd import Breezper 6 | from gfd.utils import process_config, combine_config 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser(description="Override config settings with command-line arguments.") 10 | parser.add_argument('--model_name', type=str, help='The model for testing the benchmark dataset') 11 | parser.add_argument('--setting', type=str, help='benchmark dataset settings for specified model') 12 | parser.add_argument('--audio_file_path', type=str, help='Path to the audio file sample') 13 | parser.add_argument('--result_output_path', type=str, help='Path to save dataset with predictions from the model') 14 | return parser.parse_args() 15 | 16 | def main(): 17 | setting_configs = { 18 | 'gfd': {'asr-en': process_config('config_files/model/gfd-asr-en.yaml'), 19 | 'asr-zhtw': process_config('config_files/model/gfd-asr-zhtw.yaml'), 20 | 'asr-en-lmoff': process_config('config_files/model/gfd-asr-en.yaml', args=argparse.Namespace(**{'fusing_r': 0.0})), 21 | 'asr-zhtw-lmoff': process_config('config_files/model/gfd-asr-zhtw.yaml', args=argparse.Namespace(**{'fusing_r': 0.0})) 22 | } 23 | } 24 | args = parse_args() 25 | setting_config = setting_configs[args.model_name][args.setting] 26 | if "en" in args.setting: 27 | prompt_config = process_config('config_files/prompt/noisy-librispeech-prompt.yaml') 28 | elif "zhtw" in args.setting: 29 | prompt_config = process_config('config_files/prompt/formosa-long-prompt.yaml') 30 | combined_config = combine_config(prompt_config, setting_config) 31 | 32 | model = Breezper(combined_config) 33 | result = model.get_transcription(args.audio_file_path, asr_prompt=combined_config.asr_prompt, llm_prompt=combined_config.llm_prompt) 34 | print(f'Result: {result}') 35 | 36 | with open(args.result_output_path, 'w') as f: 37 | json.dump(result, f, ensure_ascii=False) 38 | 39 | if __name__== '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /tests/test_kv_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import unittest 4 | from unittest.mock import MagicMock 5 | from transformers import AutoModelForCausalLM, Cache, DynamicCache 6 | 7 | from gfd.tokenizer import LlamaByteTokenizer 8 | 9 | class TestKVCache(unittest.TestCase): 10 | def setUp(self): 11 | self.device = 'cuda:0' 12 | self.llm = AutoModelForCausalLM.from_pretrained( 13 | 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T', 14 | device_map=self.device, 15 | torch_dtype=torch.float16, 16 | attn_implementation=None 17 | ) 18 | 19 | def test_kv_cache(self): 20 | with torch.no_grad(): 21 | ids_for_test = [1, 29537, 886, 275, 90, 154, 4548, 1154 , 4354, 3484 , 4844, 4848, 4] 22 | 23 | # expected logits 24 | expected_logits = self.llm( 25 | input_ids=torch.tensor([ids_for_test], device=self.device), 26 | use_cache=True, 27 | return_dict=True).logits 28 | 29 | cache = DynamicCache() 30 | step = 2 31 | for i in range(step, len(ids_for_test), step): 32 | model_kwargs = self.llm.prepare_inputs_for_generation( 33 | torch.tensor([ids_for_test[:i]], device=self.device), 34 | past_key_values=cache, 35 | attention_mask=None, 36 | inputs_embeds=None, 37 | cache_position=None, 38 | use_cache=True, 39 | ) 40 | 41 | outputs = self.llm(**model_kwargs) 42 | cache = outputs.past_key_values 43 | last_logits = outputs.logits 44 | 45 | expected_last_logits = expected_logits[:,i-step:i,:] 46 | 47 | torch.testing.assert_close(last_logits.cpu().argmax(-1), expected_last_logits.cpu().argmax(-1), atol=1e-6, rtol=0.0) 48 | torch.testing.assert_close(torch.topk(last_logits.cpu(), 2), torch.topk(expected_last_logits.cpu(), 2), atol=1e-6, rtol=0.0) 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /gfd/beam.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import numpy as np 4 | 5 | 6 | DecodingBeam = namedtuple( 7 | 'DecodingBeam', 8 | ['asr_score', 'llm_score', 'fuse_score', 'reach_end', 9 | 'asr_ids', 'llm_ids', 'asr_prefix_ids', 'llm_prefix_ids', 'asr_logprob', 'llm_logprob'] 10 | ) 11 | 12 | class BeamsControler: 13 | def __init__(self, config, n_beam, asr_eos_id): 14 | self.config = config 15 | self.n_beam = n_beam 16 | self.asr_eos_id = asr_eos_id 17 | 18 | self.beams = [] 19 | self._next_beams = [] 20 | 21 | def list(self): 22 | return self.beams 23 | 24 | def add(self, asr_score, llm_score, fuse_score, 25 | asr_prefix_ids, asr_ids, asr_logprob, 26 | llm_prefix_ids, llm_ids, llm_logprob): 27 | 28 | reach_end = (asr_ids[-1] == self.asr_eos_id) or \ 29 | len(asr_prefix_ids + asr_ids) > self.config.beam_max_decode_len 30 | 31 | beam = DecodingBeam( 32 | asr_score=asr_score, 33 | llm_score=llm_score, 34 | fuse_score=fuse_score, 35 | reach_end=reach_end, 36 | asr_ids=asr_ids, 37 | llm_ids=llm_ids, 38 | asr_prefix_ids=asr_prefix_ids, 39 | llm_prefix_ids=llm_prefix_ids, 40 | asr_logprob=asr_logprob, 41 | llm_logprob=llm_logprob 42 | ) 43 | self._next_beams.append(beam) 44 | 45 | def add_beam(self, beam): 46 | self._next_beams.append(beam) 47 | 48 | def update(self): 49 | self.beams = sorted(self._next_beams, key=lambda beam: beam.fuse_score, reverse=True)[:self.n_beam] 50 | self._next_beams = [] 51 | 52 | def is_terminated(self): 53 | min_len = self.config.beam_min_len 54 | max_len = self.config.beam_max_len 55 | for beam in self.beams: 56 | min_len = min(min_len, len(beam.asr_ids)) 57 | max_len = max(max_len, len(beam.asr_ids)) 58 | if max_len - min_len > self.config.beam_max_len_diff: 59 | return True 60 | 61 | if self.config.beam_terminated_strategy == 'when_all_end': 62 | return all([beam.reach_end for beam in self.beams]) 63 | else: 64 | raise NotImplementedError() 65 | 66 | def get_result(self, asr_tokenizer): 67 | if self.config.beam_select_strategy == 'best': 68 | transcription = asr_tokenizer.decode(self.beams[0].asr_ids, skip_special_tokens=True) 69 | elif self.config.beam_select_strategy == 'longest': 70 | max_len = 0 71 | transcription = '' 72 | for beam in self.beams: 73 | tmp = asr_tokenizer.decode(beam.asr_ids, skip_special_tokens=True) 74 | if len(tmp) > max_len: 75 | transcription = tmp 76 | max_len = len(tmp) 77 | else: 78 | raise NotImplementedError() 79 | return transcription 80 | -------------------------------------------------------------------------------- /tests/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from collections import defaultdict 3 | 4 | from gfd.tokenizer import LlamaByteTokenizer 5 | 6 | @pytest.fixture(scope="class") 7 | def byte_tokenizer(): 8 | return LlamaByteTokenizer.from_pretrained("MediaTek-Research/Breeze-7B-Instruct-v1_0") 9 | 10 | class TestLlamaByteTokenizer: 11 | def test_tokenize_from_byte(self, byte_tokenizer): 12 | byte_str = b'hello world' 13 | expected_ids = byte_tokenizer.encode('hello world', add_special_tokens=False) 14 | result = byte_tokenizer.tokenize_from_byte(byte_str) 15 | 16 | assert result == expected_ids 17 | 18 | def test_convert_ids_to_bytes_english_without_special_characters(self, byte_tokenizer): 19 | input_ids = [6312, 28709, 1526] # hello world 20 | tokens = byte_tokenizer.convert_ids_to_tokens(input_ids) 21 | expected_bytes_lists = [b' hell', b'o', b' world'] 22 | result = byte_tokenizer.convert_ids_to_bytes(input_ids) 23 | 24 | assert result == expected_bytes_lists 25 | 26 | def test_convert_ids_to_bytes_english_with_special_characters(self, byte_tokenizer): 27 | # Original Input String: Hello! I stayed up late last night and I felf like dying...? 28 | input_ids = [22557, 28808, 315, 10452, 582, 3909, 1432, 2125, 304, 315, 2770, 737, 13074, 1101, 28804] 29 | tokens = byte_tokenizer.convert_ids_to_tokens(input_ids) 30 | expected_bytes_lists = [b' Hello', b'!', b' I', b' stayed', b' up', b' late', b' last', b' night', 31 | b' and', b' I', b' felt', b' like', b' dying', b'...', b'?'] 32 | result = byte_tokenizer.convert_ids_to_bytes(input_ids) 33 | 34 | assert result == expected_bytes_lists 35 | 36 | 37 | def test_convert_ids_to_bytes_chinese_without_special_characters(self, byte_tokenizer): 38 | # Original Input String: 我每天都好累 39 | input_ids = [28705, 29242, 29513, 43136, 51557, 31719] 40 | tokens = byte_tokenizer.convert_ids_to_tokens(input_ids) 41 | expected_bytes_lists = [s.encode('utf-8') for s in tokens] 42 | result = byte_tokenizer.convert_ids_to_bytes(input_ids) 43 | 44 | assert result == expected_bytes_lists 45 | 46 | def test_convert_ids_to_bytes_chinese_with_special_characters(self, byte_tokenizer): 47 | # Original Input String: # '最近梅雨季,一直下雨真的超煩...!!!!別下了,真心拜託。' 48 | input_ids = [28705, 42529, 31223, 31115, 31740, 28924, 42405, 48282, 42398, 29800, 33781, 1101, 19010, 49 | 30798, 46562, 28924, 45930, 47542, 28944] 50 | tokens = byte_tokenizer.convert_ids_to_tokens(input_ids) 51 | expected_bytes_lists = [s.encode('utf-8') for s in tokens] 52 | result = byte_tokenizer.convert_ids_to_bytes(input_ids) 53 | 54 | assert result == expected_bytes_lists 55 | 56 | def test_get_matched_ids_from_prefix_result_matched(self, byte_tokenizer): 57 | byte_prefix_match = b'\xe8\x9f\x8b' 58 | matched_ids = byte_tokenizer.get_matched_ids_from_prefix(byte_prefix_match) 59 | expected_matched_ids = [61871] 60 | 61 | assert matched_ids == expected_matched_ids 62 | 63 | def test_get_matched_ids_from_prefix_result_no_matched(self, byte_tokenizer): 64 | byte_prefix_no_match = b'\xe3\x96\x87' 65 | matched_ids = byte_tokenizer.get_matched_ids_from_prefix(byte_prefix_no_match) 66 | expected_matched_ids = [] 67 | 68 | assert matched_ids == expected_matched_ids 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /gfd/tokenizer.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | from transformers import LlamaTokenizerFast, WhisperTokenizer 5 | from transformers.models.llama.tokenization_llama import SPIECE_UNDERLINE 6 | 7 | 8 | class ByteTokenizer: 9 | def tokenize_from_byte(self, byte_str): 10 | str_part = byte_str.decode('utf8', errors='ignore') 11 | return self(str_part, add_special_tokens=False).input_ids 12 | 13 | def convert_ids_to_bytes(self, ids): 14 | raise NotImplementedError 15 | 16 | def get_matched_ids_from_prefix(self, byte_prefix): 17 | if not hasattr(self, '_prefix_to_ids'): 18 | self._prefix_to_ids = defaultdict(list) 19 | for i in range(self.vocab_size): 20 | b = self.convert_ids_to_bytes(i) 21 | for j in range(1,len(b)): 22 | self._prefix_to_ids[b[:j]].append(i) 23 | 24 | return self._prefix_to_ids.get(byte_prefix, []) 25 | 26 | def get_alternative_ids(self, seq_ids): 27 | alternative_ids = [None] * len(seq_ids) 28 | prefix_from_last = b'' 29 | pointer_from_last = 1 30 | while pointer_from_last <= len(seq_ids): 31 | prefix_from_last = self.convert_ids_to_bytes(seq_ids[-pointer_from_last]) + prefix_from_last 32 | alternative_ids[-pointer_from_last] = self.get_matched_ids_from_prefix(prefix_from_last) 33 | pointer_from_last += 1 34 | 35 | return alternative_ids 36 | 37 | 38 | class LlamaByteTokenizer(LlamaTokenizerFast, ByteTokenizer): 39 | def __init__(self, *args, **kwargs): 40 | super().__init__(*args, **kwargs) 41 | self.bytetokens_to_ids = {} 42 | for s,i in self.get_vocab().items(): 43 | b = self._convert_token_to_byte(s) 44 | if b in self.bytetokens_to_ids: 45 | if self.bytetokens_to_ids[b] < i: 46 | self.bytetokens_to_ids[b] = i 47 | else: 48 | self.bytetokens_to_ids[b] = i 49 | 50 | def convert_ids_to_bytes(self, ids): 51 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=False) 52 | if isinstance(tokens, str): 53 | return self._convert_token_to_byte(tokens) 54 | return [self._convert_token_to_byte(t) for t in tokens] 55 | 56 | def _convert_token_to_byte(self, token): 57 | SPIECE_UNDERLINE = "▁" 58 | if token.startswith(SPIECE_UNDERLINE) and len(token) > 1: 59 | token = " " + token.lstrip(SPIECE_UNDERLINE) 60 | 61 | if token.startswith("<0x"): # '<0xAB>' -> 'AB' -> b'\xAB' 62 | bs = bytes.fromhex(f'{token[3:5]}') 63 | else: 64 | bs = token.encode("utf8") 65 | return bs 66 | 67 | def tokenize_from_byte(self, byte_str): 68 | str_part = byte_str.decode('utf8', errors='ignore') 69 | encoded_str_part = str_part.encode('utf8') 70 | 71 | str_part_tokenized = self(str_part, add_special_tokens=False).input_ids 72 | leftover_string = byte_str[len(encoded_str_part):] 73 | for byte_int in leftover_string: 74 | byte_character = bytes([byte_int]) 75 | str_part_tokenized.append(self.bytetokens_to_ids[byte_character]) 76 | 77 | return str_part_tokenized 78 | 79 | 80 | class WhisperByteTokenizer(WhisperTokenizer, ByteTokenizer): 81 | def __init__(self, *args, **kwargs): 82 | super().__init__(*args, **kwargs) 83 | 84 | def convert_ids_to_bytes(self, ids, skip_special_tokens=True): 85 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) 86 | return [bytes([self.byte_decoder[c] for c in s]) for s in tokens] 87 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /tests/test_beam.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import numpy as np 4 | import unittest 5 | from unittest.mock import MagicMock 6 | 7 | from gfd.beam import BeamsControler, DecodingBeam 8 | 9 | 10 | class MockTokenizer: 11 | def decode(self, ids, skip_special_tokens=True): 12 | return ''.join(map(str, ids)) 13 | 14 | class TestBeamsControler(unittest.TestCase): 15 | def setUp(self): 16 | # Create a mock config with necessary attributes 17 | self.mock_config = MagicMock() 18 | self.mock_config.beam_max_decode_len = 10 19 | self.mock_config.beam_min_len = 1 20 | self.mock_config.beam_max_len = 10 21 | self.mock_config.beam_max_len_diff = 20 22 | self.mock_config.beam_terminated_strategy = 'when_all_end' 23 | self.mock_config.beam_select_strategy = 'best' 24 | 25 | # Initialize BeamsControler with the mock config 26 | self.controller = BeamsControler(config=self.mock_config, n_beam=3, asr_eos_id=1) 27 | 28 | def test_add(self): 29 | self.controller.add( 30 | asr_score=0.5, llm_score=0.5, fuse_score=1.0, 31 | asr_prefix_ids=[1, 2], asr_ids=[3, 4], asr_logprob=-10, 32 | llm_prefix_ids=[5, 6], llm_ids=[7, 8], llm_logprob=-20 33 | ) 34 | 35 | beam = self.controller._next_beams[0] 36 | assert len(self.controller._next_beams) == 1 37 | assert beam.asr_score == 0.5 38 | assert beam.llm_score == 0.5 39 | assert beam.fuse_score == 1.0 40 | assert beam.reach_end == False 41 | 42 | def test_add_beam(self): 43 | beam = DecodingBeam( 44 | asr_score=0.5, llm_score=0.5, fuse_score=1.0, reach_end=False, 45 | asr_ids=[3, 4], llm_ids=[7, 8], asr_prefix_ids=[1, 2], llm_prefix_ids=[5, 6], 46 | asr_logprob=-10, llm_logprob=-20 47 | ) 48 | 49 | self.controller.add_beam(beam) 50 | 51 | assert len(self.controller._next_beams) == 1 52 | assert self.controller._next_beams[0] == beam 53 | 54 | def test_update(self): 55 | beam1 = DecodingBeam(asr_score=0.6, llm_score=0.4, fuse_score=0.5, reach_end=False, 56 | asr_ids=[1, 2], llm_ids=[3, 4], asr_prefix_ids=[], llm_prefix_ids=[], 57 | asr_logprob=-10, llm_logprob=-20) 58 | beam2 = DecodingBeam(asr_score=0.7, llm_score=0.5, fuse_score=0.6, reach_end=False, 59 | asr_ids=[1, 2], llm_ids=[3, 4], asr_prefix_ids=[], llm_prefix_ids=[], 60 | asr_logprob=-10, llm_logprob=-20) 61 | self.controller._next_beams = [beam1, beam2] 62 | 63 | self.controller.update() 64 | 65 | assert len(self.controller.beams) == 2 66 | assert self.controller.beams[0] == beam2 67 | assert self.controller.beams[1] == beam1 68 | 69 | def test_is_terminated_is_True_when_all_beams_reach_end(self): 70 | beam1 = DecodingBeam(0.5, 0.5, 1.0, True, [1, 2, 3], [1, 2, 3], [], [], -10, -10) 71 | beam2 = DecodingBeam(0.6, 0.4, 1.0, True, [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [], [], -10, -10) 72 | beam3 = DecodingBeam(0.7, 0.3, 1.0, True, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [], [], -10, -10) 73 | self.controller.add_beam(beam1) 74 | self.controller.add_beam(beam2) 75 | self.controller.add_beam(beam3) 76 | 77 | self.controller.update() 78 | 79 | assert self.controller.is_terminated() == True 80 | 81 | def test_is_terminated_is_True_exceed_length_diff(self): 82 | beam1 = DecodingBeam(0.5, 0.5, 1.0, False, [1, 2, 3], [1, 2, 3], [], [], -10, -10) 83 | beam2 = DecodingBeam(0.6, 0.4, 1.0, False, [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [], [], -10, -10) 84 | beam3 = DecodingBeam(0.7, 0.3, 1.0, False, [i for i in range(1, 30)], [i for i in range(1, 30)], [], [], -10, -10) 85 | self.controller.add_beam(beam1) 86 | self.controller.add_beam(beam2) 87 | self.controller.add_beam(beam3) 88 | 89 | self.controller.update() 90 | 91 | assert self.controller.is_terminated() == True 92 | 93 | def test_is_terminated_is_False(self): 94 | beam1 = DecodingBeam(0.5, 0.5, 1.0, False, [1, 2, 3], [1, 2, 3], [], [], -10, -10) 95 | beam2 = DecodingBeam(0.6, 0.4, 1.0, False, [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [], [], -10, -10) 96 | beam3 = DecodingBeam(0.7, 0.3, 1.0, False, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [], [], -10, -10) 97 | self.controller.add_beam(beam1) 98 | self.controller.add_beam(beam2) 99 | self.controller.add_beam(beam3) 100 | 101 | self.controller.update() 102 | 103 | assert self.controller.is_terminated() == False 104 | 105 | def test_get_result_select_strategy_best(self): 106 | self.mock_config.beam_select_strategy = 'best' 107 | beam1 = DecodingBeam(0.5, 0.5, 0.7, True, [1, 2, 3], [1, 2, 3], [], [], -10, -10) 108 | beam2 = DecodingBeam(0.6, 0.4, 0.9, True, [4, 5, 6], [4, 5, 6], [], [], -10, -10) 109 | beam3 = DecodingBeam(0.7, 0.3, 0.8, True, [7, 8, 9], [7, 8, 9], [], [], -10, -10) 110 | self.controller.add_beam(beam1) 111 | self.controller.add_beam(beam2) 112 | self.controller.add_beam(beam3) 113 | self.controller.update() 114 | 115 | mock_asr_tokenizer = MockTokenizer() 116 | result = self.controller.get_result(mock_asr_tokenizer) 117 | 118 | assert result == '456' 119 | 120 | def test_get_result_select_strategy_longest(self): 121 | self.mock_config.beam_select_strategy = 'longest' 122 | beam1 = DecodingBeam(0.5, 0.5, 1.0, True, [1, 2, 3], [1, 2, 3], [], [], -10, -10) 123 | beam2 = DecodingBeam(0.6, 0.4, 1.0, True, [4, 5, 6, 7, 8], [4, 5, 6, 7, 8], [], [], -10, -10) 124 | beam3 = DecodingBeam(0.7, 0.3, 1.0, True, [7, 8, 9], [7, 8, 9], [], [], -10, -10) 125 | 126 | self.controller.add_beam(beam1) 127 | self.controller.add_beam(beam2) 128 | self.controller.add_beam(beam3) 129 | self.controller.update() 130 | 131 | mock_asr_tokenizer = MockTokenizer() 132 | result = self.controller.get_result(mock_asr_tokenizer) 133 | 134 | assert result == '45678' 135 | 136 | def test_get_result_select_strategy_unsupported(self): 137 | self.mock_config.beam_select_strategy = 'unsupported' 138 | beam1 = DecodingBeam(0.5, 0.5, 1.0, True, [1, 2, 3], [1, 2, 3], [], [], -10, -10) 139 | self.controller.add_beam(beam1) 140 | self.controller.update() 141 | mock_asr_tokenizer = MockTokenizer() 142 | 143 | with self.assertRaises(NotImplementedError): 144 | self.controller.get_result(mock_asr_tokenizer) -------------------------------------------------------------------------------- /benchmarks/run_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import argparse 5 | import subprocess 6 | from copy import deepcopy 7 | 8 | from datasets import load_dataset, load_from_disk 9 | import numpy as np 10 | from transformers import pipeline 11 | 12 | from gfd.gfd import Breezper 13 | from gfd.utils import process_config, combine_config 14 | 15 | class Evaluator: 16 | def __init__(self, model_name, config, transcription_column_name, temp_output_dir): 17 | self.config = config 18 | self.transcription_column_name = transcription_column_name 19 | self.temp_output_dir = temp_output_dir 20 | self.model_name = model_name 21 | 22 | if self.model_name == 'gfd': 23 | self.model = Breezper(config) 24 | if self.model_name == 'whisper': 25 | self.model = pipeline(task='automatic-speech-recognition', 26 | model=self.config.asr_model_path, 27 | device=self.config.asr_device 28 | ) 29 | 30 | def evaluate(self, example, idx): 31 | if self.model_name == 'gfd': 32 | return self.get_breezper_prediction(example, idx) 33 | elif self.model_name == 'whisper': 34 | return self.get_whisper_beam_prediction(example, idx) 35 | else: 36 | raise ValueError("Unsupported model type. Please choose 'gfd' or 'whisper'.") 37 | 38 | def get_breezper_prediction(self, example, idx): 39 | if os.path.exists(os.path.join(self.temp_output_dir, f'prediction_{idx}.json')): 40 | with open(os.path.join(self.temp_output_dir, f'prediction_{idx}.json'), 'r') as f: 41 | js = json.load(f) 42 | example['prediction'] = js["prediction"] 43 | return example 44 | 45 | breezper_transcription = self.model.get_transcription(example['audio']['array'], asr_prompt=self.config.asr_prompt, 46 | llm_prompt=self.config.llm_prompt) 47 | example['prediction'] = breezper_transcription 48 | result = { 49 | 'id': idx, 50 | 'transcription': example[self.transcription_column_name], 51 | 'prediction': breezper_transcription 52 | } 53 | 54 | with open(os.path.join(self.temp_output_dir, f'prediction_{idx}.json'), 'w') as f: 55 | json.dump(result, f, ensure_ascii=False) 56 | 57 | return example 58 | 59 | def get_whisper_beam_prediction(self, example, idx): 60 | whisper_transcription = self.model(example['audio'], generate_kwargs={'task': 'transcribe', 61 | 'num_beams': 5, 'language':f'<|{self.config.lang}|>'})['text'] 62 | example['prediction'] = whisper_transcription 63 | result = { 64 | 'id': idx, 65 | 'transcription': example[self.transcription_column_name], 66 | 'prediction': whisper_transcription 67 | } 68 | 69 | with open(os.path.join(self.temp_output_dir, f'prediction_{idx}.json'), 'w') as f: 70 | json.dump(result, f, ensure_ascii=False) 71 | 72 | return example 73 | 74 | 75 | def test_benchmark(ds, model_name, config, transcription_column_name, output_dir): 76 | temp_result_dir = f'{output_dir}/temp_results' 77 | if not os.path.exists(output_dir): 78 | os.makedirs(output_dir) 79 | os.makedirs(temp_result_dir) 80 | evaluator = Evaluator(model_name, config, transcription_column_name, temp_result_dir) 81 | ds = ds.map(evaluator.evaluate, with_indices=True) 82 | ds.save_to_disk(os.path.join(output_dir, 'ds_result')) 83 | 84 | 85 | def parse_args(): 86 | parser = argparse.ArgumentParser(description="Run Benchmark datasets.") 87 | parser.add_argument('--dataset_name', type=str, help='The benchmark dataset for testing') 88 | parser.add_argument('--model_name', type=str, help='The model for testing the benchmark dataset') 89 | parser.add_argument('--setting', type=str, help='benchmark dataset settings for specified model') 90 | parser.add_argument('--output_dir', type=str, default='results/', help='Directory to save results of the model output') 91 | 92 | return parser.parse_args() 93 | 94 | def main(): 95 | setting_configs = { 96 | 'gfd': {'asr-en': process_config('config_files/model/gfd-asr-en.yaml'), 97 | 'asr-zhtw': process_config('config_files/model/gfd-asr-zhtw.yaml'), 98 | 'asr-en-lmoff': process_config('config_files/model/gfd-asr-en.yaml', args=argparse.Namespace(**{'fusing_r': 0.0})), 99 | 'asr-zhtw-lmoff': process_config('config_files/model/gfd-asr-zhtw.yaml', args=argparse.Namespace(**{'fusing_r': 0.0})) 100 | }, 101 | 'whisper': { 'whisper-en': process_config('config_files/model/whisper-en.yaml'), 102 | 'whisper-zhtw': process_config('config_files/model/whisper-zhtw.yaml') 103 | } 104 | } 105 | 106 | args = parse_args() 107 | 108 | setting_config = setting_configs[args.model_name][args.setting] 109 | if args.dataset_name == 'ml-lecture-2021-long': 110 | prompt_config = process_config('config_files/prompt/ml-lecture-long-prompt.yaml') 111 | combined_config = combine_config(prompt_config, setting_config) 112 | ds = load_dataset('generative-fusion-decoding/ml-lecture-2021-long', split='test') # Use Internet Download 113 | # ds = load_from_disk('../benchmark_dataset/ml_lecture_long', keep_in_memory = True) 114 | transcription_column_name = 'transcription' 115 | elif args.dataset_name == 'formosa-long': 116 | prompt_config = process_config('config_files/prompt/formosa-long-prompt.yaml') 117 | combined_config = combine_config(prompt_config, setting_config) 118 | ds = load_dataset('Mediatek-Research/formosaspeech', split='test') # Use Internet Download 119 | # ds = load_from_disk('../../../data/speech/formosaspeech/test', keep_in_memory = True) 120 | transcription_column_name = 'text' 121 | elif args.dataset_name == 'fleurs-hk': 122 | prompt_config = process_config('config_files/prompt/fleurs-hk-prompt.yaml') 123 | combined_config = combine_config(prompt_config, setting_config) 124 | ds = load_dataset('google/fleurs', 'yue_hant_hk', split='test') # Use Internet Download 125 | # ds = load_from_disk('../../../data/speech/fleurs/yue_hant_hk/test', keep_in_memory = True) 126 | transcription_column_name = 'transcription' 127 | elif args.dataset_name.startswith('noisy-librispeech'): 128 | signal_to_noise_ratio = args.dataset_name.split('-')[-1] 129 | prompt_config = process_config('config_files/prompt/noisy-librispeech-prompt.yaml') 130 | combined_config = combine_config(prompt_config, setting_config) 131 | ds = load_dataset("distil-whisper/librispeech_asr-noise", "test-pub-noise")[signal_to_noise_ratio] # Use Internet Download 132 | transcription_column_name = 'text' 133 | elif args.dataset_name == 'atco2': 134 | prompt_config = process_config('config_files/prompt/atco2-asr-prompt.yaml') 135 | combined_config = combine_config(prompt_config, setting_config) 136 | ds = load_dataset("jlvdoorn/atco2-asr", split = "train") # Use Internet Download 137 | transcription_column_name = 'text' 138 | 139 | if args.model_name == 'gfd': 140 | test_benchmark(ds, args.model_name, combined_config, transcription_column_name, args.output_dir) 141 | elif args.model_name == 'whisper': 142 | test_benchmark(ds, args.model_name, setting_config, transcription_column_name, args.output_dir) 143 | 144 | if __name__== '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /benchmarks/calculate_mer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import regex 4 | import re 5 | 6 | import json 7 | import jiwer 8 | import pandas as pd 9 | import hanzidentifier 10 | from opencc import OpenCC 11 | from whisper.normalizers import EnglishTextNormalizer 12 | from datasets import load_from_disk 13 | import Levenshtein 14 | 15 | cc = OpenCC("s2t") 16 | normalizer = EnglishTextNormalizer() 17 | greek_to_phonetic = { 18 | 'α': 'alpha', 19 | 'β': 'beta', 20 | 'γ': 'gamma', 21 | 'δ': 'delta', 22 | 'ε': 'epsilon', 23 | 'ζ': 'zeta', 24 | 'η': 'eta', 25 | 'θ': 'theta', 26 | 'ι': 'iota', 27 | 'κ': 'kappa', 28 | 'λ': 'lambda', 29 | 'μ': 'mu', 30 | 'ν': 'nu', 31 | 'ξ': 'xi', 32 | 'ο': 'omicron', 33 | 'π': 'pi', 34 | 'ρ': 'rho', 35 | 'σ': 'sigma', 36 | 'τ': 'tau', 37 | 'υ': 'upsilon', 38 | 'φ': 'phi', 39 | 'χ': 'chi', 40 | 'ψ': 'psi', 41 | 'ω': 'omega', 42 | 'Α': 'Alpha', 43 | 'Β': 'Beta', 44 | 'Γ': 'Gamma', 45 | 'Δ': 'Delta', 46 | 'Ε': 'Epsilon', 47 | 'Ζ': 'Zeta', 48 | 'Η': 'Eta', 49 | 'Θ': 'Theta', 50 | 'Ι': 'Iota', 51 | 'Κ': 'Kappa', 52 | 'Λ': 'Lambda', 53 | 'Μ': 'Mu', 54 | 'Ν': 'Nu', 55 | 'Ξ': 'Xi', 56 | 'Ο': 'Omicron', 57 | 'Π': 'Pi', 58 | 'Ρ': 'Rho', 59 | 'Σ': 'Sigma', 60 | 'Τ': 'Tau', 61 | 'Υ': 'Upsilon', 62 | 'Φ': 'Phi', 63 | 'Χ': 'Chi', 64 | 'Ψ': 'Psi', 65 | 'Ω': 'Omega', 66 | # Add lowercase and uppercase variants for all Greek letters 67 | } 68 | 69 | def separate_text(x): 70 | newx = "" 71 | for i,c in enumerate(x): 72 | if ord(c) > 1000: 73 | newx += " " + c + " " 74 | elif c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ" and x[i-1] not in "ABCDEFGHIJKLMNOPQRSTYVWXYZ": 75 | newx += " " + c 76 | else: 77 | newx += c 78 | newx = re.sub(r'\s+',' ',newx) 79 | newx = newx.strip() 80 | 81 | return newx 82 | 83 | 84 | def demathify(x): 85 | for k,v in greek_to_phonetic.items(): 86 | x = x.replace(k," "+v+" ") 87 | x = x.replace("*", " star ") 88 | x = x.replace("′", " prime ") 89 | x = x.replace("%", " percent ") 90 | 91 | return x 92 | 93 | 94 | def clean_repeating_end(text, repeat_threshold): 95 | x = list(text)[::-1] 96 | for k in range(1,len(x)//repeat_threshold): 97 | repeat = x[0:k] * repeat_threshold 98 | motiv = x[0:k] 99 | 100 | if repeat == x[0:k * repeat_threshold]: 101 | print(repeat) 102 | while x[0:k] == motiv: 103 | x = x[k:] 104 | idx = -1 105 | while len(x) and x[-1] == motiv[idx]: 106 | x = x[-1:] 107 | idx -= 1 108 | x = motiv + x 109 | ret = " ".join(x[::-1]) 110 | print(ret) 111 | return ret 112 | 113 | return text 114 | 115 | 116 | def predscleaner(x): 117 | x = x.strip() 118 | if "<|transcribe|>" in x: 119 | x = x.split("<|transcribe|>")[1] 120 | x = re.sub('<[^>]*>','', x) 121 | l = len(x) 122 | while True: 123 | x = clean_repeating_end(x, 3) 124 | if len(x) == l: 125 | break 126 | l = len(x) 127 | x = demathify(x) 128 | x = re.sub(r'[^\w\s]',' ',x) 129 | 130 | x = x.strip() 131 | if hanzidentifier.is_simplified(x): 132 | x = cc.convert(x) 133 | x = x.lower() 134 | x = separate_text(x) 135 | x = normalizer(x) 136 | 137 | return x 138 | 139 | 140 | def goldcleaner_ml(x): 141 | x = demathify(x) 142 | x = re.sub(r'[^\w\s]','',x) 143 | x = x.strip() 144 | x = x.lower() 145 | x = normalizer(x) 146 | 147 | return x 148 | 149 | 150 | def goldcleaner_formosa(x): 151 | x = demathify(x) 152 | x = re.sub(r'[^\w\s]','',x) 153 | x = x.strip() 154 | if hanzidentifier.is_simplified(x): 155 | x = cc.convert(x) 156 | x = x.lower() 157 | x = normalizer(x) 158 | x = separate_text(x) 159 | 160 | return x 161 | 162 | 163 | def goldcleaner_librispeech(x): 164 | x = re.sub(r'[^\w\s]','',x) 165 | x = x.strip() 166 | x = x.lower() 167 | x = normalizer(x) 168 | 169 | return x 170 | 171 | def predscleaner_acto2(x): 172 | x = x.strip() 173 | if "<|transcribe|>" in x: 174 | x = x.split("<|transcribe|>")[1] 175 | x = re.sub('<[^>]*>','', x) 176 | l = len(x) 177 | while True: 178 | x = clean_repeating_end(x, 5) 179 | if len(x) == l: 180 | break 181 | l = len(x) 182 | x = re.sub(r'[^\w\s]',' ',x) 183 | 184 | x = x.strip() 185 | x = x.lower() 186 | x = normalizer(x) 187 | 188 | return x 189 | 190 | 191 | # def edit_distance_with_whitespace(transcription, prediction): 192 | # def compute_edit_distance(s1, s2): 193 | # m = len(s1) 194 | # n = len(s2) 195 | # dp = [[0 for x in range(n + 1)] for x in range(m + 1)] 196 | 197 | # for i in range(m + 1): 198 | # for j in range(n + 1): 199 | # if i == 0: 200 | # dp[i][j] = j 201 | # elif j == 0: 202 | # dp[i][j] = i 203 | # elif s1[i - 1] == s2[j - 1]: 204 | # dp[i][j] = dp[i - 1][j - 1] 205 | # else: 206 | # dp[i][j] = 1 + min(dp[i][j - 1], dp[i - 1][j], dp[i - 1][j - 1]) 207 | 208 | # return dp[m][n] 209 | 210 | # # Function to find English words in a string 211 | # def find_english_words(s): 212 | # return [match.span() for match in re.finditer(r'\b[a-zA-Z]+(?:\s+[a-zA-Z]+)*\b', s)] 213 | 214 | # # Compute the standard edit distance 215 | # standard_edit_distance = compute_edit_distance(transcription, prediction) 216 | 217 | # # Find English words in both strings 218 | # transcription_english_words = find_english_words(transcription) 219 | # prediction_english_words = find_english_words(prediction) 220 | 221 | # # Check for whitespace insertion 222 | # index_to_add_whitespace = [] 223 | # index_to_delete_whitespace = [] 224 | 225 | # # Check for whitespace insertion in prediction 226 | # for start, end in prediction_english_words: 227 | # for i in range(start + 1, end): 228 | # if prediction[i] != ' ': 229 | # new_prediction = prediction[:i] + ' ' + prediction[i:] 230 | # if standard_edit_distance > compute_edit_distance(new_prediction, transcription): 231 | # index_to_add_whitespace.append(i) 232 | 233 | # # Check for whitespace deletion in prediction 234 | # for start, end in prediction_english_words: 235 | # for i in range(start, end - 1): 236 | # if prediction[i] == ' ': 237 | # new_prediction = prediction[:i] + prediction[i+1:] 238 | # if standard_edit_distance > compute_edit_distance(new_prediction, transcription): 239 | # index_to_delete_whitespace.append(i) 240 | 241 | # # Sort the indexes 242 | # index_to_add_whitespace.sort() 243 | # index_to_delete_whitespace.sort() 244 | 245 | # # Apply deletions first 246 | # new_prediction = list(prediction) 247 | # offset = 0 248 | # for i in index_to_delete_whitespace: 249 | # del new_prediction[i - offset] 250 | # offset += 1 251 | 252 | # # Apply additions 253 | # offset = 0 254 | # for i in index_to_add_whitespace: 255 | # new_prediction.insert(i + offset, ' ') 256 | # offset += 1 257 | 258 | # new_prediction = ''.join(new_prediction) 259 | 260 | # return new_prediction 261 | 262 | 263 | def parse_args(): 264 | parser = argparse.ArgumentParser(description="Evaluate Benchmark Result.") 265 | parser.add_argument('--dataset_name', type=str, help='The benchmark dataset for testing') 266 | parser.add_argument('--output_dir', type=str, help='The output directory for benchmark dataset') 267 | 268 | return parser.parse_args() 269 | 270 | args = parse_args() 271 | if args.dataset_name in 'ml-lecture-2021-long': 272 | transcription_column_name = 'transcription' 273 | goldcleaner_function = goldcleaner_ml 274 | elif args.dataset_name == 'formosa-long': 275 | transcription_column_name = 'text' 276 | goldcleaner_function = goldcleaner_formosa 277 | elif args.dataset_name == 'fleurs-hk': 278 | transcription_column_name = 'transcription' 279 | goldcleaner_function = goldcleaner_formosa 280 | elif args.dataset_name.startswith('noisy-librispeech'): 281 | transcription_column_name = 'text' 282 | goldcleaner_function = goldcleaner_librispeech 283 | elif args.dataset_name == 'acto2': 284 | transcription_column_name = 'text' 285 | predscleaner = predscleaner_acto2 286 | goldcleaner_function = goldcleaner_librispeech 287 | 288 | # Try get the dataset with predictions, else use json files to evaluate 289 | use_dataset = True 290 | try: 291 | ds = load_from_disk(os.path.join(args.output_dir, 'ds_result')) 292 | except: 293 | print('Load from disk failed.') 294 | use_dataset = False 295 | 296 | if use_dataset: 297 | try: 298 | sub_ds = ds.select_columns([transcription_column_name, 'prediction']) 299 | df = sub_ds.to_pandas() 300 | except: 301 | print('Extracting dataset failed. Column name error.') 302 | use_dataset = False 303 | 304 | if not use_dataset: # use json 305 | all_samples = [] 306 | path = os.path.join(args.output_dir, 'temp_results') 307 | for f in os.listdir(path): 308 | fpath = os.path.join(path, f) 309 | with open(fpath, "r") as ff: 310 | js = json.load(ff) 311 | all_samples.append(js) 312 | df = pd.DataFrame.from_dict(all_samples) 313 | print('Number of samples loaded from json: ', len(df)) 314 | 315 | df["preds"] = [predscleaner(x) for x in df["prediction"]] 316 | df["gold"] = [goldcleaner_function(x) for x in df[transcription_column_name]] 317 | 318 | wer = jiwer.wer(list(df["gold"]),list(df["preds"])) 319 | print(wer) 320 | -------------------------------------------------------------------------------- /gfd/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | from collections import Counter 5 | 6 | import torch 7 | import numpy as np 8 | from transformers import AutoModelForCausalLM, RepetitionPenaltyLogitsProcessor, DynamicCache 9 | 10 | DEBUG = 0 11 | 12 | class KVCache: 13 | def __init__(self): 14 | self._id_to_cache = {} 15 | self._id_to_logits = {} 16 | self._query_count = Counter() 17 | self._newly_added = set() 18 | 19 | def add(self, decode_ids, kv, logits): 20 | if isinstance(decode_ids, torch.Tensor): 21 | decode_ids = tuple(decode_ids.squeeze().tolist()) 22 | 23 | if decode_ids not in self._id_to_cache: 24 | self._id_to_cache[decode_ids] = kv 25 | self._id_to_logits[decode_ids] = logits 26 | self._query_count[decode_ids] = 0 27 | self._newly_added.add(decode_ids) 28 | 29 | def query(self, decode_ids): 30 | # Find the longest match of decoding ids then return, should increment the number that the entry is queried 31 | if isinstance(decode_ids, torch.Tensor): 32 | decode_ids = tuple(decode_ids.squeeze().tolist()) 33 | 34 | for i in range(len(decode_ids), 0, -1): 35 | prefix = decode_ids[:i] 36 | if prefix in self._id_to_cache: 37 | self._query_count[prefix] += 1 38 | return self._id_to_cache[prefix], self._id_to_logits[prefix], len(prefix) 39 | 40 | return None, None, 0 41 | 42 | def remove_unused(self): 43 | # Delete the entry that is not queried nor newly added 44 | keys_to_remove = [key for key in self._id_to_cache.keys() if self._query_count[key] == 0 and key not in self._newly_added] 45 | if DEBUG >= 2: 46 | print(keys_to_remove) 47 | print(self._query_count) 48 | print(self._newly_added) 49 | for key in keys_to_remove: 50 | del self._id_to_cache[key] 51 | del self._id_to_logits[key] 52 | del self._query_count[key] 53 | # Update min_length after removal 54 | self._query_count = Counter() 55 | self._newly_added = set() 56 | torch.cuda.empty_cache() 57 | if DEBUG >= 1: 58 | print(len(self._decode_ids_to_kv)) 59 | if DEBUG >= 2: 60 | print(self._decode_ids_to_kv.keys()) 61 | 62 | 63 | class ByteModel: 64 | def _get_logprob(self, standard_ids, alternative_ids, logits, start_token_id=1): 65 | # standard_seq [A B C D E F] 66 | # standard_seq_probs [1, B_prob, C_prob, D_prob, E_prob, F_prob] 67 | # alternative_ids [ [D_alts][E_alts][F_alts]] 68 | # nseq_vocab_probs (bs, (D_lg, E_lg) n_vocab) 69 | # sheilded_token_length x x x x 70 | 71 | shifted_standard_ids = standard_ids[0, 1:] 72 | shifted_alternative_ids = alternative_ids[1:] 73 | shifted_logits = logits[0, :-1, :] 74 | shifted_start_token_id = start_token_id - 1 75 | all_logprobs = torch.log(torch.softmax(shifted_logits, dim=-1)) 76 | 77 | standard_logprobs = all_logprobs[ 78 | torch.arange(0, shifted_standard_ids.size(0)), 79 | shifted_standard_ids 80 | ] 81 | standard_logprobs[:shifted_start_token_id] = 0. 82 | 83 | rolling_logprobs = torch.cumsum(standard_logprobs, dim=0) 84 | 85 | logprob_collect = [] 86 | for i in range(shifted_start_token_id, len(all_logprobs)): 87 | if len(shifted_alternative_ids[i]) == 0: 88 | continue 89 | alternatives = torch.tensor(shifted_alternative_ids[i], 90 | device=self.device, dtype=torch.long) 91 | 92 | prev_logprob = rolling_logprobs[i-1] 93 | selected_logprobs = all_logprobs[i, alternatives] 94 | logprob_collect.append(prev_logprob + torch.logsumexp(selected_logprobs, dim=0)) 95 | logprob_collect.append(rolling_logprobs[-1]) 96 | logprob_collect_stacked = torch.stack(logprob_collect) 97 | 98 | # length to normalize depends on the max contributed sequence length 99 | normalizer_adjust_n = min(torch.argmax(logprob_collect_stacked) - len(logprob_collect_stacked) + 2, 0) 100 | 101 | logprob = torch.logsumexp(logprob_collect_stacked, dim=0) 102 | 103 | return logprob, normalizer_adjust_n 104 | 105 | 106 | class BreezeByte(ByteModel): 107 | def __init__(self, config): 108 | self.config = config 109 | self.llm = AutoModelForCausalLM.from_pretrained( 110 | self.config.llm_model_path, 111 | device_map=self.config.llm_device, 112 | torch_dtype=torch.float16, 113 | attn_implementation=self.config.llm_attn_implementation 114 | ) 115 | self.device = self.llm.device 116 | self.kv_cache = KVCache() 117 | self.static_cache = None 118 | 119 | def _process_prompt_in_batches(self, static_decoding_ids): 120 | past_key_values = None 121 | for i in range(0, len(static_decoding_ids), batch_size): 122 | batch_input_ids = input_ids[i:i] 123 | 124 | def _initialize_static_cache(self, static_decoding_ids): 125 | if self.static_cache is None: 126 | model_kwargs = self.llm.prepare_inputs_for_generation( 127 | torch.tensor([static_decoding_ids], device=self.device), 128 | attention_mask=None, 129 | inputs_embeds=None, 130 | cache_position=None, 131 | use_cache=True, 132 | ) 133 | outputs = self.llm(**model_kwargs) 134 | self.static_cache = outputs.past_key_values 135 | 136 | def get_logprob(self, prefix_decoding_ids, llm_ids, llm_tokenizer): 137 | standard_ids = prefix_decoding_ids + llm_ids 138 | alternative_ids = llm_tokenizer.get_alternative_ids(standard_ids) 139 | 140 | with torch.no_grad(): 141 | standard_ids = torch.tensor([standard_ids], device=self.device) 142 | logits = self.llm( 143 | input_ids=standard_ids, 144 | return_dict=True 145 | ).logits 146 | logits = logits.float() / self.config.llm_temp 147 | 148 | if self.config.repetition_penalty > self.config.repetition_penalty_threshold: 149 | logits_processor = RepetitionPenaltyLogitsProcessor(penalty=self.config.repetition_penalty) 150 | for i in range(1, min(logits.size(1), self.config.repetition_penalty_last + 1)): 151 | logits[:, -i, :] = logits_processor(standard_ids[:, -i - self.config.repetition_penalty_window:-i], logits[:, -i, :]) 152 | 153 | logprob, normalizer_adjust_n = self._get_logprob(standard_ids, alternative_ids, logits, 154 | start_token_id=len(prefix_decoding_ids)) 155 | 156 | return logprob.item(), normalizer_adjust_n 157 | 158 | def get_logprob_cache_static(self, prefix_decoding_ids, llm_ids, llm_tokenizer): 159 | self._initialize_static_cache(prefix_decoding_ids) 160 | 161 | standard_ids = prefix_decoding_ids + llm_ids 162 | alternative_ids = llm_tokenizer.get_alternative_ids(standard_ids) 163 | 164 | with torch.no_grad(): 165 | standard_ids = torch.tensor([standard_ids], device=self.device) 166 | logits = self.llm( 167 | input_ids=standard_ids, 168 | past_key_values=self.static_cache, 169 | return_dict=True 170 | ).logits 171 | logits = logits.float() / self.config.llm_temp 172 | 173 | if self.config.repetition_penalty > self.config.repetition_penalty_threshold: 174 | logits_processor = RepetitionPenaltyLogitsProcessor(penalty=self.config.repetition_penalty) 175 | for i in range(1, min(logits.size(1), self.config.repetition_penalty_last + 1)): 176 | logits[:, -i, :] = logits_processor(standard_ids[:, -i - self.config.repetition_penalty_window:-i], logits[:, -i, :]) 177 | 178 | logprob, normalizer_adjust_n = self._get_logprob(standard_ids, alternative_ids, logits, 179 | start_token_id=len(prefix_decoding_ids)) 180 | 181 | return logprob.item(), normalizer_adjust_n 182 | 183 | 184 | def get_logprob_cache_dynamic(self, prefix_decoding_ids, llm_ids, llm_tokenizer): 185 | standard_ids = prefix_decoding_ids + llm_ids 186 | alternative_ids = llm_tokenizer.get_alternative_ids(standard_ids) 187 | 188 | with torch.no_grad(): 189 | standard_ids = torch.tensor([standard_ids], device=self.device) 190 | past_key_values, past_logits , matched_length = self.kv_cache.query(standard_ids) 191 | 192 | if matched_length == standard_ids.shape[1]: 193 | logits = past_logits 194 | else: 195 | model_kwargs = self.llm.prepare_inputs_for_generation( 196 | standard_ids[:, matched_length:], 197 | past_key_values=past_key_values, 198 | attention_mask=None, 199 | inputs_embeds=None, 200 | cache_position=None, 201 | use_cache=True, 202 | ) 203 | outputs = self.llm(**model_kwargs) 204 | 205 | logits = torch.cat((past_logits, outputs.logits), dim=1) if past_logits is not None else outputs.logits 206 | 207 | if logits.size(1) != standard_ids.shape[1]: # The output logits size does not match the length of standard_id 208 | model_kwargs = self.llm.prepare_inputs_for_generation( 209 | standard_ids, 210 | attention_mask=None, 211 | inputs_embeds=None, 212 | cache_position=None, 213 | use_cache=True, 214 | ) 215 | outputs = self.llm(**model_kwargs) 216 | logits = outputs.logits 217 | 218 | self.kv_cache.add(standard_ids, outputs.past_key_values, logits) 219 | 220 | logits = logits.float() / self.config.llm_temp 221 | 222 | 223 | if self.config.repetition_penalty > self.config.repetition_penalty_threshold: 224 | logits_processor = RepetitionPenaltyLogitsProcessor(penalty=self.config.repetition_penalty) 225 | for i in range(1, min(logits.size(1), self.config.repetition_penalty_last + 1)): 226 | 227 | logits[:, -i, :] = logits_processor(standard_ids[:, -i - self.config.repetition_penalty_window:-i], logits[:, -i, :]) 228 | 229 | logprob, normalizer_adjust_n = self._get_logprob(standard_ids, alternative_ids, logits, 230 | start_token_id=len(prefix_decoding_ids)) 231 | 232 | return logprob.item(), normalizer_adjust_n 233 | 234 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generative Fusion Decoding 2 | 3 | Python code for the paper "[Let’s Fuse Step by Step: A Generative Fusion Decoding Algorithm with LLMs for Multi-modal Text Recognition](https://arxiv.org/abs/2405.14259)" by Chan-Jan Hsu*, [Yi-Chang Chen](https://ycc.idv.tw/about-me)*, Feng-Ting Liao, Pei-Chen Ho, Yu-Hsiang Wang, Po-Chun Hsu, Da-shan Shiu 4 | 5 | *Equal contribution 6 | 7 | 8 | 9 | 10 | ## Approach 11 | 12 | ![Project teaser](assets/teaser.png) 13 | 14 | We introduce "Generative Fusion Decoding" (GFD), a novel shallow fusion framework, utilized to integrate Large Language Models (LLMs) into multi-modal text recognition systems such as automatic speech recognition (ASR) and optical character recognition (OCR). We derive the formulas necessary to enable GFD to operate across mismatched token spaces of different models by mapping text token space to byte token space, enabling seamless fusion during the decoding process. The framework is plug-and-play, compatible with various auto-regressive models, and does not require re-training for feature alignment, thus overcoming limitations of previous fusion techniques. We highlight three main advantages of GFD: First, by simplifying the complexity of aligning different model sample spaces, GFD allows LLMs to correct errors in tandem with the recognition model, reducing computation latencies. Second, the in-context learning ability of LLMs is fully capitalized by GFD, increasing robustness in long-form speech recognition and instruction aware speech recognition. Third, GFD enables fusing recognition models deficient in Chinese text recognition with LLMs extensively trained on Chinese. Our evaluation demonstrates that GFD significantly improves performance in ASR and OCR tasks, with ASR reaching state-of-the-art in the NTUML2021 benchmark. GFD provides a significant step forward in model integration, offering a unified solution that could be widely applicable to leveraging existing pre-trained models through step by step fusion. 15 | 16 | ## Setup 17 | 18 | 1. **Clone the repository:** 19 | ``` 20 | git clone https://github.com/mtkresearch/generative-fusion-decoding.git 21 | cd generative-fusion-decoding 22 | ``` 23 | 2. **Create a python virtual environment:** 24 | ``` 25 | python -m venv venv 26 | source venv/bin/activate # On Windows use `venv\Scripts\activate` 27 | ``` 28 | 3. **Install the required package:** 29 | ``` 30 | pip install -r requirements.txt 31 | ``` 32 | 4. **Run the setup script:** 33 | ``` 34 | python setup.py install 35 | ``` 36 | 37 | ## Run GFD 38 | ### GPU Memory Requirements 39 | We run GFD on 1*A6000 machine 40 | 41 | Memory Breakdown: ASR - Whisper Large ~3GB, LLM - Breeze/Mistral ~14GB 42 | 43 | 🤗You can try it out hassle-free with Kaggle T4 GPUs [here](https://www.kaggle.com/code/a24998667/generative-fusion-decoding-example)! 44 | 45 | ### On Single File 46 | To run the script, the following three arguments are required: 47 | - `--model_name`: This argument specifies which type of model to use. There are two options: 48 | + `gfd`: The generative fusion decoding method. 49 | - `--setting`: The argument specifies the configuration setting for the model. The available settings depend on the `model_name`: 50 | + `asr-zhtw`: The complete version of our method's configuration for testing on a Traditional Chinese sample. 51 | + `asr-zhtw-lmoff`: Uses our custom beam search method on the ASR model, neglecting the output from the LLM (fusing_r = 0) for a Traditional Chinese sample. 52 | + `asr-en`: The complete version of our method's configuration for testing on an English sample. 53 | + `asr-en-lmoff`: Uses our custom beam search method on the ASR model, neglecting the output from the LLM (fusing_r = 0) for an English dataset sample. 54 | - `--audio_file_path`: The path to the audio file that you want to process. 55 | - `--result_output_path`: The path where the output result will be saved. 56 | 57 | **Example Usage** 58 | ``` 59 | python benchmarks/run_single_file.py --model_name gfd --setting asr-zhtw --audio_file_path demo_examples/zh_news.wav --result_output_path output.txt 60 | ``` 61 | 62 | ### On Benchmark Dataset 63 | To run the benchmark dataset, the following four arguments are required: 64 | - `--dataset_name`: Each dataset we tested has a short version name for easy reference. When you run `benchmarks/run_benchmark.py`, the script will automatically download the specified dataset from Hugging Face. Below is a list of short version names of datasets used. 65 | + ml-lecture-long-2021: A dataset of long-form audio recordings from NTU 2021 machine learning lectures. 66 | + formosa-long: A dataset of long-form audio recordings in Traditional Chinese. 67 | + fleurs-hk: The Google Fleurs dataset using the split of yue_hant_hk. 68 | + noisy-librispeech-10: Librispeech dataset with noises added to the audio (S/R = 10). 69 | + noisy-librispeech-5: Librispeech dataset with noises added to the audio (S/R = 5). 70 | + atco2: Air Traffic Control Voice Communication dataset. 71 | - `--model_name`: This argument specifies which type of model to use. There are two options: 72 | + `gfd`: The generative fusion decoding method. 73 | + `whisper`: The huggingface whisper generation method. 74 | - `--setting`: The argument specifies the configuration setting for the model. The available settings depend on the `model_name`: 75 | 76 | For **gfd**: 77 | + `asr-zhtw`: The complete version of our method's configuration for testing on the Traditional Chinese dataset. 78 | + `asr-zhtw-lmoff`: Uses our custom beam search method on the ASR model, neglecting the output from the LLM (fusing_r = 0) for Traditional Chinese dataset. 79 | + `asr-en`: he complete version of our method's configuration for testing on the English dataset. 80 | + `asr-en-lmoff`: Uses our custom beam search method on the ASR model, neglecting the output from the LLM (fusing_r = 0) for the English dataset. 81 | 82 | For **whisper**: 83 | + `whisper-zhtw`: The configuration for the Traditional Chinese dataset. 84 | + `whisper-en`: The configuration for the English dataset. 85 | 86 | - `--output_dir`: The argument specifies the path to the directory where the model output will be stored. The outputs of the model will be stored in two subfolders: 87 | + `temp_results`: Stores the result of each sample to a JSON file. 88 | + `ds_result`: Stores the whole dataset along with the model predictions. 89 | 90 | **Example Usage** 91 | 92 | Here are some example commands for different configuration: 93 | + Using `gfd` model with `asr-zhtw` setting on `ml-lecture-2021-long` dataset 94 | ``` 95 | python benchmarks/run_benchmark.py --dataset_name ml-lecture-2021-long --model_name gfd --setting asr-zhtw --output_dir result/ 96 | ``` 97 | + Using `whisper` model with `whisper-zhtw` setting on `ml-lecture-2021-long` dataset 98 | ``` 99 | python benchmarks/run_benchmark.py --dataset_name ml-lecture-2021-long --model_name whisper --setting whisper-zhtw --output_dir result/ 100 | ``` 101 | **Using Multiple GPUs** 102 | 103 | If you have multiple GPUs, you can change the device configuration in the config file. 104 | 105 | ## Configuration 106 | 107 | There are configurations for GFD and Whisper model under `config_files/model`, including Traditional Chinese and English for both models. 108 | - GFD: 109 | - Traditional Chinese: `gfd-asr-zhtw.yaml` 110 | - English: `gfd-asr-en.yaml` 111 | - Whisper: 112 | - Traditional Chinese: `whisper-zhtw.yaml` 113 | - English: `whisper-en.yaml` 114 | 115 | In `config_files/prompt`, it also includes task-specific configurations of Automatic Speech Recognition (ASR) and Language Model (LLM) prompts for `gfd`. The naming rule for prompt configuration file is `{short version dataset name}_prompt.yaml`. 116 | 117 | The general configuration files `gfd-asr-zhtw.yaml` and `gfd-asr-en.yaml` contain various configuration options. Below are the meanings and choices for each argument, divided into three parts based on the likelihood of needing to reset them. 118 | 119 | ### Core Arguments 120 | 121 | - **`asr_model_path`**: Path to the Automatic Speech Recognition (ASR) model for speech recognition. 122 | 123 | - **`llm_model_path`**: Path to the Language Model (LLM) for language processing task. 124 | 125 | - **`lang`**: Language code for the ASR model, 'en' for English and 'zh' for Chinese. 126 | 127 | - **`asr_device`**: Device to run the ASR model on. 128 | 129 | - **`llm_device`**: Device to run the LLM on. 130 | 131 | ### Arguments that can optionally be reset 132 | 133 | - **`force_character_mode`**: Output mode of characters when `lang == 'zh'`, options include `'tc'` for traditional Chinese characters, `'sc'` for simplified Chinese characters and `None` for no specific mode specified 134 | 135 | - **`seg_with_overlap`**: Default is `False`. When set to `True`, the audio will be segmented with a short interval of overlap. If set to `false`, the audio will be segmented without any overlap. 136 | 137 | - **`fusing_strategy`**: Default is `simple`. The fusing score of ASR and LLM will be the weighted sum of ASR score and LLM score. score = `fusing_r` * `llm_score` + `1-fusing_r` * `asr_score`. 138 | 139 | - **`use_cache`**: Default is `dynamic`. When set to `dynamic`, the model will run with key-value (kv) cache enabled, which speeds up the processing, especially for long-from audio. If set to `None`, the kv cache will be disabled. If you are facing memory issues, consider setting it to `None` to release memory. 140 | 141 | - **`fusing_r`**: Fusing ratio used in the fusing strategy to combine ASR and LLM outputs. 142 | 143 | - **`asr_attn_implementation`**: ASR attention implementation, options including "eager" (manual implementation of the attention), "sdpa" (attention using torch.nn.functional.scaled_dot_product_attention), or "flash_attention_2" (attention using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual "eager" implementation. 144 | 145 | - **`llm_attn_implementation`**: LLM attention implementation, options including "eager" (manual implementation of the attention), "sdpa" (attention using torch.nn.functional.scaled_dot_product_attention), or "flash_attention_2" (attention using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual "eager" implementation. 146 | 147 | - **`llm_temp`**: LLM temperature parameter to modulate the next token probabilities. 148 | 149 | - **`transcription_cutoff`**: Transcription cutoff limit. This argument specified the maximum number of tokens to retain from the previous transcription. If the previous transcription exceeds this limit, it will be truncated to the specified length. 150 | 151 | ### Arguments that most likely don't need to be reset 152 | 153 | - **`repetition_penalty`**: The penalty applied to repeated tokens during the generation process. A higher value increases the penalty, making the model less likely to repeat the same tokens. If the `repetition_penalty` is greater than the `repetition_penalty_threshold`, the penalty is applied. 154 | 155 | - **`repetition_penalty_last`**: Repetition penalty for the last tokens, which specifies the number of last tokens to apply the repetition penalty to. 156 | 157 | - **`repetition_penalty_window`**: The window size for applying the repetition penalty. The penalty is applied to tokens within this window size from the current token being processed. For example, if `repetition_penalty_window` is set to `50`, the penalty will be applied to tokens within the last 50 tokens from the current token. 158 | 159 | - **`repetition_penalty_threshold`**: The threshold for applying the repetition penalty. If the `repetition_penalty` is greater than this threshold, the penalty mechanism is activated. 160 | 161 | - **`beam_terminated_strategy`**: Beam search termination strategy. The default is `when_all_ended`, which terminates beam search when all beams reaches the end. 162 | 163 | - **`beam_select_strategy`**: Beam selection strategy, options including `'best'` which selects the beam with highest score, and `'longest'` which selects the beam with longest transcription result 164 | 165 | - **`beam_max_decode_len`**: Maximum decode length for beam search, which specifies the maximum length of the decoded sequence during beam search. 166 | 167 | - **`beam_max_len_diff`**: Maximum length difference for beam search, which specifies the maximum difference in length between the beams during beam search. 168 | 169 | - **`beam_max_len`**: Maximum length for beam search, which specifies the maximum length of the beam search. A default value of `-1` means no limit. 170 | 171 | - **`beam_min_len`**: Minimum length for beam search. 172 | 173 | - **`logprob_min`**: Minimum log probability for the LLM output. 174 | 175 | ## Evaluate the Result 176 | 177 | After running the model on the benchmark dataset, you can evaluate the result by calculating the Mixed Error Rates (MER) using the provided `benchmarks/calculate_mer.py` script. The script requireds the following arguments: 178 | 179 | + `--dataset_name`: The short version name of the benchmark dataset that you want to evalute. 180 | + `--output_dir`: The output directory that stores the output from the model. 181 | 182 | **Example Usage** 183 | 184 | ``` 185 | python benchmarks/calculate_mer.py --dataset_name ml-lecture-2021-long --output_dir result/ 186 | ``` 187 | 188 | ## Warning 189 | 190 | **Warning**: This project uses tokenizers with [custom tokenizer functions](https://github.com/mtkresearch/generative-fusion-decoding/blob/main/gfd/tokenizer.py) mostly to deal with byte string tokenizations, and has only been tested with the Mistral and Breeze models. Using other models may result in errors or unexpected behavior. Please ensure compatibility before using it with other models. 191 | 192 | ## Acknowledgements 193 | If you like our work, please site: 194 | ``` 195 | @article{hsu2024let, 196 | title={Let's Fuse Step by Step: A Generative Fusion Decoding Algorithm with LLMs for Multi-modal Text Recognition}, 197 | author={Hsu, Chan-Jan and Chen, Yi-Chang and Liao, Feng-Ting and Ho, Pei-Chen and Wang, Yu-Hsiang and Hsu, Po-Chun and Shiu, Da-Shan}, 198 | journal={arXiv preprint arXiv:2405.14259}, 199 | year={2024} 200 | } 201 | ``` 202 | -------------------------------------------------------------------------------- /gfd/gfd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | import unittest 6 | import numpy as np 7 | import librosa 8 | from transformers import WhisperForConditionalGeneration, WhisperProcessor, GenerationConfig 9 | 10 | from gfd.beam import BeamsControler 11 | from gfd.model import BreezeByte 12 | from gfd.tokenizer import LlamaByteTokenizer, WhisperByteTokenizer 13 | 14 | DEBUG = 1 15 | 16 | class SuppressTokenWarper(): 17 | def __init__(self, surpress_tokens, min_value): 18 | self.surpress_tokens = surpress_tokens 19 | self.min_value = min_value 20 | 21 | def __call__(self, scores): 22 | scores[...,self.surpress_tokens] = self.min_value 23 | return scores 24 | 25 | class Breezper: 26 | def __init__(self, config): 27 | self.config = config 28 | self.asr = WhisperForConditionalGeneration.from_pretrained( 29 | self.config.asr_model_path, torch_dtype=torch.float16, device_map=self.config.asr_device, 30 | attn_implementation=self.config.asr_attn_implementation) 31 | self.device = self.asr.device 32 | self.asr_processor = WhisperProcessor.from_pretrained(self.config.asr_model_path) 33 | self.asr_tokenizer = WhisperByteTokenizer.from_pretrained(self.config.asr_model_path) 34 | 35 | self.breeze_byte = BreezeByte(config) 36 | self.llm_tokenizer = LlamaByteTokenizer.from_pretrained(self.config.llm_model_path) 37 | 38 | self.asr_prefix_prompt_template = "<|startofprev|> {prompt} " 39 | if self.config.lang == 'en': 40 | self.asr_prefix_template = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" 41 | self.asr_default_prompt = None 42 | elif self.config.lang == 'zh': 43 | self.asr_prefix_template = "<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>" 44 | self.asr_default_prompt = '繁體中文' 45 | 46 | self.llm_prefix_template = "{prompt}" 47 | 48 | asr_config = GenerationConfig().from_pretrained(self.config.asr_model_path) 49 | self.suppress_tokens = asr_config.suppress_tokens 50 | self.begin_suppress_tokens = asr_config.begin_suppress_tokens 51 | self.surpress_token_func = SuppressTokenWarper(self.suppress_tokens, min_value=float("-inf")) 52 | self.surpress_begin_token_func = SuppressTokenWarper(self.begin_suppress_tokens + self.suppress_tokens, min_value=float("-inf")) 53 | 54 | def _chunk_audio(self, y, sr): 55 | chunk_size = self.config.chunk_sec * sr 56 | stride_size = self.config.stride_sec * sr 57 | length = len(y) 58 | count = 1 + max(0, int(np.ceil((length - chunk_size) / (chunk_size - stride_size)))) 59 | for i in range(count): 60 | start = i * (chunk_size - stride_size) 61 | end = min(start + chunk_size, length) 62 | chunked_y = y[start:end] 63 | yield chunked_y 64 | 65 | def _chunk_audio_by_whipser(self, y, sr, seg_with_overlap): 66 | input_features = self.asr_processor(y, sampling_rate=sr, 67 | return_tensors="pt", truncation=False).input_features.half().to(self.device) 68 | res = self.asr.generate(prompt_condition_type='first-segment', input_features=input_features, num_beams=5, 69 | return_segments=True, return_dict_in_generate=True) 70 | 71 | intervals = [] 72 | # Return segments aggregated by target duration 73 | if seg_with_overlap == True: 74 | target_duration = 30 75 | sub_intervals = [(chunk['start'].item(), chunk['end'].item()) for chunk in res['segments'][0]] 76 | curr_start, curr_end = sub_intervals[0][0], sub_intervals[0][1] 77 | last_start = None 78 | for next_start, next_end in sub_intervals[1:]: 79 | if curr_end - curr_start + (next_end - next_start) > target_duration: 80 | intervals.append((curr_start, curr_end)) 81 | curr_start = last_start 82 | else: 83 | last_start = next_start 84 | curr_end = next_end 85 | elif seg_with_overlap == False: 86 | last_sequence = None 87 | for chunk in res['segments'][0]: 88 | curr_sequence = chunk['result']['sequences'] 89 | if last_sequence is None or not torch.equal(last_sequence, curr_sequence): 90 | intervals.append((chunk['start'].item(), chunk['end'].item())) 91 | last_sequence = curr_sequence 92 | else: 93 | old_start, old_end = intervals.pop() 94 | intervals.append((min(old_start, chunk['start'].item()), max(old_end, chunk['end'].item()))) 95 | else: 96 | raise NotImplementedError 97 | 98 | for start, end in intervals: 99 | yield y[int(start*sr): int(end*sr)] 100 | 101 | def get_transcription(self, fpath_or_audio, sr=16000, num_beams=5, asr_prompt='', llm_prompt=''): 102 | if isinstance(fpath_or_audio, str): 103 | y, sr = librosa.load(fpath_or_audio, sr = sr) 104 | else: 105 | y = fpath_or_audio 106 | sr = sr 107 | 108 | if DEBUG: 109 | print('asr prompt:', asr_prompt) 110 | print('llm prompt:', llm_prompt) 111 | 112 | if len(y) <= sr * 30: 113 | transcription = self._get_transcription(y, sr, num_beams, asr_prompt=asr_prompt, llm_prompt=llm_prompt, use_cache=self.config.use_cache) 114 | else: 115 | transcription = '' 116 | for chunked_y in self._chunk_audio_by_whipser(y, sr, seg_with_overlap=self.config.seg_with_overlap): 117 | last_transcription = self._get_transcription(chunked_y, sr, num_beams, asr_prompt=asr_prompt, llm_prompt=llm_prompt+transcription[-self.config.transcription_cutoff:], use_cache=self.config.use_cache) 118 | transcription += last_transcription + ' ' 119 | asr_prompt = last_transcription 120 | 121 | if DEBUG: 122 | print('current transcription:', transcription) 123 | time.sleep(3) 124 | if DEBUG > 2: 125 | input("Enter to continue ...") 126 | 127 | return transcription 128 | 129 | def fuse(self, asr_score, llm_score): 130 | if self.config.fuse_strategy == 'simple': 131 | return (1 - self.config.fusing_r) * asr_score + self.config.fusing_r * llm_score 132 | else: 133 | raise NotImplementedError() 134 | 135 | def _get_prefix_decoding_ids(self, asr_prompt, llm_prompt): 136 | # asr 137 | asr_prompt = asr_prompt if asr_prompt else self.asr_default_prompt 138 | asr_prefix = ( 139 | self.asr_prefix_prompt_template.format(prompt=asr_prompt) 140 | + self.asr_prefix_template 141 | ) 142 | asr_prefix_decoding_ids = self.asr_tokenizer( 143 | asr_prefix, 144 | add_special_tokens=False 145 | ).input_ids 146 | 147 | # llm 148 | llm_prefix_decoding_ids = self.llm_tokenizer.tokenize_from_byte( 149 | self.llm_prefix_template.format( 150 | prompt=llm_prompt 151 | ).encode('utf8') 152 | ) 153 | 154 | return asr_prefix_decoding_ids, llm_prefix_decoding_ids 155 | 156 | def _asr_forward(self, encoder_outputs, decoder_input_ids, k, supress_func=None): 157 | with torch.no_grad(): 158 | logits = self.asr( 159 | encoder_outputs=encoder_outputs, 160 | decoder_input_ids=torch.tensor(decoder_input_ids, device=self.device), 161 | return_dict=True 162 | ).logits 163 | if supress_func is not None: 164 | logits = supress_func(logits) 165 | logprobs = torch.log(torch.softmax(logits, dim=-1)) 166 | next_logprobs, inds = torch.topk(logprobs[0, -1, :], k, dim=-1) 167 | 168 | return next_logprobs, inds 169 | 170 | def _get_transcription(self, y, sr, num_beams, asr_prompt, llm_prompt, use_cache=None): 171 | input_features = self.asr_processor(y, sampling_rate=sr, 172 | return_tensors="pt").input_features.half().to(self.device) 173 | encoder_outputs = self.asr.get_encoder()(input_features, return_dict=True) 174 | 175 | beams = BeamsControler( 176 | config=self.config, 177 | n_beam=num_beams, 178 | asr_eos_id=self.asr_tokenizer.eos_token_id) 179 | 180 | asr_prefix_decoding_ids, llm_prefix_decoding_ids = self._get_prefix_decoding_ids(asr_prompt, llm_prompt) 181 | next_asr_logprobs, asr_inds = self._asr_forward( 182 | encoder_outputs, 183 | asr_prefix_decoding_ids, 184 | k=1, 185 | supress_func=self.surpress_begin_token_func 186 | ) 187 | for ind, next_asr_logprob in zip(asr_inds, next_asr_logprobs): 188 | next_id = ind.item() 189 | next_asr_logprob = next_asr_logprob.item() 190 | asr_score, llm_score = self._calcualte_asr_llm_score( 191 | asr_normalized_len=1, 192 | asr_logprob=next_asr_logprob, 193 | llm_normalized_len=1, 194 | llm_logprob=None 195 | ) 196 | 197 | fuse_score = self.fuse(asr_score, llm_score) 198 | beams.add( 199 | asr_score=asr_score, 200 | llm_score=llm_score, 201 | fuse_score=fuse_score, 202 | asr_prefix_ids=asr_prefix_decoding_ids, 203 | asr_ids=[next_id], 204 | asr_logprob=next_asr_logprob, 205 | llm_prefix_ids=llm_prefix_decoding_ids, 206 | llm_ids=[], 207 | llm_logprob=None, 208 | ) 209 | self._update_asr_llm_mean_and_std(beams._next_beams) 210 | beams.update() 211 | 212 | while True: 213 | for beam in beams.list(): 214 | if beam.reach_end: 215 | beams.add_beam(beam) 216 | else: 217 | next_asr_logprobs, asr_inds = self._asr_forward( 218 | encoder_outputs, 219 | beam.asr_prefix_ids + beam.asr_ids, 220 | k=num_beams, 221 | supress_func=self.surpress_token_func 222 | ) 223 | 224 | asr_inds = [x.item() for x in asr_inds] 225 | next_asr_logprobs = [x.item() for x in next_asr_logprobs] 226 | 227 | # important: check if asr respond "stop" at top 228 | # if not, go back normal operation 229 | if asr_inds[0] == self.asr_tokenizer.eos_token_id: 230 | next_asr_logprobs = next_asr_logprobs[0:1] 231 | asr_inds = asr_inds[0:1] 232 | 233 | # drop "ending at not top" 234 | elif self.asr_tokenizer.eos_token_id in asr_inds: 235 | p = asr_inds.index(self.asr_tokenizer.eos_token_id) 236 | next_asr_logprobs = next_asr_logprobs[:p] + next_asr_logprobs[p+1:] 237 | asr_inds = asr_inds[:p] + asr_inds[p+1:] 238 | 239 | 240 | asr_new_content = self.asr_tokenizer.convert_ids_to_bytes( 241 | beam.asr_ids, skip_special_tokens=True 242 | ) 243 | new_content = b''.join(asr_new_content) 244 | 245 | llm_ids = self.llm_tokenizer.tokenize_from_byte(new_content) 246 | if use_cache == 'dynamic': 247 | llm_logprob, normalizer_adjust_n = self.breeze_byte.get_logprob_cache_dynamic( 248 | prefix_decoding_ids=llm_prefix_decoding_ids, 249 | llm_ids=llm_ids, 250 | llm_tokenizer=self.llm_tokenizer 251 | ) 252 | elif use_cache == 'static': 253 | llm_logprob, normalizer_adjust_n = self.breeze_byte.get_logprob_cache_static( 254 | prefix_decoding_ids=llm_prefix_decoding_ids, 255 | llm_ids=llm_ids, 256 | llm_tokenizer=self.llm_tokenizer 257 | ) 258 | else: 259 | llm_logprob, normalizer_adjust_n = self.breeze_byte.get_logprob( 260 | prefix_decoding_ids=llm_prefix_decoding_ids, 261 | llm_ids=llm_ids, 262 | llm_tokenizer=self.llm_tokenizer 263 | ) 264 | 265 | assert normalizer_adjust_n <= 0 266 | 267 | for next_id, next_asr_logprob in zip(asr_inds, next_asr_logprobs): 268 | asr_logprob = next_asr_logprob + beam.asr_logprob 269 | asr_score, llm_score = self._calcualte_asr_llm_score( 270 | asr_normalized_len=len(beam.asr_ids)+1, 271 | asr_logprob=asr_logprob, 272 | llm_normalized_len=len(llm_ids) + normalizer_adjust_n, 273 | llm_logprob=llm_logprob 274 | ) 275 | fuse_score = self.fuse(asr_score, llm_score) 276 | beams.add( 277 | asr_score=asr_score, 278 | llm_score=llm_score, 279 | fuse_score=fuse_score, 280 | asr_prefix_ids=beam.asr_prefix_ids, 281 | asr_ids=beam.asr_ids + [next_id], 282 | asr_logprob=asr_logprob, 283 | llm_prefix_ids=beam.llm_prefix_ids, 284 | llm_ids=llm_ids, 285 | llm_logprob=llm_logprob, 286 | ) 287 | self._update_asr_llm_mean_and_std(beams._next_beams) 288 | beams.update() 289 | self.breeze_byte.kv_cache.remove_unused() 290 | 291 | if DEBUG > 1: 292 | for k, beam in enumerate(beams.list()): 293 | print(f'''[{k}] asr_score={beam.asr_score}, llm_score={beam.llm_score},fuse_score={beam.fuse_score}, 294 | {self.asr_tokenizer.decode(beam.asr_ids)}''') 295 | print() 296 | elif DEBUG > 0: 297 | beam = beams.list()[0] 298 | print(f'''[0] asr_score={beam.asr_score}, llm_score={beam.llm_score},fuse_score={beam.fuse_score}, 299 | {self.asr_tokenizer.decode(beam.asr_ids)} 300 | ''') 301 | 302 | if beams.is_terminated(): 303 | break 304 | 305 | transcription = beams.get_result(self.asr_tokenizer) 306 | return transcription 307 | 308 | def _update_asr_llm_mean_and_std(self, beams_list): 309 | # DEPRECATED 310 | pass 311 | 312 | def _calcualte_asr_llm_score(self, asr_normalized_len, asr_logprob, llm_normalized_len, llm_logprob): 313 | if not (asr_logprob > self.config.logprob_min): 314 | asr_logprob = self.config.logprob_min 315 | 316 | if llm_logprob is None or not (llm_logprob > self.config.logprob_min): 317 | llm_logprob = self.config.logprob_min 318 | asr_score = asr_logprob / asr_normalized_len if asr_normalized_len > 0 else self.config.logprob_min 319 | llm_score = llm_logprob / llm_normalized_len if llm_normalized_len > 0 else self.config.logprob_min 320 | 321 | return asr_score, llm_score 322 | -------------------------------------------------------------------------------- /config_files/prompt/atco2-asr-prompt.yaml: -------------------------------------------------------------------------------- 1 | asr_prompt: "Alfa Bravo Charlie Delta Echo Foxtrot Golf Hotel India Juliett Kilo Lima Mike November Oscar Papa Quebec Romeo Sierra Tango Uniform Victor Whiskey Xray Yankee Zulu One Two Three Four Five Six Seven Eight Nine Zero\nDayton radio, November One Two Three Four Five on one two two point two, over Springfield V-O-R, over.\nNew York Radio, Mooney Three One One Echo.\nColumbia Ground, Cessna Three One Six Zero Foxtrot, south ramp, I-F-R Memphis." 2 | llm_prompt: | 3 | Section 2. Radio Communications Phraseology 4 | and Techniques 5 | 1. General 6 | 1. Radio communications are a critical link in the ATC system. The link can be a strong bond between pilot and controller or it can be broken with surprising speed and disastrous results. Discussion herein provides basic procedures for new pilots and also highlights safe operating concepts for all pilots. 7 | 2. The single, most important thought in pilot‐controller communications is understanding. It is essential, therefore, that pilots acknowledge each radio communication with ATC by using the appropriate aircraft call sign. Brevity is important, and contacts should be kept as brief as possible, but controllers must know what you want to do before they can properly carry out their control duties. And you, the pilot, must know exactly what the controller wants you to do. Since concise phraseology may not always be adequate, use whatever words are necessary to get your message across. Pilots are to maintain vigilance in monitoring air traffic control radio communications frequencies for potential traffic conflicts with their aircraft especially when operating on an active runway and/or when conducting a final approach to landing. 8 | 3. All pilots will find the Pilot/Controller Glossary very helpful in learning what certain words or phrases mean. Good phraseology enhances safety and is the mark of a professional pilot. Jargon, chatter, and “CB” slang have no place in ATC communications. The Pilot/Controller Glossary is the same glossary used in FAA Order JO 7110.65, Air Traffic Control. We recommend that it be studied and reviewed from time to time to sharpen your communication skills. 9 | 2. Radio Technique 10 | 1. Listen before you transmit. Many times you can get the information you want through ATIS or by monitoring the frequency. Except for a few situations where some frequency overlap occurs, if you hear someone else talking, the keying of your transmitter will be futile and you will probably jam their receivers causing them to repeat their call. If you have just changed frequencies, pause, listen, and make sure the frequency is clear. 11 | 2. Think before keying your transmitter. Know what you want to say and if it is lengthy; e.g., a flight plan or IFR position report, jot it down. 12 | 3. The microphone should be very close to your lips and after pressing the mike button, a slight pause may be necessary to be sure the first word is transmitted. Speak in a normal, conversational tone. 13 | 4. When you release the button, wait a few seconds before calling again. The controller or FSS specialist may be jotting down your number, looking for your flight plan, transmitting on a different frequency, or selecting the transmitter for your frequency. 14 | 5. Be alert to the sounds or the lack of sounds in your receiver. Check your volume, recheck your frequency, and make sure that your microphone is not stuck in the transmit position. Frequency blockage can, and has, occurred for extended periods of time due to unintentional transmitter operation. This type of interference is commonly referred to as a “stuck mike,” and controllers may refer to it in this manner when attempting to assign an alternate frequency. If the assigned frequency is completely blocked by this type of interference, use the procedures described for en route IFR radio frequency outage to establish or reestablish communications with ATC. 15 | 6. Be sure that you are within the performance range of your radio equipment and the ground station equipment. Remote radio sites do not always transmit and receive on all of a facility's available frequencies, particularly with regard to VOR sites where you can hear but not reach a ground station's receiver. Remember that higher altitudes increase the range of VHF “line of sight” communications. 16 | 3. Contact Procedures 17 | 1. Initial Contact. 18 | 1. The terms initial contact or initial callup means the first radio call you make to a given facility or the first call to a different controller or FSS specialist within a facility. Use the following format: 19 | 1. Name of the facility being called; 20 | 2. Your full aircraft identification as filed in the flight plan or as discussed in paragraph 4-2-4, Aircraft Call Signs; 21 | 3. When operating on an airport surface, state your position. 22 | 4. The type of message to follow or your request if it is short; and 23 | 5. The word “Over” if required. 24 | EXAMPLE- 25 | 1. “New York Radio, Mooney Three One One Echo.” 26 | 2. “Columbia Ground, Cessna Three One Six Zero Foxtrot, south ramp, I-F-R Memphis.” 27 | 3. “Miami Center, Baron Five Six Three Hotel, request V-F-R traffic advisories.” 28 | 2. Many FSSs are equipped with Remote Communications Outlets (RCOs) and can transmit on the same frequency at more than one location. The frequencies available at specific locations are indicated on charts above FSS communications boxes. To enable the specialist to utilize the correct transmitter, advise the location and the frequency on which you expect a reply. 29 | EXAMPLE- 30 | St. Louis FSS can transmit on frequency 122.3 at either Farmington, Missouri, or Decatur, Illinois, if you are in the vicinity of Decatur, your callup should be “Saint Louis radio, Piper Six Niner Six Yankee, receiving Decatur One Two Two Point Three.” 31 | 3. If radio reception is reasonably assured, inclusion of your request, your position or altitude, and the phrase “(ATIS) Information Charlie received” in the initial contact helps decrease radio frequency congestion. Use discretion; do not overload the controller with information unneeded or superfluous. If you do not get a response from the ground station, recheck your radios or use another transmitter, but keep the next contact short. 32 | EXAMPLE- 33 | “Atlanta Center, Duke Four One Romeo, request V-F-R traffic advisories, Twenty Northwest Rome, seven thousand five hundred, over.” 34 | 2. Initial Contact When Your Transmitting and Receiving Frequencies are Different. 35 | 1. If you are attempting to establish contact with a ground station and you are receiving on a different frequency than that transmitted, indicate the VOR name or the frequency on which you expect a reply. Most FSSs and control facilities can transmit on several VOR stations in the area. Use the appropriate FSS call sign as indicated on charts. 36 | EXAMPLE- 37 | New York FSS transmits on the Kennedy, the Hampton, and the Calverton VORTACs. If you are in the Calverton area, your callup should be “New York radio, Cessna Three One Six Zero Foxtrot, receiving Calverton V-O-R, over.” 38 | 2. If the chart indicates FSS frequencies above the VORTAC or in the FSS communications boxes, transmit or receive on those frequencies nearest your location. 39 | 3. When unable to establish contact and you wish to call any ground station, use the phrase “ANY RADIO (tower) (station), GIVE CESSNA THREE ONE SIX ZERO FOXTROT A CALL ON (frequency) OR (V-O-R).” If an emergency exists or you need assistance, so state. 40 | 3. Subsequent Contacts and Responses to Callup from a Ground Facility. 41 | Use the same format as used for the initial contact except you should state your message or request with the callup in one transmission. The ground station name and the word “Over” may be omitted if the message requires an obvious reply and there is no possibility for misunderstandings. You should acknowledge all callups or clearances unless the controller or FSS specialist advises otherwise. There are some occasions when controllers must issue time‐critical instructions to other aircraft, and they may be in a position to observe your response, either visually or on radar. If the situation demands your response, take appropriate action or immediately advise the facility of any problem. Acknowledge with your aircraft identification, either at the beginning or at the end of your transmission, and one of the words “Wilco,” “Roger,” “Affirmative,” “Negative,” or other appropriate remarks; e.g., “PIPER TWO ONE FOUR LIMA, ROGER.” If you have been receiving services; e.g., VFR traffic advisories and you are leaving the area or changing frequencies, advise the ATC facility and terminate contact. 42 | 4. Acknowledgement of Frequency Changes. 43 | 1. When advised by ATC to change frequencies, acknowledge the instruction. If you select the new frequency without an acknowledgement, the controller's workload is increased because there is no way of knowing whether you received the instruction or have had radio communications failure. 44 | 2. At times, a controller/specialist may be working a sector with multiple frequency assignments. In order to eliminate unnecessary verbiage and to free the controller/specialist for higher priority transmissions, the controller/specialist may request the pilot “(Identification), change to my frequency 134.5.” This phrase should alert the pilot that the controller/specialist is only changing frequencies, not controller/specialist, and that initial callup phraseology may be abbreviated. 45 | EXAMPLE- 46 | “United Two Twenty-Two on one three four point five” or “one three four point five, United Two Twenty-Two.” 47 | 5. Compliance with Frequency Changes. 48 | When instructed by ATC to change frequencies, select the new frequency as soon as possible unless instructed to make the change at a specific time, fix, or altitude. A delay in making the change could result in an untimely receipt of important information. If you are instructed to make the frequency change at a specific time, fix, or altitude, monitor the frequency you are on until reaching the specified time, fix, or altitudes unless instructed otherwise by ATC. 49 | REFERENCE- 50 | AIM, Para 5-3-1, ARTCC Communications. 51 | 4. Aircraft Call Signs 52 | 1. Precautions in the Use of Call Signs. 53 | 1. Improper use of call signs can result in pilots executing a clearance intended for another aircraft. Call signs should never be abbreviated on an initial contact or at any time when other aircraft call signs have similar numbers/sounds or identical letters/number; e.g., Cessna 6132F, Cessna 1622F, Baron 123F, Cherokee 7732F, etc. 54 | EXAMPLE- 55 | Assume that a controller issues an approach clearance to an aircraft at the bottom of a holding stack and an aircraft with a similar call sign (at the top of the stack) acknowledges the clearance with the last two or three numbers of the aircraft's call sign. If the aircraft at the bottom of the stack did not hear the clearance and intervene, flight safety would be affected, and there would be no reason for either the controller or pilot to suspect that anything is wrong. This kind of “human factors” error can strike swiftly and is extremely difficult to rectify. 56 | 2. Pilots, therefore, must be certain that aircraft identification is complete and clearly identified before taking action on an ATC clearance. ATC specialists will not abbreviate call signs of air carrier or other civil aircraft having authorized call signs. ATC specialists may initiate abbreviated call signs of other aircraft by using the prefix and the last three digits/letters of the aircraft identification after communications are established. The pilot may use the abbreviated call sign in subsequent contacts with the ATC specialist. When aware of similar/identical call signs, ATC specialists will take action to minimize errors by emphasizing certain numbers/letters, by repeating the entire call sign, by repeating the prefix, or by asking pilots to use a different call sign temporarily. Pilots should use the phrase “VERIFY CLEARANCE FOR (your complete call sign)” if doubt exists concerning proper identity. 57 | 3. Civil aircraft pilots should state the aircraft type, model or manufacturer's name, followed by the digits/letters of the registration number. When the aircraft manufacturer's name or model is stated, the prefix “N” is dropped; e.g., Aztec Two Four Six Four Alpha. 58 | EXAMPLE- 59 | 1. Bonanza Six Five Five Golf. 60 | 2. Breezy Six One Three Romeo Experimental (omit “Experimental” after initial contact). 61 | 4. Air Taxi or other commercial operators not having FAA authorized call signs should prefix their normal identification with the phonetic word “Tango.” 62 | EXAMPLE- 63 | Tango Aztec Two Four Six Four Alpha. 64 | 5. Air carriers and commuter air carriers having FAA authorized call signs should identify themselves by stating the complete call sign (using group form for the numbers) and the word “super” or “heavy” if appropriate. 65 | EXAMPLE- 66 | 1. United Twenty-Five Heavy. 67 | 2. Midwest Commuter Seven Eleven. 68 | 6. Military aircraft use a variety of systems including serial numbers, word call signs, and combinations of letters/numbers. Examples include Army Copter 48931; Air Force 61782; REACH 31792; Pat 157; Air Evac 17652; Navy Golf Alfa Kilo 21; Marine 4 Charlie 36, etc. 69 | 2. Air Ambulance Flights. 70 | Because of the priority afforded air ambulance flights in the ATC system, extreme discretion is necessary when using the term “MEDEVAC.” It is only intended for those missions of an urgent medical nature and to be utilized only for that portion of the flight requiring priority handling. It is important for ATC to be aware of a flight's MEDEVAC status, and it is the pilot's responsibility to ensure that this information is provided to ATC. 71 | 1. To receive priority handling from ATC, the pilot must verbally identify the flight in radio transmissions by stating “MEDEVAC” followed by the FAA authorized call sign (ICAO 3LD, US Special, or local) or the aircraft civil “N” registration numbers/letters. 72 | EXAMPLE- 73 | If the aircraft identification of the flight indicates DAL51, the pilot states “MEDEVAC Delta Fifty One.” 74 | If the aircraft identification of the flight indicates MDSTR1, the pilot states “MEDEVAC Medstar One.” 75 | If the aircraft identification of the flight indicates N123G or LN123G, the pilot states “MEDEVAC One Two Three Golf”. 76 | 2. If requested by the pilot, ATC will provide additional assistance (e.g., landline notifications) to expedite ground handling of patients, vital organs, or urgently needed medical materials. When possible make these requests to ATC via methods other than through ATC radio frequencies. 77 | 3. MEDEVAC flights may include: 78 | 1. Civilian air ambulance flights responding to medical emergencies (e.g., first call to an accident scene, carrying patients, organ donors, organs, or other urgently needed lifesaving medical material). 79 | 2. Air carrier and air taxi flights responding to medical emergencies. The nature of these medical emergency flights usually concerns the transportation of urgently needed lifesaving medical materials or vital organs, but can include inflight medical emergencies. It is imperative that the company/pilot determine, by the nature/urgency of the specific medical cargo, if priority ATC assistance is required. 80 | 4. When filing a flight plan, pilots may include “L” for MEDEVAC with the aircraft registration letters/digits and/or include “MEDEVAC” in Item 11 (Remarks) of the flight plan or Item 18 (Other Information) of an international flight plan. However, ATC will only use these flight plan entries for informational purposes or as a visual indicator. ATC will only provide priority handling when the pilot verbally identifies the “MEDEVAC” status of the flight as described in subparagraph b1 above. 81 | NOTE- 82 | Civilian air ambulance aircraft operating VFR and without a filed flight plan are eligible for priority handling in accordance with subparagraph b1 above. 83 | 5. ATC will also provide priority handling to HOSP and AIR EVAC flights when verbally requested. These aircraft may file “HOSP” or “AIR EVAC” in either Item 11 (Remarks) of the flight plan or Item 18 of an international flight plan. For aircraft identification in radio transmissions, civilian pilots will use normal call signs when filing “HOSP” and military pilots will use the “EVAC” call sign. 84 | 3. Student Pilots Radio Identification. 85 | 1. The FAA desires to help student pilots in acquiring sufficient practical experience in the environment in which they will be required to operate. To receive additional assistance while operating in areas of concentrated air traffic, student pilots need only identify themselves as a student pilot during their initial call to an FAA radio facility. 86 | EXAMPLE- 87 | Dayton tower, Fleetwing One Two Three Four, student pilot. 88 | 2. This special identification will alert FAA ATC personnel and enable them to provide student pilots with such extra assistance and consideration as they may need. It is recommended that student pilots identify themselves as such, on initial contact with each clearance delivery prior to taxiing, ground control, tower, approach and departure control frequency, or FSS contact. 89 | 5. Description of Interchange or Leased Aircraft 90 | 1. Controllers issue traffic information based on familiarity with airline equipment and color/markings. When an air carrier dispatches a flight using another company's equipment and the pilot does not advise the terminal ATC facility, the possible confusion in aircraft identification can compromise safety. 91 | 2. Pilots flying an “interchange” or “leased” aircraft not bearing the colors/markings of the company operating the aircraft should inform the terminal ATC facility on first contact the name of the operating company and trip number, followed by the company name as displayed on the aircraft, and aircraft type. 92 | EXAMPLE- 93 | Air Cal Three Eleven, United (interchange/lease), Boeing Seven Two Seven. 94 | 6. Ground Station Call Signs 95 | Pilots, when calling a ground station, should begin with the name of the facility being called followed by the type of the facility being called as indicated in TBL 4-2-1. 96 | TBL 4-2-1 97 | Calling a Ground Station 98 | | Facility | Call Sign | 99 | |-----------------------------------------------|----------------------------| 100 | | Airport UNICOM | “Shannon UNICOM” | 101 | | FAA Flight Service Station | “Chicago Radio” | 102 | | Airport Traffic Control Tower | “Augusta Tower” | 103 | | Clearance Delivery Position (IFR) | “Dallas Clearance Delivery”| 104 | | Ground Control Position in Tower | “Miami Ground” | 105 | | Radar or Nonradar Approach Control Position | “Oklahoma City Approach” | 106 | | Radar Departure Control Position | “St. Louis Departure” | 107 | | FAA Air Route Traffic Control Center | “Washington Center” | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 7. Phonetic Alphabet 118 | The International Civil Aviation Organization (ICAO) phonetic alphabet is used by FAA personnel when communications conditions are such that the information cannot be readily received without their use. ATC facilities may also request pilots to use phonetic letter equivalents when aircraft with similar sounding identifications are receiving communications on the same frequency. Pilots should use the phonetic alphabet when identifying their aircraft during initial contact with air traffic control facilities. Additionally, use the phonetic equivalents for single letters and to spell out groups of letters or difficult words during adverse communications conditions. (See TBL 4-2-2.) 119 | TBL 4-2-2 120 | Phonetic Alphabet/Morse Code 121 | | Character | Morse Code | Telephony | Phonic (Pronunciation) | 122 | |-----------|---------------|-----------|------------------------| 123 | | A | ● − | Alfa | (AL-FAH) | 124 | | B | − ● ● ● | Bravo | (BRAH-VOH) | 125 | | C | − ● − ● | Charlie | (CHAR-LEE) or (SHAR-LEE)| 126 | | D | − ● ● | Delta | (DELL-TAH) | 127 | | E | ● | Echo | (ECK-OH) | 128 | | F | ● ● − ● | Foxtrot | (FOKS-TROT) | 129 | | G | − − ● | Golf | (GOLF) | 130 | | H | ● ● ● ● | Hotel | (HOH-TEL) | 131 | | I | ● ● | India | (IN-DEE-AH) | 132 | | J | ● − − − | Juliett | (JEW-LEE-ETT) | 133 | | K | − ● − | Kilo | (KEY-LOH) | 134 | | L | ● − ● ● | Lima | (LEE-MAH) | 135 | | M | − − | Mike | (MIKE) | 136 | | N | − ● | November | (NO-VEM-BER) | 137 | | O | − − − | Oscar | (OSS-CAH) | 138 | | P | ● − − ● | Papa | (PAH-PAH) | 139 | | Q | − − ● − | Quebec | (KEH-BECK) | 140 | | R | ● − ● | Romeo | (ROW-ME-OH) | 141 | | S | ● ● ● | Sierra | (SEE-AIR-RAH) | 142 | | T | − | Tango | (TANG-GO) | 143 | | U | ● ● − | Uniform | (YOU-NEE-FORM) or (OO-NEE-FORM)| 144 | | V | ● ● ● − | Victor | (VIK-TAH) | 145 | | W | ● − − | Whiskey | (WISS-KEY) | 146 | | X | − ● ● − | Xray | (ECKS-RAY) | 147 | | Y | − ● − − | Yankee | (YANG-KEY) | 148 | | Z | − − ● ● | Zulu | (ZOO-LOO) | 149 | | 1 | ● − − − − | One | (WUN) | 150 | | 2 | ● ● − − − | Two | (TOO) | 151 | | 3 | ● ● ● − − | Three | (TREE) | 152 | | 4 | ● ● ● ● − | Four | (FOW-ER) | 153 | | 5 | ● ● ● ● ● | Five | (FIFE) | 154 | | 6 | − ● ● ● ● | Six | (SIX) | 155 | | 7 | − − ● ● ● | Seven | (SEV-EN) | 156 | | 8 | − − − ● ● | Eight | (AIT) | 157 | | 9 | − − − − ● | Nine | (NIN-ER) | 158 | | 0 | − − − − − | Zero | (ZEE-RO) | 159 | 8. Figures 160 | 1. Figures indicating hundreds and thousands in round number, as for ceiling heights, and upper wind levels up to 9,900 must be spoken in accordance with the following. 161 | EXAMPLE- 162 | 1. 500 . . . . . . . .five hundred 163 | 2. 4,500 . . . . . . . .four thousand five hundred 164 | 2. Numbers above 9,900 must be spoken by separating the digits preceding the word “thousand.” 165 | EXAMPLE- 166 | 1. 10,000 . . . . . . . .one zero thousand 167 | 2. 13,500 . . . . . . . .one three thousand five hundred 168 | 3. Transmit airway or jet route numbers as follows. 169 | EXAMPLE- 170 | 1. V12. . . . . . . .Victor Twelve 171 | 2. J533. . . . . . . .J Five Thirty-Three 172 | 4. All other numbers must be transmitted by pronouncing each digit. 173 | EXAMPLE- 174 | 10 . . . . . . . .one zero 175 | 5. When a radio frequency contains a decimal point, the decimal point is spoken as “POINT.” 176 | EXAMPLE- 177 | 122.1. . . . . . . .one two two point one 178 | NOTE- 179 | ICAO procedures require the decimal point be spoken as “DECIMAL.” The FAA will honor such usage by military aircraft and all other aircraft required to use ICAO procedures. 180 | 9. Altitudes and Flight Levels 181 | 1. Up to but not including 18,000 feet MSL, state the separate digits of the thousands plus the hundreds if appropriate. 182 | EXAMPLE- 183 | 1. 12,000 . . . . . . . .one two thousand 184 | 2. 12,500 . . . . . . . .one two thousand five hundred 185 | 2. At and above 18,000 feet MSL (FL 180), state the words “flight level” followed by the separate digits of the flight level. 186 | EXAMPLE- 187 | 1. 190 . . . . . . . .Flight Level One Niner Zero 188 | 2. 275 . . . . . . . .Flight Level Two Seven Five 189 | 10. Directions 190 | The three digits of bearing, course, heading, or wind direction should always be magnetic. The word “true” must be added when it applies. 191 | EXAMPLE- 192 | 1. (Magnetic course) 005 . . . . . . . .zero zero five 193 | 2. (True course) 050 . . . . . . . .zero five zero true 194 | 3. (Magnetic bearing) 360 . . . . . . . .three six zero 195 | 4. (Magnetic heading) 100 . . . . . . . .heading one zero zero 196 | 5. (Wind direction) 220 . . . . . . . .wind two two zero 197 | 11. Speeds 198 | The separate digits of the speed followed by the word “KNOTS.” Except, controllers may omit the word “KNOTS” when using speed adjustment procedures; e.g., “REDUCE/INCREASE SPEED TO TWO FIVE ZERO.” 199 | EXAMPLE- 200 | (Speed) 250 . . . . . . . . two five zero knots 201 | (Speed) 190 . . . . . . . . one niner zero knots 202 | The separate digits of the Mach Number preceded by “Mach.” 203 | EXAMPLE- 204 | (Mach number) 1.5 . . . . . . . .Mach one point five 205 | (Mach number) 0.64. . . . . . . .Mach point six four 206 | (Mach number) 0.7. . . . . . . .Mach point seven 207 | 12. Time 208 | 1. FAA uses Coordinated Universal Time (UTC) for all operations. The word “local” or the time zone equivalent must be used to denote local when local time is given during radio and telephone communications. The term “Zulu” may be used to denote UTC. 209 | EXAMPLE- 210 | 0920 UTC. . . . . . . .zero niner two zero, 211 | zero one two zero pacific or local, 212 | or one twenty AM 213 | 2. To convert from Standard Time to Coordinated Universal Time: 214 | TBL 4-2-3 215 | Standard Time to Coordinated Universal Time 216 | | Time Zone | Add Hours | 217 | |-------------------------|-----------| 218 | | Eastern Standard Time | 5 hours | 219 | | Central Standard Time | 6 hours | 220 | | Mountain Standard Time | 7 hours | 221 | | Pacific Standard Time | 8 hours | 222 | | Alaska Standard Time | 9 hours | 223 | | Hawaii Standard Time | 10 hours | 224 | 225 | 226 | 227 | 228 | 229 | 230 | NOTE- 231 | For daylight time, subtract 1 hour. 232 | 3. A reference may be made to local daylight or standard time utilizing the 24-hour clock system. The hour is indicated by the first two figures and the minutes by the last two figures. 233 | EXAMPLE- 234 | 0000 . . . . . . . . zero zero zero zero 235 | 0920 . . . . . . . . zero niner two zero 236 | 4. Time may be stated in minutes only (two figures) in radiotelephone communications when no misunderstanding is likely to occur. 237 | 5. Current time in use at a station is stated in the nearest quarter minute in order that pilots may use this information for time checks. Fractions of a quarter minute less than 8 seconds are stated as the preceding quarter minute; fractions of a quarter minute of 8 seconds or more are stated as the succeeding quarter minute. 238 | EXAMPLE- 239 | 0929:05 . . . . . . . .time, zero niner two niner 240 | 0929:10 . . . . . . . .time, zero niner two niner and one-quarter 241 | 13. Communications with Tower when Aircraft Transmitter or Receiver or Both are Inoperative 242 | 1. Arriving Aircraft. 243 | 1. Receiver inoperative. 244 | 1. If you have reason to believe your receiver is inoperative, remain outside or above the Class D surface area until the direction and flow of traffic has been determined; then, advise the tower of your type aircraft, position, altitude, intention to land, and request that you be controlled with light signals. 245 | REFERENCE- 246 | AIM, Para 4-3-13, Traffic Control Light Signals. 247 | 2. When you are approximately 3 to 5 miles from the airport, advise the tower of your position and join the airport traffic pattern. From this point on, watch the tower for light signals. Thereafter, if a complete pattern is made, transmit your position downwind and/or turning base leg. 248 | 2. Transmitter inoperative. Remain outside or above the Class D surface area until the direction and flow of traffic has been determined; then, join the airport traffic pattern. Monitor the primary local control frequency as depicted on Sectional Charts for landing or traffic information, and look for a light signal which may be addressed to your aircraft. During hours of daylight, acknowledge tower transmissions or light signals by rocking your wings. At night, acknowledge by blinking the landing or navigation lights. To acknowledge tower transmissions during daylight hours, hovering helicopters will turn in the direction of the controlling facility and flash the landing light. While in flight, helicopters should show their acknowledgement of receiving a transmission by making shallow banks in opposite directions. At night, helicopters will acknowledge receipt of transmissions by flashing either the landing or the search light. 249 | 3. Transmitter and receiver inoperative. Remain outside or above the Class D surface area until the direction and flow of traffic has been determined; then, join the airport traffic pattern and maintain visual contact with the tower to receive light signals. Acknowledge light signals as noted above. 250 | 2. Departing Aircraft. If you experience radio failure prior to leaving the parking area, make every effort to have the equipment repaired. If you are unable to have the malfunction repaired, call the tower by telephone and request authorization to depart without two‐way radio communications. If tower authorization is granted, you will be given departure information and requested to monitor the tower frequency or watch for light signals as appropriate. During daylight hours, acknowledge tower transmissions or light signals by moving the ailerons or rudder. At night, acknowledge by blinking the landing or navigation lights. If radio malfunction occurs after departing the parking area, watch the tower for light signals or monitor tower frequency. 251 | REFERENCE- 252 | 14 CFR Section 91.125 and 14 CFR Section 91.129. 253 | 14. Communications for VFR Flights 254 | 1. FSSs and Supplemental Weather Service Locations (SWSL) are allocated frequencies for different functions; for example, in Alaska, certain FSSs provide Local Airport Advisory on 123.6 MHz or other frequencies which can be found in the Chart Supplement. If you are in doubt as to what frequency to use, 122.2 MHz is assigned to the majority of FSSs as a common en route simplex frequency. 255 | NOTE- 256 | In order to expedite communications, state the frequency being used and the aircraft location during initial callup. 257 | EXAMPLE- 258 | Dayton radio, November One Two Three Four Five on one two two point two, over Springfield V-O-R, over. 259 | 2. Certain VOR voice channels are being utilized for recorded broadcasts; for example, ATIS. These services and appropriate frequencies are listed in the Chart Supplement. On VFR flights, pilots are urged to monitor these frequencies. When in contact with a control facility, notify the controller if you plan to leave the frequency to monitor these broadcasts. 260 | 261 | Here is an ATC transcription: --------------------------------------------------------------------------------