├── src ├── __init__.py ├── whisper │ ├── version.py │ ├── __main__.py │ ├── assets │ │ └── mel_filters.npz │ ├── normalizers │ │ ├── __init__.py │ │ └── basic.py │ ├── triton_ops.py │ ├── audio.py │ ├── __init__.py │ ├── model.py │ ├── utils.py │ ├── tokenizer.py │ └── timing.py ├── ui │ ├── ctkAlert │ │ ├── __init__.py │ │ ├── icons │ │ │ ├── info_dark.png │ │ │ ├── error_dark.png │ │ │ ├── error_light.png │ │ │ ├── info_light.png │ │ │ ├── success_dark.png │ │ │ ├── warning_dark.png │ │ │ ├── success_light.png │ │ │ └── warning_light.png │ │ └── ctkalert.py │ ├── ctkLoader │ │ ├── __init__.py │ │ └── ctkloader.py │ ├── icons │ │ ├── close_dark.png │ │ ├── close_light.png │ │ ├── settings_dark.png │ │ ├── audio_file_dark.png │ │ ├── audio_file_light.png │ │ ├── microphone_dark.png │ │ ├── microphone_light.png │ │ ├── settings_light.png │ │ ├── subtitle_light.png │ │ ├── subtitles_dark.png │ │ ├── translate_light.png │ │ └── translation_dark.png │ ├── download_models.py │ ├── style.py │ ├── icons.py │ ├── ctk_tooltip.py │ ├── add_subtitles.py │ ├── translate.py │ ├── live_transcribe.py │ ├── transcribe.py │ ├── ctkdropdown.py │ └── settings.py └── logic │ ├── __init__.py │ ├── settings.py │ ├── transcriber.py │ ├── gpu_details.py │ ├── model_requirements.py │ └── live_transcriber.py ├── demo ├── 1.PNG ├── 2.PNG ├── 3.PNG ├── 4.PNG ├── 5.PNG ├── 6.PNG ├── 7.PNG ├── 8.PNG ├── 9.PNG ├── 10.PNG └── 11.PNG ├── requirements.txt ├── assets └── icons │ ├── logo.ico │ ├── logo.png │ ├── close_dark.png │ ├── close_light.png │ ├── github_dark.png │ ├── help_dark.png │ ├── help_light.png │ ├── paypal_dark.png │ ├── github_light.png │ ├── paypal_light.png │ ├── settings_dark.png │ ├── audio_file_dark.png │ ├── audio_file_light.png │ ├── microphone_dark.png │ ├── microphone_light.png │ ├── settings_light.png │ ├── subtitles_dark.png │ ├── subtitles_light.png │ ├── translation_dark.png │ └── translation_light.png ├── README.md ├── setup.py └── main.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/whisper/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "20231117" 2 | -------------------------------------------------------------------------------- /src/ui/ctkAlert/__init__.py: -------------------------------------------------------------------------------- 1 | from .ctkalert import CTkAlert 2 | -------------------------------------------------------------------------------- /src/ui/ctkLoader/__init__.py: -------------------------------------------------------------------------------- 1 | from .ctkloader import CTkLoader 2 | -------------------------------------------------------------------------------- /src/whisper/__main__.py: -------------------------------------------------------------------------------- 1 | from .transcribe import cli 2 | 3 | cli() 4 | -------------------------------------------------------------------------------- /demo/1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/1.PNG -------------------------------------------------------------------------------- /demo/2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/2.PNG -------------------------------------------------------------------------------- /demo/3.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/3.PNG -------------------------------------------------------------------------------- /demo/4.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/4.PNG -------------------------------------------------------------------------------- /demo/5.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/5.PNG -------------------------------------------------------------------------------- /demo/6.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/6.PNG -------------------------------------------------------------------------------- /demo/7.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/7.PNG -------------------------------------------------------------------------------- /demo/8.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/8.PNG -------------------------------------------------------------------------------- /demo/9.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/9.PNG -------------------------------------------------------------------------------- /demo/10.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/10.PNG -------------------------------------------------------------------------------- /demo/11.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/11.PNG -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/requirements.txt -------------------------------------------------------------------------------- /assets/icons/logo.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/logo.ico -------------------------------------------------------------------------------- /assets/icons/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/logo.png -------------------------------------------------------------------------------- /assets/icons/close_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/close_dark.png -------------------------------------------------------------------------------- /assets/icons/close_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/close_light.png -------------------------------------------------------------------------------- /assets/icons/github_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/github_dark.png -------------------------------------------------------------------------------- /assets/icons/help_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/help_dark.png -------------------------------------------------------------------------------- /assets/icons/help_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/help_light.png -------------------------------------------------------------------------------- /assets/icons/paypal_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/paypal_dark.png -------------------------------------------------------------------------------- /src/ui/icons/close_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/close_dark.png -------------------------------------------------------------------------------- /src/ui/icons/close_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/close_light.png -------------------------------------------------------------------------------- /assets/icons/github_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/github_light.png -------------------------------------------------------------------------------- /assets/icons/paypal_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/paypal_light.png -------------------------------------------------------------------------------- /assets/icons/settings_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/settings_dark.png -------------------------------------------------------------------------------- /src/ui/icons/settings_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/settings_dark.png -------------------------------------------------------------------------------- /assets/icons/audio_file_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/audio_file_dark.png -------------------------------------------------------------------------------- /assets/icons/audio_file_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/audio_file_light.png -------------------------------------------------------------------------------- /assets/icons/microphone_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/microphone_dark.png -------------------------------------------------------------------------------- /assets/icons/microphone_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/microphone_light.png -------------------------------------------------------------------------------- /assets/icons/settings_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/settings_light.png -------------------------------------------------------------------------------- /assets/icons/subtitles_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/subtitles_dark.png -------------------------------------------------------------------------------- /assets/icons/subtitles_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/subtitles_light.png -------------------------------------------------------------------------------- /assets/icons/translation_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/translation_dark.png -------------------------------------------------------------------------------- /src/ui/icons/audio_file_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/audio_file_dark.png -------------------------------------------------------------------------------- /src/ui/icons/audio_file_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/audio_file_light.png -------------------------------------------------------------------------------- /src/ui/icons/microphone_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/microphone_dark.png -------------------------------------------------------------------------------- /src/ui/icons/microphone_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/microphone_light.png -------------------------------------------------------------------------------- /src/ui/icons/settings_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/settings_light.png -------------------------------------------------------------------------------- /src/ui/icons/subtitle_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/subtitle_light.png -------------------------------------------------------------------------------- /src/ui/icons/subtitles_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/subtitles_dark.png -------------------------------------------------------------------------------- /src/ui/icons/translate_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/translate_light.png -------------------------------------------------------------------------------- /src/ui/icons/translation_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/translation_dark.png -------------------------------------------------------------------------------- /assets/icons/translation_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/translation_light.png -------------------------------------------------------------------------------- /src/ui/ctkAlert/icons/info_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/info_dark.png -------------------------------------------------------------------------------- /src/whisper/assets/mel_filters.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/whisper/assets/mel_filters.npz -------------------------------------------------------------------------------- /src/ui/ctkAlert/icons/error_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/error_dark.png -------------------------------------------------------------------------------- /src/ui/ctkAlert/icons/error_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/error_light.png -------------------------------------------------------------------------------- /src/ui/ctkAlert/icons/info_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/info_light.png -------------------------------------------------------------------------------- /src/ui/ctkAlert/icons/success_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/success_dark.png -------------------------------------------------------------------------------- /src/ui/ctkAlert/icons/warning_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/warning_dark.png -------------------------------------------------------------------------------- /src/ui/ctkAlert/icons/success_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/success_light.png -------------------------------------------------------------------------------- /src/ui/ctkAlert/icons/warning_light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/warning_light.png -------------------------------------------------------------------------------- /src/whisper/normalizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import BasicTextNormalizer as BasicTextNormalizer 2 | from .english import EnglishTextNormalizer as EnglishTextNormalizer 3 | -------------------------------------------------------------------------------- /src/logic/__init__.py: -------------------------------------------------------------------------------- 1 | from .gpu_details import GPUInfo 2 | from .settings import SettingsHandler 3 | from .model_requirements import ModelRequirements 4 | from .transcriber import Transcriber 5 | -------------------------------------------------------------------------------- /src/ui/download_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .. import whisper 4 | 5 | 6 | def download_model(): 7 | models = whisper.available_models() 8 | cache_folder = os.path.join(os.path.expanduser('~'), f'.cache{os.path.sep}whisper{os.path.sep}') 9 | 10 | for model in models: 11 | print(f"Downloading {model}") 12 | whisper._download(url=whisper._MODELS[model], root=cache_folder, in_memory=False) 13 | 14 | 15 | download_model() 16 | -------------------------------------------------------------------------------- /src/ui/style.py: -------------------------------------------------------------------------------- 1 | FONTS = { 2 | "title_bold": ("Inter", 24, "bold"), 3 | "title": ("Inter", 22, "normal"), 4 | "subtitle_bold": ("Inter", 18, "bold"), 5 | "subtitle": ("Inter", 18, "normal"), 6 | "normal_bold": ("Inter", 15, "bold"), 7 | "normal": ("Inter", 15, "normal"), 8 | "small": ("Inter", 13, "normal"), 9 | } 10 | 11 | DROPDOWN = { 12 | "corner_radius": 2, 13 | "alpha": 1.0, 14 | "frame_corner_radius": 5, 15 | "x": 0, 16 | "hover": False, 17 | "justify": "left" 18 | } 19 | -------------------------------------------------------------------------------- /src/ui/icons.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import customtkinter as ctk 4 | from PIL import Image 5 | 6 | current_path = os.path.dirname(os.path.realpath(__file__)) 7 | icon_path = f"{current_path}{os.path.sep}icons{os.path.sep}" 8 | 9 | icons = { 10 | "close": ctk.CTkImage(dark_image=Image.open(f"{icon_path}close_dark.png"), 11 | light_image=Image.open(f"{icon_path}close_light.png"), size=(30, 30)), 12 | "audio_file": ctk.CTkImage(dark_image=Image.open(f"{icon_path}audio_file_dark.png"), 13 | light_image=Image.open(f"{icon_path}audio_file_light.png"), size=(30, 30)), 14 | "translate": ctk.CTkImage(dark_image=Image.open(f"{icon_path}translation_dark.png"), 15 | light_image=Image.open(f"{icon_path}translate_light.png"), size=(30, 30)), 16 | "microphone": ctk.CTkImage(dark_image=Image.open(f"{icon_path}microphone_dark.png"), 17 | light_image=Image.open(f"{icon_path}microphone_light.png"), size=(30, 30)), 18 | "subtitle": ctk.CTkImage(dark_image=Image.open(f"{icon_path}subtitles_dark.png"), 19 | light_image=Image.open(f"{icon_path}subtitle_light.png"), size=(30, 30)), 20 | "settings": ctk.CTkImage(dark_image=Image.open(f"{icon_path}settings_dark.png"), 21 | light_image=Image.open(f"{icon_path}settings_light.png"), size=(30, 30)) 22 | } 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 |

Whisper Transcriber GUI

6 | 7 | ### 8 | 9 |

Modern Desktop Application offering a suite of tools for audio/video text recognition and a variety of other useful utilities.

10 | 11 | ### 12 | 13 |

Demo

14 | 15 | ### 16 | 17 |

Check demo folder for rest of other images

18 | 19 |

20 | 21 |

22 | 23 | ### 24 | 25 |

Installation

26 | 27 | ### 28 | 29 |

Setup

30 | 31 | ### 32 | 33 | ``` 34 | git clone https://github.com/rudymohammadbali/Whisper-Transcriber.git 35 | ``` 36 | ``` 37 | cd Whisper-Transcriber 38 | ``` 39 | ``` 40 | python setup.py 41 | ``` 42 | and Finally run it 43 | ``` 44 | python main.py 45 | ``` 46 | ### 47 | 48 |

Special thanks to @Akascape for bringing this app to life

49 | 50 |

Support

51 | 52 | ### 53 | 54 |

If you'd like to support my ongoing efforts in sharing fantastic open-source projects, you can contribute by making a donation via PayPal.

55 | 56 | ### 57 | 58 |
59 | 60 | paypal logo 61 | 62 |
63 | 64 | ### 65 | -------------------------------------------------------------------------------- /src/logic/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | DEFAULT_CONFIG = { 6 | "theme": "system", 7 | "color_theme": "blue", 8 | "model_size": "base", 9 | "language": "auto detection", 10 | "device": "cpu" 11 | } 12 | 13 | 14 | class SettingsHandler: 15 | def __init__(self): 16 | current_dir = os.path.dirname(__file__) 17 | file_path = os.path.join(current_dir, '..', '..', f'assets{os.path.sep}config') 18 | self.config_path = os.path.normpath(file_path) 19 | self.settings_file_path = os.path.join(self.config_path, "settings.json") 20 | 21 | def save_settings(self, **new_settings: dict): 22 | try: 23 | existing_settings = self.load_settings() 24 | existing_settings.update(new_settings) 25 | if not os.path.exists(self.config_path): 26 | os.mkdir(self.config_path) 27 | 28 | with open(self.settings_file_path, 'w') as file: 29 | json.dump(existing_settings, file) 30 | return "Your settings have been saved successfully." 31 | except FileNotFoundError: 32 | pass 33 | except PermissionError: 34 | pass 35 | 36 | def load_settings(self): 37 | try: 38 | if os.path.exists(self.settings_file_path) and os.path.getsize(self.settings_file_path) > 0: 39 | with open(self.settings_file_path, "r") as file: 40 | loaded_settings = json.load(file) 41 | return loaded_settings 42 | else: 43 | with open(self.settings_file_path, 'w') as file: 44 | json.dump(DEFAULT_CONFIG, file) 45 | return DEFAULT_CONFIG 46 | except FileNotFoundError: 47 | return DEFAULT_CONFIG 48 | 49 | def reset_settings(self): 50 | try: 51 | if not os.path.exists(self.config_path): 52 | os.mkdir(self.config_path) 53 | 54 | with open(self.settings_file_path, 'w') as file: 55 | json.dump(DEFAULT_CONFIG, file) 56 | return "Your settings have been reset to default values." 57 | except FileNotFoundError: 58 | pass 59 | except PermissionError: 60 | pass 61 | -------------------------------------------------------------------------------- /src/whisper/normalizers/basic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unicodedata 3 | 4 | import regex 5 | 6 | # non-ASCII letters that are not separated by "NFKD" normalization 7 | ADDITIONAL_DIACRITICS = { 8 | "œ": "oe", 9 | "Œ": "OE", 10 | "ø": "o", 11 | "Ø": "O", 12 | "æ": "ae", 13 | "Æ": "AE", 14 | "ß": "ss", 15 | "ẞ": "SS", 16 | "đ": "d", 17 | "Đ": "D", 18 | "ð": "d", 19 | "Ð": "D", 20 | "þ": "th", 21 | "Þ": "th", 22 | "ł": "l", 23 | "Ł": "L", 24 | } 25 | 26 | 27 | def remove_symbols_and_diacritics(s: str, keep=""): 28 | """ 29 | Replace any other markers, symbols, and punctuations with a space, 30 | and drop any diacritics (category 'Mn' and some manual mappings) 31 | """ 32 | return "".join( 33 | c 34 | if c in keep 35 | else ADDITIONAL_DIACRITICS[c] 36 | if c in ADDITIONAL_DIACRITICS 37 | else "" 38 | if unicodedata.category(c) == "Mn" 39 | else " " 40 | if unicodedata.category(c)[0] in "MSP" 41 | else c 42 | for c in unicodedata.normalize("NFKD", s) 43 | ) 44 | 45 | 46 | def remove_symbols(s: str): 47 | """ 48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics 49 | """ 50 | return "".join( 51 | " " if unicodedata.category(c)[0] in "MSP" else c 52 | for c in unicodedata.normalize("NFKC", s) 53 | ) 54 | 55 | 56 | class BasicTextNormalizer: 57 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): 58 | self.clean = ( 59 | remove_symbols_and_diacritics if remove_diacritics else remove_symbols 60 | ) 61 | self.split_letters = split_letters 62 | 63 | def __call__(self, s: str): 64 | s = s.lower() 65 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets 66 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis 67 | s = self.clean(s).lower() 68 | 69 | if self.split_letters: 70 | s = " ".join(regex.findall(r"\X", s, regex.U)) 71 | 72 | s = re.sub( 73 | r"\s+", " ", s 74 | ) # replace any successive whitespace characters with a space 75 | 76 | return s 77 | -------------------------------------------------------------------------------- /src/logic/transcriber.py: -------------------------------------------------------------------------------- 1 | from .. import whisper 2 | 3 | from .settings import SettingsHandler 4 | from deep_translator import GoogleTranslator 5 | 6 | 7 | class Transcriber: 8 | def __init__(self, audio: str): 9 | self.audio = audio 10 | get_config = SettingsHandler() 11 | config = get_config.load_settings() 12 | 13 | model = config.get("model_size", "base").lower() 14 | self.language = config.get("language", "auto detect").lower() 15 | device = config.get("device", "cpu").lower() 16 | self.fp16 = False 17 | 18 | if self.language == "english" and model not in ["large", "large-v1", "large-v2"]: 19 | model += ".en" 20 | 21 | if device == "gpu": 22 | device = "cuda" 23 | self.fp16 = True 24 | 25 | self.load_model = whisper.load_model(name=model, device=device) 26 | 27 | self.options = { 28 | "task": "transcribe", 29 | "fp16": self.fp16 30 | } 31 | 32 | def audio_recognition(self, cancel_func=any): 33 | print("Audio Recognition started...") 34 | if self.language == "auto detection": 35 | audio = whisper.load_audio(self.audio) 36 | audio = whisper.pad_or_trim(audio) 37 | 38 | mel = whisper.log_mel_spectrogram(audio).to(self.load_model.device) 39 | 40 | _, probs = self.load_model.detect_language(mel) 41 | detected_language = max(probs, key=probs.get) 42 | result = whisper.transcribe(model=self.load_model, audio=self.audio, 43 | language=detected_language, **self.options, cancel_func=cancel_func) 44 | else: 45 | result = whisper.transcribe(model=self.load_model, audio=self.audio, 46 | **self.options, language=self.language, cancel_func=cancel_func) 47 | 48 | return result 49 | 50 | def translate_audio(self, cancel_func, to_language): 51 | print("Audio Translate started...") 52 | result = self.audio_recognition(cancel_func=cancel_func) 53 | text = str(result["text"]).strip() 54 | language = to_language 55 | translated_text = GoogleTranslator(source='auto', target=language).translate(text=text) 56 | # langs_list = GoogleTranslator().get_supported_languages() 57 | return translated_text 58 | -------------------------------------------------------------------------------- /src/ui/ctkLoader/ctkloader.py: -------------------------------------------------------------------------------- 1 | import customtkinter as ctk 2 | 3 | CTkFrame = { 4 | "width": 500, 5 | "height": 200, 6 | "corner_radius": 5, 7 | "border_width": 1, 8 | "fg_color": ["#F7F8FA", "#2B2D30"], 9 | "border_color": ["#D3D5DB", "#4E5157"] 10 | } 11 | 12 | CTkButton = { 13 | "width": 150, 14 | "height": 40, 15 | "corner_radius": 5, 16 | "border_width": 1, 17 | "fg_color": ["#3574F0", "#2B2D30"], 18 | "hover_color": ["#3369D6", "#4E5157"], 19 | "border_color": ["#C2D6FC", "#4E5157"], 20 | "text_color": ["#FFFFFF", "#DFE1E5"], 21 | "text_color_disabled": ["#A8ADBD", "#6F737A"] 22 | } 23 | 24 | CTkProgressBar = { 25 | "width": 460, 26 | "height": 5, 27 | "corner_radius": 5, 28 | "border_width": 0, 29 | "fg_color": ["#DFE1E5", "#43454A"], 30 | "progress_color": ["#3574F0", "#3574F0"], 31 | "border_color": ["#3574F0", "#4E5157"] 32 | } 33 | 34 | font_title = ("Inter", 20, "bold") 35 | font_normal = ("Inter", 14, "normal") 36 | 37 | 38 | class CTkLoader(ctk.CTkFrame): 39 | def __init__(self, parent: any, title: str, msg: str, cancel_func: any): 40 | super().__init__(master=parent, **CTkFrame) 41 | self.title = title.capitalize() 42 | self.msg = msg.capitalize() 43 | self.grid_propagate(False) 44 | self.grid_columnconfigure(0, weight=1) 45 | 46 | self.cancel_func = cancel_func 47 | 48 | title = ctk.CTkLabel(self, text=self.title, text_color=("#000000", "#DFE1E5"), font=font_title) 49 | title.grid(row=0, column=0, padx=20, pady=(20, 10), sticky="nw") 50 | 51 | msg = ctk.CTkLabel(self, text=self.msg, text_color=("#000000", "#DFE1E5"), font=font_normal) 52 | msg.grid(row=1, column=0, padx=20, pady=10, sticky="nsw") 53 | 54 | self.progressbar = ctk.CTkProgressBar(self, mode="indeterminate", **CTkProgressBar) 55 | self.progressbar.grid(row=2, column=0, padx=20, pady=10, sticky="sw") 56 | self.progressbar.start() 57 | 58 | button = ctk.CTkButton(self, text="Cancel", command=self.cancel_callback, **CTkButton) 59 | button.grid(row=3, column=0, padx=20, pady=(10, 20), sticky="nsew") 60 | 61 | # self.pack(anchor="center", expand=True) 62 | self.grid(row=0, column=0, sticky="ew", padx=20) 63 | 64 | def hide_loader(self): 65 | self.progressbar.stop() 66 | self.destroy() 67 | 68 | def cancel_callback(self): 69 | self.cancel_func() 70 | self.after(100, self.hide_loader) 71 | -------------------------------------------------------------------------------- /src/logic/gpu_details.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import GPUtil 5 | import torch 6 | 7 | 8 | class GPUInfo: 9 | def __init__(self): 10 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | self.cuda_available = torch.cuda.is_available() 12 | current_dir = os.path.dirname(__file__) 13 | file_path = os.path.join(current_dir, '..', '..', f"assets{os.path.sep}config") 14 | self.config_path = os.path.normpath(file_path) 15 | self.gpu_info_file_path = os.path.join(self.config_path, "gpu_info.json") 16 | 17 | def get_gpu_info(self): 18 | if self.cuda_available: 19 | gpu_count = torch.cuda.device_count() 20 | current_device = torch.cuda.current_device() 21 | gpu_name = torch.cuda.get_device_name(current_device) 22 | total_memory = round(torch.cuda.get_device_properties(current_device).total_memory / (1024 ** 3)) 23 | else: 24 | gpus = GPUtil.getGPUs() 25 | gpu_count = len(gpus) 26 | 27 | if gpu_count == 0: 28 | current_device = "No GPU found." 29 | gpu_name = "N/A" 30 | total_memory = 0 31 | else: 32 | current_gpu = gpus[0] 33 | current_device = f"{current_gpu.name} ({current_gpu.id})" 34 | gpu_name = current_gpu.name 35 | total_memory = current_gpu.memoryTotal / 1024 36 | 37 | return {"cuda_available": self.cuda_available, 38 | "gpu_count" : gpu_count, 39 | "current_gpu" : current_device, 40 | "gpu_name" : gpu_name, 41 | "total_memory" : total_memory} 42 | 43 | def save_gpu_info(self): 44 | settings = self.get_gpu_info() 45 | 46 | try: 47 | if not os.path.exists(self.config_path): 48 | os.mkdir(self.config_path) 49 | 50 | with open(self.gpu_info_file_path, 'w') as file: json.dump(settings, file) 51 | except FileNotFoundError or PermissionError: 52 | pass 53 | 54 | def load_gpu_info(self): 55 | if not os.path.exists(self.gpu_info_file_path) or os.path.getsize(self.gpu_info_file_path) == 0: 56 | self.save_gpu_info() 57 | try: 58 | with open(self.gpu_info_file_path, 'r') as file: loaded_settings = json.load(file) 59 | except FileNotFoundError: 60 | self.save_gpu_info() 61 | with open(self.gpu_info_file_path, 'r') as file: loaded_settings = json.load(file) 62 | return loaded_settings 63 | -------------------------------------------------------------------------------- /src/logic/model_requirements.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from .gpu_details import GPUInfo 5 | 6 | MODEL_SETTINGS = { 7 | "models": ["Tiny", "Base", "Small", "Medium", "Large", "Large-v1", "Large-v2", "Large-v3"], 8 | "device": ["GPU", "CPU"] 9 | } 10 | 11 | 12 | class ModelRequirements: 13 | def __init__(self): 14 | current_dir = os.path.dirname(__file__) 15 | file_path = os.path.join(current_dir, '..', '..', f'assets{os.path.sep}config') 16 | config_path = os.path.normpath(file_path) 17 | self.filename = os.path.join(config_path, "recommended_models.json") 18 | 19 | @staticmethod 20 | def model_requirements(available_memory, required_memory): 21 | recommended_models = [model for model, memory in required_memory.items() if available_memory >= memory] 22 | return recommended_models 23 | 24 | def read_json_file(self): 25 | try: 26 | with open(self.filename, 'r') as json_file: 27 | data = json.load(json_file) 28 | return data 29 | except FileNotFoundError: 30 | with open(self.filename, 'w') as json_file: 31 | json.dump(MODEL_SETTINGS, json_file, indent=2) 32 | return MODEL_SETTINGS 33 | 34 | def write_json_file(self, data): 35 | with open(self.filename, 'w') as json_file: 36 | json.dump(data, json_file, indent=2) 37 | 38 | def update_model_requirements(self): 39 | gpu_monitor = GPUInfo() 40 | gpu_info = gpu_monitor.load_gpu_info() 41 | available_memory = gpu_info["total_memory"] 42 | cuda = gpu_info["cuda_available"] 43 | 44 | required_memory = { 45 | "Tiny": 1, 46 | "Base": 1, 47 | "Small": 2, 48 | "Medium": 5, 49 | "Large": 10, 50 | "Large-v1": 10, 51 | "Large-v2": 10, 52 | "Large-v3": 10 53 | } 54 | 55 | json_data = self.read_json_file() 56 | 57 | if json_data: 58 | if cuda: 59 | device = ["CPU", "GPU"] 60 | recommended_models = self.model_requirements(available_memory, required_memory) 61 | json_data["models"] = recommended_models 62 | json_data["device"] = device 63 | else: 64 | recommended_models = ["Tiny", "Base", "Small", "Medium", "Large", "Large-v1", "Large-v2", "Large-v3"] 65 | device = ["CPU"] 66 | json_data["models"] = recommended_models 67 | json_data["device"] = device 68 | 69 | self.write_json_file(json_data) 70 | return recommended_models, device 71 | -------------------------------------------------------------------------------- /src/logic/live_transcriber.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import threading 4 | import queue 5 | 6 | import sounddevice as sd 7 | from .. import whisper 8 | from scipy.io.wavfile import write 9 | 10 | 11 | class LiveTranscriber: 12 | def __init__(self): 13 | self.audio_file_name = "temp.wav" 14 | self.model = whisper.load_model("base") 15 | self.language = "english" 16 | self.transcription_queue = queue.Queue() 17 | self.recording_complete = threading.Event() 18 | self.stop_recording = threading.Event() 19 | 20 | def start_transcription_thread(self): 21 | transcription_thread = threading.Thread(target=self.transcribe_loop) 22 | transcription_thread.start() 23 | 24 | def transcribe_loop(self): 25 | while not self.stop_recording.is_set(): 26 | self.recording_complete.wait() 27 | transcribed_text = self.transcribe() 28 | self.transcription_queue.put(transcribed_text) 29 | self.recording_complete.clear() 30 | 31 | def start_recording_thread(self): 32 | recording_thread = threading.Thread(target=self.record_loop) 33 | recording_thread.start() 34 | 35 | def record_loop(self): 36 | while not self.stop_recording.is_set(): 37 | self.recorder() 38 | self.recording_complete.set() 39 | self.stop_recording.wait(2) 40 | 41 | def start_live_transcription(self): 42 | self.start_recording_thread() 43 | self.start_transcription_thread() 44 | 45 | def stop_live_transcription(self): 46 | if os.path.exists(self.audio_file_name): 47 | os.remove(self.audio_file_name) 48 | self.stop_recording.set() 49 | 50 | def transcribe(self): 51 | print("Transcribing...") 52 | 53 | result = self.model.transcribe(audio=self.audio_file_name, cancel_func=self.cancel_callback) 54 | cleaned_result = self.clean_transcription(result['text']) 55 | print(f"Result: {cleaned_result}") 56 | 57 | os.remove(self.audio_file_name) 58 | 59 | return cleaned_result 60 | 61 | def recorder(self, duration=5): 62 | print("Say something...") 63 | sample_rate = 44100 64 | recording = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=2) 65 | sd.wait() 66 | 67 | write(self.audio_file_name, sample_rate, recording) 68 | 69 | print(f"Recording saved as {self.audio_file_name}") 70 | 71 | def cancel_callback(self): 72 | pass 73 | 74 | @staticmethod 75 | def clean_transcription(transcription): 76 | cleaned_text = re.sub(r'[^a-zA-Z\s]', '', transcription).strip() 77 | return cleaned_text 78 | 79 | 80 | # transcriber = LiveTranscriber() 81 | # transcriber.start_live_transcription() 82 | # transcriber.stop_live_transcription() 83 | -------------------------------------------------------------------------------- /src/ui/ctkAlert/ctkalert.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import customtkinter as ctk 4 | from PIL import Image 5 | 6 | current_path = os.path.dirname(os.path.realpath(__file__)) 7 | icon_path = f"{current_path}{os.path.sep}icons{os.path.sep}" 8 | 9 | icons = { 10 | "info": ctk.CTkImage(dark_image=Image.open(f"{icon_path}info_dark.png"), 11 | light_image=Image.open(f"{icon_path}info_light.png"), size=(28, 28)), 12 | "success": ctk.CTkImage(dark_image=Image.open(f"{icon_path}success_dark.png"), 13 | light_image=Image.open(f"{icon_path}success_light.png"), size=(28, 28)), 14 | "error": ctk.CTkImage(dark_image=Image.open(f"{icon_path}error_dark.png"), 15 | light_image=Image.open(f"{icon_path}error_light.png"), size=(28, 28)), 16 | "warning": ctk.CTkImage(dark_image=Image.open(f"{icon_path}warning_dark.png"), 17 | light_image=Image.open(f"{icon_path}warning_light.png"), size=(28, 28)) 18 | } 19 | 20 | CTkFrame = { 21 | "width": 500, 22 | "height": 200, 23 | "corner_radius": 5, 24 | "border_width": 1, 25 | "fg_color": ["#F7F8FA", "#2B2D30"], 26 | "border_color": ["#D3D5DB", "#4E5157"] 27 | } 28 | 29 | CTkButton = { 30 | "width": 100, 31 | "height": 30, 32 | "corner_radius": 5, 33 | "border_width": 0, 34 | "fg_color": ["#3574F0", "#3574F0"], 35 | "hover_color": ["#3369D6", "#366ACF"], 36 | "border_color": ["#C2D6FC", "#4E5157"], 37 | "text_color": ["#FFFFFF", "#DFE1E5"], 38 | "text_color_disabled": ["#A8ADBD", "#6F737A"] 39 | } 40 | 41 | font_title = ("Inter", 20, "bold") 42 | font_normal = ("Inter", 13, "normal") 43 | 44 | 45 | class CTkAlert(ctk.CTkFrame): 46 | def __init__(self, parent: any, status: str, title: str, msg: str): 47 | super().__init__(master=parent, **CTkFrame) 48 | self.status = status 49 | self.title = title.capitalize() 50 | self.msg = msg.capitalize() 51 | self.grid_propagate(False) 52 | self.grid_columnconfigure(0, weight=1) 53 | 54 | title = ctk.CTkLabel(self, text=f" {self.title}", image=icons.get(self.status, icons["info"]), compound="left", 55 | text_color=("#000000", "#DFE1E5"), font=font_title) 56 | title.grid(row=0, column=0, padx=20, pady=20, sticky="nw") 57 | 58 | msg = ctk.CTkLabel(self, text=self.msg, text_color=("#000000", "#DFE1E5"), font=font_normal) 59 | msg.grid(row=1, column=0, padx=40, pady=20, sticky="nsew") 60 | 61 | button = ctk.CTkButton(self, text="OK", command=self.hide_alert, **CTkButton) 62 | button.grid(row=2, column=0, padx=20, pady=20, sticky="se") 63 | 64 | # self.pack(anchor="center", expand=True) 65 | self.grid(row=0, column=0, sticky="ew", padx=10) 66 | 67 | def hide_alert(self): 68 | self.destroy() 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | import time 5 | 6 | try: 7 | import pkg_resources 8 | except ImportError: 9 | if sys.platform.startswith("win"): 10 | subprocess.call('python -m pip install setuptools', shell=True) 11 | else: 12 | subprocess.call('python3 -m pip install setuptools', shell=True) 13 | import pkg_resources 14 | 15 | DIR_PATH = os.path.dirname(os.path.realpath(__file__)) 16 | 17 | # Checking the required folders 18 | folders = ["assets", "src"] 19 | missing_folder = [] 20 | for i in folders: 21 | if not os.path.exists(i): 22 | missing_folder.append(i) 23 | if missing_folder: 24 | print("These folder(s) not available: " + str(missing_folder)) 25 | print("Download them from the repository properly") 26 | sys.exit() 27 | else: 28 | print("All folders available!") 29 | 30 | # Checking required modules 31 | required = {"Pillow==10.1.0", "sounddevice==0.4.6", "scipy==1.11.4", "deep-translator==1.11.4", "mutagen==1.47.0", "pydub==0.25.1", "tkinterdnd2==0.3.0", "numpy==1.26.2", 32 | "SpeechRecognition==3.10.0", "customtkinter==5.2.1", "tqdm==4.66.1", "numba==0.58.1", "tiktoken==0.5.1", "more-itertools==10.1.0", "GPUtil==1.4.0", "PyAudio==0.2.14", "packaging==23.2"} 33 | installed = {pkg.key for pkg in pkg_resources.working_set} 34 | missing = required - installed 35 | missing_set = [*missing, ] 36 | pytorch_win = "torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118" 37 | pytorch_linux = "torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118" 38 | pytorch_mac = "torch torchvision torchaudio" 39 | 40 | # Download the modules if not installed 41 | if missing: 42 | try: 43 | print("Installing modules...") 44 | for x in range(len(missing_set)): 45 | y = missing_set[x] 46 | if sys.platform.startswith("win"): 47 | subprocess.call('python -m pip install ' + y, shell=True) 48 | else: 49 | subprocess.call('python3 -m pip install ' + y, shell=True) 50 | except: 51 | print("Unable to download! \nThis are the required ones: " + str( 52 | required) + "\nUse 'pip install module_name' to download the modules one by one.") 53 | time.sleep(3) 54 | sys.exit() 55 | 56 | try: 57 | print("Installing Pytorch...") 58 | if sys.platform.startswith("win"): 59 | subprocess.call('python -m pip install ' + pytorch_win, shell=True) 60 | elif sys.platform.startswith("linux"): 61 | subprocess.call('python3 -m pip install ' + pytorch_linux, shell=True) 62 | elif sys.platform.startswith("darwin"): 63 | subprocess.call('python3 -m pip install ' + pytorch_mac, shell=True) 64 | except: 65 | print("Unable to download! \nThis are the required ones: " + str(required) + "\nUse 'pip/pip3 install module_name' to download the modules one by one.") 66 | sys.exit() 67 | else: 68 | print("All required modules installed!") 69 | 70 | # Everything done! 71 | print("Setup Complete!") 72 | time.sleep(5) 73 | sys.exit() 74 | -------------------------------------------------------------------------------- /src/whisper/triton_ops.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import numpy as np 4 | import torch 5 | 6 | try: 7 | import triton 8 | import triton.language as tl 9 | except ImportError: 10 | raise RuntimeError("triton import failed; try `pip install --pre triton`") 11 | 12 | 13 | @triton.jit 14 | def dtw_kernel( 15 | cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr 16 | ): 17 | offsets = tl.arange(0, BLOCK_SIZE) 18 | mask = offsets < M 19 | 20 | for k in range(1, N + M + 1): # k = i + j 21 | tl.debug_barrier() 22 | 23 | p0 = cost + (k - 1) * cost_stride 24 | p1 = cost + k * cost_stride 25 | p2 = cost + k * cost_stride + 1 26 | 27 | c0 = tl.load(p0 + offsets, mask=mask) 28 | c1 = tl.load(p1 + offsets, mask=mask) 29 | c2 = tl.load(p2 + offsets, mask=mask) 30 | 31 | x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0) 32 | cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2) 33 | 34 | cost_ptr = cost + (k + 1) * cost_stride + 1 35 | tl.store(cost_ptr + offsets, cost_row, mask=mask) 36 | 37 | trace_ptr = trace + (k + 1) * trace_stride + 1 38 | tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1)) 39 | tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2)) 40 | tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2)) 41 | 42 | 43 | @lru_cache(maxsize=None) 44 | def median_kernel(filter_width: int): 45 | @triton.jit 46 | def kernel( 47 | y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr 48 | ): # x.shape[-1] == filter_width 49 | row_idx = tl.program_id(0) 50 | offsets = tl.arange(0, BLOCK_SIZE) 51 | mask = offsets < y_stride 52 | 53 | x_ptr = x + row_idx * x_stride # noqa: F841 54 | y_ptr = y + row_idx * y_stride 55 | 56 | LOAD_ALL_ROWS_HERE # noqa: F821 57 | 58 | BUBBLESORT_HERE # noqa: F821 59 | 60 | tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821 61 | 62 | kernel = triton.JITFunction(kernel.fn) 63 | kernel.src = kernel.src.replace( 64 | " LOAD_ALL_ROWS_HERE", 65 | "\n".join( 66 | [ 67 | f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" 68 | for i in range(filter_width) 69 | ] 70 | ), 71 | ) 72 | kernel.src = kernel.src.replace( 73 | " BUBBLESORT_HERE", 74 | "\n\n".join( 75 | [ 76 | "\n\n".join( 77 | [ 78 | "\n".join( 79 | [ 80 | f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})", 81 | f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})", 82 | f" row{j} = smaller", 83 | f" row{j + 1} = larger", 84 | ] 85 | ) 86 | for j in range(filter_width - i - 1) 87 | ] 88 | ) 89 | for i in range(filter_width // 2 + 1) 90 | ] 91 | ), 92 | ) 93 | kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}") 94 | 95 | return kernel 96 | 97 | 98 | def median_filter_cuda(x: torch.Tensor, filter_width: int): 99 | """Apply a median filter of given width along the last dimension of x""" 100 | slices = x.contiguous().unfold(-1, filter_width, 1) 101 | grid = np.prod(slices.shape[:-2]) 102 | 103 | kernel = median_kernel(filter_width) 104 | y = torch.empty_like(slices[..., 0]) 105 | 106 | BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length() 107 | kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE) 108 | 109 | return y 110 | -------------------------------------------------------------------------------- /src/whisper/audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import lru_cache 3 | from subprocess import CalledProcessError, run 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import exact_div 11 | 12 | # hard-coded audio hyperparameters 13 | SAMPLE_RATE = 16000 14 | N_FFT = 400 15 | HOP_LENGTH = 160 16 | CHUNK_LENGTH = 30 17 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk 18 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input 19 | 20 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2 21 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame 22 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token 23 | 24 | 25 | def load_audio(file: str, sr: int = SAMPLE_RATE): 26 | """ 27 | Open an audio file and read as mono waveform, resampling as necessary 28 | 29 | Parameters 30 | ---------- 31 | file: str 32 | The audio file to open 33 | 34 | sr: int 35 | The sample rate to resample the audio if necessary 36 | 37 | Returns 38 | ------- 39 | A NumPy array containing the audio waveform, in float32 dtype. 40 | """ 41 | 42 | # This launches a subprocess to decode audio while down-mixing 43 | # and resampling as necessary. Requires the ffmpeg CLI in PATH. 44 | # fmt: off 45 | cmd = [ 46 | "ffmpeg", 47 | "-nostdin", 48 | "-threads", "0", 49 | "-i", file, 50 | "-f", "s16le", 51 | "-ac", "1", 52 | "-acodec", "pcm_s16le", 53 | "-ar", str(sr), 54 | "-" 55 | ] 56 | # fmt: on 57 | try: 58 | out = run(cmd, capture_output=True, check=True).stdout 59 | except CalledProcessError as e: 60 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e 61 | 62 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 63 | 64 | 65 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): 66 | """ 67 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder. 68 | """ 69 | if torch.is_tensor(array): 70 | if array.shape[axis] > length: 71 | array = array.index_select( 72 | dim=axis, index=torch.arange(length, device=array.device) 73 | ) 74 | 75 | if array.shape[axis] < length: 76 | pad_widths = [(0, 0)] * array.ndim 77 | pad_widths[axis] = (0, length - array.shape[axis]) 78 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) 79 | else: 80 | if array.shape[axis] > length: 81 | array = array.take(indices=range(length), axis=axis) 82 | 83 | if array.shape[axis] < length: 84 | pad_widths = [(0, 0)] * array.ndim 85 | pad_widths[axis] = (0, length - array.shape[axis]) 86 | array = np.pad(array, pad_widths) 87 | 88 | return array 89 | 90 | 91 | @lru_cache(maxsize=None) 92 | def mel_filters(device, n_mels: int) -> torch.Tensor: 93 | """ 94 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram. 95 | Allows decoupling librosa dependency; saved using: 96 | 97 | np.savez_compressed( 98 | "mel_filters.npz", 99 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), 100 | mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), 101 | ) 102 | """ 103 | assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" 104 | 105 | filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz") 106 | with np.load(filters_path, allow_pickle=False) as f: 107 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) 108 | 109 | 110 | def log_mel_spectrogram( 111 | audio: Union[str, np.ndarray, torch.Tensor], 112 | n_mels: int = 80, 113 | padding: int = 0, 114 | device: Optional[Union[str, torch.device]] = None, 115 | ): 116 | """ 117 | Compute the log-Mel spectrogram of 118 | 119 | Parameters 120 | ---------- 121 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*) 122 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz 123 | 124 | n_mels: int 125 | The number of Mel-frequency filters, only 80 is supported 126 | 127 | padding: int 128 | Number of zero samples to pad to the right 129 | 130 | device: Optional[Union[str, torch.device]] 131 | If given, the audio tensor is moved to this device before STFT 132 | 133 | Returns 134 | ------- 135 | torch.Tensor, shape = (80, n_frames) 136 | A Tensor that contains the Mel spectrogram 137 | """ 138 | if not torch.is_tensor(audio): 139 | if isinstance(audio, str): 140 | audio = load_audio(audio) 141 | audio = torch.from_numpy(audio) 142 | 143 | if device is not None: 144 | audio = audio.to(device) 145 | if padding > 0: 146 | audio = F.pad(audio, (0, padding)) 147 | window = torch.hann_window(N_FFT).to(audio.device) 148 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True) 149 | magnitudes = stft[..., :-1].abs() ** 2 150 | 151 | filters = mel_filters(audio.device, n_mels) 152 | mel_spec = filters @ magnitudes 153 | 154 | log_spec = torch.clamp(mel_spec, min=1e-10).log10() 155 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) 156 | log_spec = (log_spec + 4.0) / 4.0 157 | return log_spec 158 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import webbrowser 3 | 4 | import customtkinter as ctk 5 | from PIL import Image 6 | 7 | from src.logic.settings import SettingsHandler 8 | from src.ui.add_subtitles import AddSubtitlesUI 9 | from src.ui.live_transcribe import LiveTranscribeUI 10 | from src.ui.settings import SettingsUI 11 | from src.ui.style import FONTS 12 | from src.ui.transcribe import TranscribeUI 13 | from src.ui.translate import TranslateUI 14 | 15 | current_path = os.path.dirname(os.path.realpath(__file__)) 16 | icon_path = f"{current_path}{os.path.sep}assets{os.path.sep}icons{os.path.sep}" 17 | 18 | icons = { 19 | "logo": ctk.CTkImage(dark_image=Image.open(f"{icon_path}logo.png"), 20 | light_image=Image.open(f"{icon_path}logo.png"), size=(50, 50)), 21 | "close": ctk.CTkImage(dark_image=Image.open(f"{icon_path}close_dark.png"), 22 | light_image=Image.open(f"{icon_path}close_light.png"), size=(30, 30)), 23 | "audio_file": ctk.CTkImage(dark_image=Image.open(f"{icon_path}audio_file_dark.png"), 24 | light_image=Image.open(f"{icon_path}audio_file_light.png"), 25 | size=(30, 30)), 26 | "translation": ctk.CTkImage(dark_image=Image.open(f"{icon_path}translation_dark.png"), 27 | light_image=Image.open(f"{icon_path}translation_light.png"), 28 | size=(30, 30)), 29 | "microphone": ctk.CTkImage(dark_image=Image.open(f"{icon_path}microphone_dark.png"), 30 | light_image=Image.open(f"{icon_path}microphone_light.png"), 31 | size=(30, 30)), 32 | "subtitles": ctk.CTkImage(dark_image=Image.open(f"{icon_path}subtitles_dark.png"), 33 | light_image=Image.open(f"{icon_path}subtitles_light.png"), 34 | size=(30, 30)), 35 | "paypal": ctk.CTkImage(dark_image=Image.open(f"{icon_path}paypal_dark.png"), 36 | light_image=Image.open(f"{icon_path}paypal_light.png"), size=(30, 30)), 37 | "settings": ctk.CTkImage(dark_image=Image.open(f"{icon_path}settings_dark.png"), 38 | light_image=Image.open(f"{icon_path}settings_light.png"), 39 | size=(30, 30)), 40 | "help": ctk.CTkImage(dark_image=Image.open(f"{icon_path}help_dark.png"), 41 | light_image=Image.open(f"{icon_path}help_light.png"), size=(30, 30)), 42 | "github": ctk.CTkImage(dark_image=Image.open(f"{icon_path}github_dark.png"), 43 | light_image=Image.open(f"{icon_path}github_light.png"), size=(30, 30)) 44 | } 45 | 46 | logo = f"{icon_path}logo.ico" 47 | 48 | btn = { 49 | "width": 280, 50 | "height": 116, 51 | "text_color": ("#FFFFFF", "#DFE1E5"), 52 | "compound": "left", 53 | "font": ("Inter", 16) 54 | } 55 | 56 | secondary_btn = { 57 | "width": 280, 58 | "height": 100, 59 | "fg_color": ("#221D21", "#2B2D30"), 60 | "hover": False, 61 | "border_width": 0, 62 | "text_color": ("#FFFFFF", "#DFE1E5"), 63 | "compound": "left", 64 | "font": ("Inter", 16) 65 | } 66 | 67 | link_btn = { 68 | "width": 140, 69 | "height": 80, 70 | "fg_color": "transparent", 71 | "hover_color": ("#D3D5DB", "#2B2D30"), 72 | "border_color": ("#D3D5DB", "#2B2D30"), 73 | "border_width": 2, 74 | "text_color": ("#000000", "#DFE1E5"), 75 | "compound": "left", 76 | "font": ("Inter", 16) 77 | } 78 | 79 | 80 | def help_link(): 81 | webbrowser.open("https://github.com/rudymohammadbali/Whisper-Transcriber/discussions/categories/q-a") 82 | 83 | 84 | def github_link(): 85 | webbrowser.open("https://github.com/rudymohammadbali") 86 | 87 | 88 | def paypal_link(): 89 | webbrowser.open("https://www.paypal.com/paypalme/iamironman0") 90 | 91 | 92 | class Testing(ctk.CTk): 93 | def __init__(self): 94 | super().__init__() 95 | self.geometry("620x720") 96 | self.resizable(False, False) 97 | self.iconbitmap(logo) 98 | self.title("Whisper Transcriber") 99 | 100 | settings_handler = SettingsHandler() 101 | settings = settings_handler.load_settings() 102 | theme = settings.get("theme") 103 | color_theme = settings.get("color_theme") 104 | 105 | ctk.set_appearance_mode(theme) 106 | ctk.set_default_color_theme(color_theme) 107 | 108 | self.main_ui() 109 | 110 | def main_ui(self): 111 | title = ctk.CTkLabel(self, text="Welcome to Whisper Transcriber", text_color=("#000000", "#DFE1E5"), 112 | font=FONTS["title_bold"], image=icons["logo"], compound="top") 113 | title.grid(row=0, column=0, padx=20, pady=20, sticky="nsew", columnspan=2) 114 | 115 | label = ctk.CTkLabel(self, text="Select a Service", text_color=("#000000", "#DFE1E5"), font=FONTS["subtitle"]) 116 | label.grid(row=1, column=0, padx=20, pady=20, sticky="w") 117 | 118 | btn_1 = ctk.CTkButton(self, text="Transcribe Audio", **btn, image=icons["audio_file"], 119 | command=lambda: TranscribeUI(parent=self)) 120 | btn_1.grid(row=2, column=0, padx=(20, 10), pady=10, sticky="nsew") 121 | btn_2 = ctk.CTkButton(self, text="Translate Audio", **btn, image=icons["translation"], 122 | command=lambda: TranslateUI(parent=self)) 123 | btn_2.grid(row=2, column=1, padx=(10, 20), pady=10, sticky="nsew") 124 | 125 | btn_3 = ctk.CTkButton(self, text="Live Transcriber", **btn, image=icons["microphone"], 126 | command=lambda: LiveTranscribeUI(parent=self)) 127 | btn_3.grid(row=3, column=0, padx=(20, 10), pady=10, sticky="nsew") 128 | btn_4 = ctk.CTkButton(self, text="Add Subtitle", **btn, image=icons["subtitles"], 129 | command=lambda: AddSubtitlesUI(parent=self)) 130 | btn_4.grid(row=3, column=1, padx=(10, 20), pady=10, sticky="nsew") 131 | 132 | btn_5 = ctk.CTkButton(self, text="Settings", font=("Inter", 16), text_color=("#FFFFFF", "#DFE1E5"), width=280, 133 | height=100, image=icons["settings"], command=lambda: SettingsUI(parent=self)) 134 | btn_5.grid(row=4, column=0, padx=(20, 10), pady=20, sticky="nsew") 135 | btn_6 = ctk.CTkButton(self, text="Support on PayPal", **secondary_btn, image=icons["paypal"], 136 | command=paypal_link) 137 | btn_6.grid(row=4, column=1, padx=(10, 20), pady=20, sticky="nsew") 138 | 139 | btn_7 = ctk.CTkButton(self, text="Github", **link_btn, image=icons["github"], command=github_link) 140 | btn_7.grid(row=5, column=0, padx=20, pady=20, sticky="nsew") 141 | btn_8 = ctk.CTkButton(self, text="Help", **link_btn, image=icons["help"], command=help_link) 142 | btn_8.grid(row=5, column=1, padx=20, pady=20, sticky="nsew") 143 | 144 | 145 | app = Testing() 146 | app.mainloop() 147 | -------------------------------------------------------------------------------- /src/whisper/__init__.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import io 3 | import os 4 | import urllib 5 | import warnings 6 | from typing import List, Optional, Union 7 | 8 | import torch 9 | from tqdm import tqdm 10 | 11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim 12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language 13 | from .model import ModelDimensions, Whisper 14 | from .transcribe import transcribe 15 | from .version import __version__ 16 | 17 | _MODELS = { 18 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt", 19 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt", 20 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt", 21 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt", 22 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt", 23 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt", 24 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt", 25 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt", 26 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt", 27 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt", 28 | "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", 29 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt", 30 | } 31 | 32 | # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are 33 | # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens. 34 | _ALIGNMENT_HEADS = { 35 | "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00", 36 | "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", 37 | "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00", 38 | "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00", 40 | "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", 43 | "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", 45 | "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", 46 | "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", 47 | } 48 | 49 | 50 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: 51 | os.makedirs(root, exist_ok=True) 52 | 53 | expected_sha256 = url.split("/")[-2] 54 | download_target = os.path.join(root, os.path.basename(url)) 55 | 56 | if os.path.exists(download_target) and not os.path.isfile(download_target): 57 | raise RuntimeError(f"{download_target} exists and is not a regular file") 58 | 59 | if os.path.isfile(download_target): 60 | with open(download_target, "rb") as f: 61 | model_bytes = f.read() 62 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: 63 | return model_bytes if in_memory else download_target 64 | else: 65 | warnings.warn( 66 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 67 | ) 68 | 69 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 70 | with tqdm( 71 | total=int(source.info().get("Content-Length")), 72 | ncols=80, 73 | unit="iB", 74 | unit_scale=True, 75 | unit_divisor=1024, 76 | ) as loop: 77 | while True: 78 | buffer = source.read(8192) 79 | if not buffer: 80 | break 81 | 82 | output.write(buffer) 83 | loop.update(len(buffer)) 84 | 85 | model_bytes = open(download_target, "rb").read() 86 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: 87 | raise RuntimeError( 88 | "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." 89 | ) 90 | 91 | return model_bytes if in_memory else download_target 92 | 93 | 94 | def available_models() -> List[str]: 95 | """Returns the names of available models""" 96 | return list(_MODELS.keys()) 97 | 98 | 99 | def load_model( 100 | name: str, 101 | device: Optional[Union[str, torch.device]] = None, 102 | download_root: str = None, 103 | in_memory: bool = False, 104 | ) -> Whisper: 105 | """ 106 | Load a Whisper ASR model 107 | 108 | Parameters 109 | ---------- 110 | name : str 111 | one of the official model names listed by `whisper.available_models()`, or 112 | path to a model checkpoint containing the model dimensions and the model state_dict. 113 | device : Union[str, torch.device] 114 | the PyTorch device to put the model into 115 | download_root: str 116 | path to download the model files; by default, it uses "~/.cache/whisper" 117 | in_memory: bool 118 | whether to preload the model weights into host memory 119 | 120 | Returns 121 | ------- 122 | model : Whisper 123 | The Whisper ASR model instance 124 | """ 125 | 126 | if device is None: 127 | device = "cuda" if torch.cuda.is_available() else "cpu" 128 | if download_root is None: 129 | default = os.path.join(os.path.expanduser("~"), ".cache") 130 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") 131 | 132 | if name in _MODELS: 133 | checkpoint_file = _download(_MODELS[name], download_root, in_memory) 134 | alignment_heads = _ALIGNMENT_HEADS[name] 135 | elif os.path.isfile(name): 136 | checkpoint_file = open(name, "rb").read() if in_memory else name 137 | alignment_heads = None 138 | else: 139 | raise RuntimeError( 140 | f"Model {name} not found; available models = {available_models()}" 141 | ) 142 | 143 | with ( 144 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") 145 | ) as fp: 146 | checkpoint = torch.load(fp, map_location=device) 147 | del checkpoint_file 148 | 149 | dims = ModelDimensions(**checkpoint["dims"]) 150 | model = Whisper(dims) 151 | model.load_state_dict(checkpoint["model_state_dict"]) 152 | 153 | if alignment_heads is not None: 154 | model.set_alignment_heads(alignment_heads) 155 | 156 | return model.to(device) 157 | -------------------------------------------------------------------------------- /src/ui/ctk_tooltip.py: -------------------------------------------------------------------------------- 1 | """ 2 | CTkToolTip Widget 3 | version: 0.8 4 | """ 5 | 6 | import sys 7 | import time 8 | from tkinter import Toplevel, Frame 9 | 10 | import customtkinter 11 | 12 | 13 | class CTkToolTip(Toplevel): 14 | """ 15 | Creates a ToolTip (pop-up) widget for customtkinter. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | widget: any = None, 21 | message: str = None, 22 | delay: float = 0.2, 23 | follow: bool = True, 24 | x_offset: int = +20, 25 | y_offset: int = +10, 26 | bg_color: str = None, 27 | corner_radius: int = 10, 28 | border_width: int = 0, 29 | border_color: str = None, 30 | alpha: float = 0.95, 31 | padding: tuple = (10, 2), 32 | **message_kwargs): 33 | 34 | super().__init__() 35 | 36 | self.widget = widget 37 | 38 | self.withdraw() 39 | 40 | # Disable ToolTip's title bar 41 | self.overrideredirect(True) 42 | 43 | if sys.platform.startswith("win"): 44 | self.transparent_color = self.widget._apply_appearance_mode( 45 | customtkinter.ThemeManager.theme["CTkToplevel"]["fg_color"]) 46 | self.attributes("-transparentcolor", self.transparent_color) 47 | self.transient() 48 | elif sys.platform.startswith("darwin"): 49 | self.transparent_color = 'systemTransparent' 50 | self.attributes("-transparent", True) 51 | self.transient(self.master) 52 | else: 53 | self.transparent_color = '#000001' 54 | corner_radius = 0 55 | self.transient() 56 | 57 | self.resizable(width=True, height=True) 58 | 59 | # Make the background transparent 60 | self.config(background=self.transparent_color) 61 | 62 | # StringVar instance for msg string 63 | self.messageVar = customtkinter.StringVar() 64 | self.message = message 65 | self.messageVar.set(self.message) 66 | 67 | self.delay = delay 68 | self.follow = follow 69 | self.x_offset = x_offset 70 | self.y_offset = y_offset 71 | self.corner_radius = corner_radius 72 | self.alpha = alpha 73 | self.border_width = border_width 74 | self.padding = padding 75 | self.bg_color = customtkinter.ThemeManager.theme["CTkFrame"]["fg_color"] if bg_color is None else bg_color 76 | self.border_color = border_color 77 | self.disable = False 78 | 79 | # visibility status of the ToolTip inside|outside|visible 80 | self.status = "outside" 81 | self.last_moved = 0 82 | self.attributes('-alpha', self.alpha) 83 | 84 | if sys.platform.startswith("win"): 85 | if self.widget._apply_appearance_mode(self.bg_color) == self.transparent_color: 86 | self.transparent_color = "#000001" 87 | self.config(background=self.transparent_color) 88 | self.attributes("-transparentcolor", self.transparent_color) 89 | 90 | # Add the message widget inside the tooltip 91 | self.transparent_frame = Frame(self, bg=self.transparent_color) 92 | self.transparent_frame.pack(padx=0, pady=0, fill="both", expand=True) 93 | 94 | self.frame = customtkinter.CTkFrame(self.transparent_frame, bg_color=self.transparent_color, 95 | corner_radius=self.corner_radius, 96 | border_width=self.border_width, fg_color=self.bg_color, 97 | border_color=self.border_color) 98 | self.frame.pack(padx=0, pady=0, fill="both", expand=True) 99 | 100 | self.message_label = customtkinter.CTkLabel(self.frame, textvariable=self.messageVar, **message_kwargs) 101 | self.message_label.pack(fill="both", padx=self.padding[0] + self.border_width, 102 | pady=self.padding[1] + self.border_width, expand=True) 103 | 104 | if self.widget.winfo_name() != "tk": 105 | if self.frame.cget("fg_color") == self.widget.cget("bg_color"): 106 | if not bg_color: 107 | self._top_fg_color = self.frame._apply_appearance_mode( 108 | customtkinter.ThemeManager.theme["CTkFrame"]["top_fg_color"]) 109 | if self._top_fg_color != self.transparent_color: 110 | self.frame.configure(fg_color=self._top_fg_color) 111 | 112 | # Add bindings to the widget without overriding the existing ones 113 | self.widget.bind("", self.on_enter, add="+") 114 | self.widget.bind("", self.on_leave, add="+") 115 | self.widget.bind("", self.on_enter, add="+") 116 | self.widget.bind("", self.on_enter, add="+") 117 | self.widget.bind("", lambda _: self.hide(), add="+") 118 | 119 | def show(self) -> None: 120 | """ 121 | Enable the widget. 122 | """ 123 | self.disable = False 124 | 125 | def on_enter(self, event) -> None: 126 | """ 127 | Processes motion within the widget including entering and moving. 128 | """ 129 | 130 | if self.disable: 131 | return 132 | self.last_moved = time.time() 133 | 134 | # Set the status as inside for the very first time 135 | if self.status == "outside": 136 | self.status = "inside" 137 | 138 | # If the follow flag is not set, motion within the widget will make the ToolTip dissapear 139 | if not self.follow: 140 | self.status = "inside" 141 | self.withdraw() 142 | 143 | # Calculate available space on the right side of the widget relative to the screen 144 | root_width = self.winfo_screenwidth() 145 | widget_x = event.x_root 146 | space_on_right = root_width - widget_x 147 | 148 | # Calculate the width of the tooltip's text based on the length of the message string 149 | text_width = self.message_label.winfo_reqwidth() 150 | 151 | # Calculate the offset based on available space and text width to avoid going off-screen on the right side 152 | offset_x = self.x_offset 153 | if space_on_right < text_width + 20: # Adjust the threshold as needed 154 | offset_x = -text_width - 20 # Negative offset when space is limited on the right side 155 | 156 | # Offsets the ToolTip using the coordinates od an event as an origin 157 | self.geometry(f"+{event.x_root + offset_x}+{event.y_root + self.y_offset}") 158 | 159 | # Time is in integer: milliseconds 160 | self.after(int(self.delay * 1000), self._show) 161 | 162 | def on_leave(self, event=None) -> None: 163 | """ 164 | Hides the ToolTip temporarily. 165 | """ 166 | 167 | if self.disable: return 168 | self.status = "outside" 169 | self.withdraw() 170 | 171 | def _show(self) -> None: 172 | """ 173 | Displays the ToolTip. 174 | """ 175 | 176 | if not self.widget.winfo_exists(): 177 | self.hide() 178 | self.destroy() 179 | 180 | if self.status == "inside" and time.time() - self.last_moved >= self.delay: 181 | self.status = "visible" 182 | self.deiconify() 183 | 184 | def hide(self) -> None: 185 | """ 186 | Disable the widget from appearing. 187 | """ 188 | if not self.winfo_exists(): 189 | return 190 | self.withdraw() 191 | self.disable = True 192 | 193 | def is_disabled(self) -> None: 194 | """ 195 | Return the window state 196 | """ 197 | return self.disable 198 | 199 | def get(self) -> None: 200 | """ 201 | Returns the text on the tooltip. 202 | """ 203 | return self.messageVar.get() 204 | 205 | def configure(self, message: str = None, delay: float = None, bg_color: str = None, **kwargs): 206 | """ 207 | Set new message or configure the label parameters. 208 | """ 209 | if delay: self.delay = delay 210 | if bg_color: self.frame.configure(fg_color=bg_color) 211 | 212 | self.messageVar.set(message) 213 | self.message_label.configure(**kwargs) 214 | -------------------------------------------------------------------------------- /src/whisper/model.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import gzip 3 | from dataclasses import dataclass 4 | from typing import Dict, Iterable, Optional 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import Tensor, nn 10 | 11 | from .decoding import decode as decode_function 12 | from .decoding import detect_language as detect_language_function 13 | from .transcribe import transcribe as transcribe_function 14 | 15 | 16 | @dataclass 17 | class ModelDimensions: 18 | n_mels: int 19 | n_audio_ctx: int 20 | n_audio_state: int 21 | n_audio_head: int 22 | n_audio_layer: int 23 | n_vocab: int 24 | n_text_ctx: int 25 | n_text_state: int 26 | n_text_head: int 27 | n_text_layer: int 28 | 29 | 30 | class LayerNorm(nn.LayerNorm): 31 | def forward(self, x: Tensor) -> Tensor: 32 | return super().forward(x.float()).type(x.dtype) 33 | 34 | 35 | class Linear(nn.Linear): 36 | def forward(self, x: Tensor) -> Tensor: 37 | return F.linear( 38 | x, 39 | self.weight.to(x.dtype), 40 | None if self.bias is None else self.bias.to(x.dtype), 41 | ) 42 | 43 | 44 | class Conv1d(nn.Conv1d): 45 | def _conv_forward( 46 | self, x: Tensor, weight: Tensor, bias: Optional[Tensor] 47 | ) -> Tensor: 48 | return super()._conv_forward( 49 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) 50 | ) 51 | 52 | 53 | def sinusoids(length, channels, max_timescale=10000): 54 | """Returns sinusoids for positional embedding""" 55 | assert channels % 2 == 0 56 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) 57 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) 58 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] 59 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 60 | 61 | 62 | class MultiHeadAttention(nn.Module): 63 | def __init__(self, n_state: int, n_head: int): 64 | super().__init__() 65 | self.n_head = n_head 66 | self.query = Linear(n_state, n_state) 67 | self.key = Linear(n_state, n_state, bias=False) 68 | self.value = Linear(n_state, n_state) 69 | self.out = Linear(n_state, n_state) 70 | 71 | def forward( 72 | self, 73 | x: Tensor, 74 | xa: Optional[Tensor] = None, 75 | mask: Optional[Tensor] = None, 76 | kv_cache: Optional[dict] = None, 77 | ): 78 | q = self.query(x) 79 | 80 | if kv_cache is None or xa is None or self.key not in kv_cache: 81 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors; 82 | # otherwise, perform key/value projections for self- or cross-attention as usual. 83 | k = self.key(x if xa is None else xa) 84 | v = self.value(x if xa is None else xa) 85 | else: 86 | # for cross-attention, calculate keys and values once and reuse in subsequent calls. 87 | k = kv_cache[self.key] 88 | v = kv_cache[self.value] 89 | 90 | wv, qk = self.qkv_attention(q, k, v, mask) 91 | return self.out(wv), qk 92 | 93 | def qkv_attention( 94 | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None 95 | ): 96 | n_batch, n_ctx, n_state = q.shape 97 | scale = (n_state // self.n_head) ** -0.25 98 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale 99 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale 100 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) 101 | 102 | qk = q @ k 103 | if mask is not None: 104 | qk = qk + mask[:n_ctx, :n_ctx] 105 | qk = qk.float() 106 | 107 | w = F.softmax(qk, dim=-1).to(q.dtype) 108 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() 109 | 110 | 111 | class ResidualAttentionBlock(nn.Module): 112 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False): 113 | super().__init__() 114 | 115 | self.attn = MultiHeadAttention(n_state, n_head) 116 | self.attn_ln = LayerNorm(n_state) 117 | 118 | self.cross_attn = ( 119 | MultiHeadAttention(n_state, n_head) if cross_attention else None 120 | ) 121 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None 122 | 123 | n_mlp = n_state * 4 124 | self.mlp = nn.Sequential( 125 | Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state) 126 | ) 127 | self.mlp_ln = LayerNorm(n_state) 128 | 129 | def forward( 130 | self, 131 | x: Tensor, 132 | xa: Optional[Tensor] = None, 133 | mask: Optional[Tensor] = None, 134 | kv_cache: Optional[dict] = None, 135 | ): 136 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] 137 | if self.cross_attn: 138 | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] 139 | x = x + self.mlp(self.mlp_ln(x)) 140 | return x 141 | 142 | 143 | class AudioEncoder(nn.Module): 144 | def __init__( 145 | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int 146 | ): 147 | super().__init__() 148 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) 149 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) 150 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) 151 | 152 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 153 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)] 154 | ) 155 | self.ln_post = LayerNorm(n_state) 156 | 157 | def forward(self, x: Tensor): 158 | """ 159 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) 160 | the mel spectrogram of the audio 161 | """ 162 | x = F.gelu(self.conv1(x)) 163 | x = F.gelu(self.conv2(x)) 164 | x = x.permute(0, 2, 1) 165 | 166 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" 167 | x = (x + self.positional_embedding).to(x.dtype) 168 | 169 | for block in self.blocks: 170 | x = block(x) 171 | 172 | x = self.ln_post(x) 173 | return x 174 | 175 | 176 | class TextDecoder(nn.Module): 177 | def __init__( 178 | self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int 179 | ): 180 | super().__init__() 181 | 182 | self.token_embedding = nn.Embedding(n_vocab, n_state) 183 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state)) 184 | 185 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList( 186 | [ 187 | ResidualAttentionBlock(n_state, n_head, cross_attention=True) 188 | for _ in range(n_layer) 189 | ] 190 | ) 191 | self.ln = LayerNorm(n_state) 192 | 193 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) 194 | self.register_buffer("mask", mask, persistent=False) 195 | 196 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): 197 | """ 198 | x : torch.LongTensor, shape = (batch_size, <= n_ctx) 199 | the text tokens 200 | xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state) 201 | the encoded audio features to be attended on 202 | """ 203 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0 204 | x = ( 205 | self.token_embedding(x) 206 | + self.positional_embedding[offset : offset + x.shape[-1]] 207 | ) 208 | x = x.to(xa.dtype) 209 | 210 | for block in self.blocks: 211 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache) 212 | 213 | x = self.ln(x) 214 | logits = ( 215 | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1) 216 | ).float() 217 | 218 | return logits 219 | 220 | 221 | class Whisper(nn.Module): 222 | def __init__(self, dims: ModelDimensions): 223 | super().__init__() 224 | self.dims = dims 225 | self.encoder = AudioEncoder( 226 | self.dims.n_mels, 227 | self.dims.n_audio_ctx, 228 | self.dims.n_audio_state, 229 | self.dims.n_audio_head, 230 | self.dims.n_audio_layer, 231 | ) 232 | self.decoder = TextDecoder( 233 | self.dims.n_vocab, 234 | self.dims.n_text_ctx, 235 | self.dims.n_text_state, 236 | self.dims.n_text_head, 237 | self.dims.n_text_layer, 238 | ) 239 | # use the last half among the decoder layers for time alignment by default; 240 | # to use a specific set of heads, see `set_alignment_heads()` below. 241 | all_heads = torch.zeros( 242 | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool 243 | ) 244 | all_heads[self.dims.n_text_layer // 2 :] = True 245 | self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) 246 | 247 | def set_alignment_heads(self, dump: bytes): 248 | array = np.frombuffer( 249 | gzip.decompress(base64.b85decode(dump)), dtype=bool 250 | ).copy() 251 | mask = torch.from_numpy(array).reshape( 252 | self.dims.n_text_layer, self.dims.n_text_head 253 | ) 254 | self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False) 255 | 256 | def embed_audio(self, mel: torch.Tensor): 257 | return self.encoder(mel) 258 | 259 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor): 260 | return self.decoder(tokens, audio_features) 261 | 262 | def forward( 263 | self, mel: torch.Tensor, tokens: torch.Tensor 264 | ) -> Dict[str, torch.Tensor]: 265 | return self.decoder(tokens, self.encoder(mel)) 266 | 267 | @property 268 | def device(self): 269 | return next(self.parameters()).device 270 | 271 | @property 272 | def is_multilingual(self): 273 | return self.dims.n_vocab >= 51865 274 | 275 | @property 276 | def num_languages(self): 277 | return self.dims.n_vocab - 51765 - int(self.is_multilingual) 278 | 279 | def install_kv_cache_hooks(self, cache: Optional[dict] = None): 280 | """ 281 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value 282 | tensors calculated for the previous positions. This method returns a dictionary that stores 283 | all caches, and the necessary hooks for the key and value projection modules that save the 284 | intermediate tensors to be reused during later calculations. 285 | 286 | Returns 287 | ------- 288 | cache : Dict[nn.Module, torch.Tensor] 289 | A dictionary object mapping the key/value projection modules to its cache 290 | hooks : List[RemovableHandle] 291 | List of PyTorch RemovableHandle objects to stop the hooks to be called 292 | """ 293 | cache = {**cache} if cache is not None else {} 294 | hooks = [] 295 | 296 | def save_to_cache(module, _, output): 297 | if module not in cache or output.shape[1] > self.dims.n_text_ctx: 298 | # save as-is, for the first token or cross attention 299 | cache[module] = output 300 | else: 301 | cache[module] = torch.cat([cache[module], output], dim=1).detach() 302 | return cache[module] 303 | 304 | def install_hooks(layer: nn.Module): 305 | if isinstance(layer, MultiHeadAttention): 306 | hooks.append(layer.key.register_forward_hook(save_to_cache)) 307 | hooks.append(layer.value.register_forward_hook(save_to_cache)) 308 | 309 | self.decoder.apply(install_hooks) 310 | return cache, hooks 311 | 312 | detect_language = detect_language_function 313 | transcribe = transcribe_function 314 | decode = decode_function 315 | -------------------------------------------------------------------------------- /src/whisper/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | import zlib 6 | from typing import Callable, Optional, TextIO 7 | 8 | system_encoding = sys.getdefaultencoding() 9 | 10 | if system_encoding != "utf-8": 11 | 12 | def make_safe(string): 13 | # replaces any character not representable using the system default encoding with an '?', 14 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729). 15 | return string.encode(system_encoding, errors="replace").decode(system_encoding) 16 | 17 | else: 18 | 19 | def make_safe(string): 20 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding 21 | return string 22 | 23 | 24 | def exact_div(x, y): 25 | assert x % y == 0 26 | return x // y 27 | 28 | 29 | def str2bool(string): 30 | str2val = {"True": True, "False": False} 31 | if string in str2val: 32 | return str2val[string] 33 | else: 34 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}") 35 | 36 | 37 | def optional_int(string): 38 | return None if string == "None" else int(string) 39 | 40 | 41 | def optional_float(string): 42 | return None if string == "None" else float(string) 43 | 44 | 45 | def compression_ratio(text) -> float: 46 | text_bytes = text.encode("utf-8") 47 | return len(text_bytes) / len(zlib.compress(text_bytes)) 48 | 49 | 50 | def format_timestamp( 51 | seconds: float, always_include_hours: bool = False, decimal_marker: str = "." 52 | ): 53 | assert seconds >= 0, "non-negative timestamp expected" 54 | milliseconds = round(seconds * 1000.0) 55 | 56 | hours = milliseconds // 3_600_000 57 | milliseconds -= hours * 3_600_000 58 | 59 | minutes = milliseconds // 60_000 60 | milliseconds -= minutes * 60_000 61 | 62 | seconds = milliseconds // 1_000 63 | milliseconds -= seconds * 1_000 64 | 65 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" 66 | return ( 67 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" 68 | ) 69 | 70 | 71 | class ResultWriter: 72 | extension: str 73 | 74 | def __init__(self, output_dir: str): 75 | self.output_dir = output_dir 76 | 77 | def __call__( 78 | self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs 79 | ): 80 | audio_basename = os.path.basename(audio_path) 81 | audio_basename = os.path.splitext(audio_basename)[0] 82 | output_path = os.path.join( 83 | self.output_dir, audio_basename + "." + self.extension 84 | ) 85 | 86 | with open(output_path, "w", encoding="utf-8") as f: 87 | self.write_result(result, file=f, options=options, **kwargs) 88 | 89 | def write_result( 90 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 91 | ): 92 | raise NotImplementedError 93 | 94 | 95 | class WriteTXT(ResultWriter): 96 | extension: str = "txt" 97 | 98 | def write_result( 99 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 100 | ): 101 | for segment in result["segments"]: 102 | print(segment["text"].strip(), file=file, flush=True) 103 | 104 | 105 | class SubtitlesWriter(ResultWriter): 106 | always_include_hours: bool 107 | decimal_marker: str 108 | 109 | def iterate_result( 110 | self, 111 | result: dict, 112 | options: Optional[dict] = None, 113 | *, 114 | max_line_width: Optional[int] = None, 115 | max_line_count: Optional[int] = None, 116 | highlight_words: bool = False, 117 | max_words_per_line: Optional[int] = None, 118 | ): 119 | options = options or {} 120 | max_line_width = max_line_width or options.get("max_line_width") 121 | max_line_count = max_line_count or options.get("max_line_count") 122 | highlight_words = highlight_words or options.get("highlight_words", False) 123 | max_words_per_line = max_words_per_line or options.get("max_words_per_line") 124 | preserve_segments = max_line_count is None or max_line_width is None 125 | max_line_width = max_line_width or 1000 126 | max_words_per_line = max_words_per_line or 1000 127 | 128 | def iterate_subtitles(): 129 | line_len = 0 130 | line_count = 1 131 | # the next subtitle to yield (a list of word timings with whitespace) 132 | subtitle: list[dict] = [] 133 | last = result["segments"][0]["words"][0]["start"] 134 | for segment in result["segments"]: 135 | chunk_index = 0 136 | words_count = max_words_per_line 137 | while chunk_index < len(segment["words"]): 138 | remaining_words = len(segment["words"]) - chunk_index 139 | if max_words_per_line > len(segment["words"]) - chunk_index: 140 | words_count = remaining_words 141 | for i, original_timing in enumerate( 142 | segment["words"][chunk_index : chunk_index + words_count] 143 | ): 144 | timing = original_timing.copy() 145 | long_pause = ( 146 | not preserve_segments and timing["start"] - last > 3.0 147 | ) 148 | has_room = line_len + len(timing["word"]) <= max_line_width 149 | seg_break = i == 0 and len(subtitle) > 0 and preserve_segments 150 | if ( 151 | line_len > 0 152 | and has_room 153 | and not long_pause 154 | and not seg_break 155 | ): 156 | # line continuation 157 | line_len += len(timing["word"]) 158 | else: 159 | # new line 160 | timing["word"] = timing["word"].strip() 161 | if ( 162 | len(subtitle) > 0 163 | and max_line_count is not None 164 | and (long_pause or line_count >= max_line_count) 165 | or seg_break 166 | ): 167 | # subtitle break 168 | yield subtitle 169 | subtitle = [] 170 | line_count = 1 171 | elif line_len > 0: 172 | # line break 173 | line_count += 1 174 | timing["word"] = "\n" + timing["word"] 175 | line_len = len(timing["word"].strip()) 176 | subtitle.append(timing) 177 | last = timing["start"] 178 | chunk_index += max_words_per_line 179 | if len(subtitle) > 0: 180 | yield subtitle 181 | 182 | if len(result["segments"]) > 0 and "words" in result["segments"][0]: 183 | for subtitle in iterate_subtitles(): 184 | subtitle_start = self.format_timestamp(subtitle[0]["start"]) 185 | subtitle_end = self.format_timestamp(subtitle[-1]["end"]) 186 | subtitle_text = "".join([word["word"] for word in subtitle]) 187 | if highlight_words: 188 | last = subtitle_start 189 | all_words = [timing["word"] for timing in subtitle] 190 | for i, this_word in enumerate(subtitle): 191 | start = self.format_timestamp(this_word["start"]) 192 | end = self.format_timestamp(this_word["end"]) 193 | if last != start: 194 | yield last, start, subtitle_text 195 | 196 | yield start, end, "".join( 197 | [ 198 | re.sub(r"^(\s*)(.*)$", r"\1\2", word) 199 | if j == i 200 | else word 201 | for j, word in enumerate(all_words) 202 | ] 203 | ) 204 | last = end 205 | else: 206 | yield subtitle_start, subtitle_end, subtitle_text 207 | else: 208 | for segment in result["segments"]: 209 | segment_start = self.format_timestamp(segment["start"]) 210 | segment_end = self.format_timestamp(segment["end"]) 211 | segment_text = segment["text"].strip().replace("-->", "->") 212 | yield segment_start, segment_end, segment_text 213 | 214 | def format_timestamp(self, seconds: float): 215 | return format_timestamp( 216 | seconds=seconds, 217 | always_include_hours=self.always_include_hours, 218 | decimal_marker=self.decimal_marker, 219 | ) 220 | 221 | 222 | class WriteVTT(SubtitlesWriter): 223 | extension: str = "vtt" 224 | always_include_hours: bool = False 225 | decimal_marker: str = "." 226 | 227 | def write_result( 228 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 229 | ): 230 | print("WEBVTT\n", file=file) 231 | for start, end, text in self.iterate_result(result, options, **kwargs): 232 | print(f"{start} --> {end}\n{text}\n", file=file, flush=True) 233 | 234 | 235 | class WriteSRT(SubtitlesWriter): 236 | extension: str = "srt" 237 | always_include_hours: bool = True 238 | decimal_marker: str = "," 239 | 240 | def write_result( 241 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 242 | ): 243 | for i, (start, end, text) in enumerate( 244 | self.iterate_result(result, options, **kwargs), start=1 245 | ): 246 | print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) 247 | 248 | 249 | class WriteTSV(ResultWriter): 250 | """ 251 | Write a transcript to a file in TSV (tab-separated values) format containing lines like: 252 | \t\t 253 | 254 | Using integer milliseconds as start and end times means there's no chance of interference from 255 | an environment setting a language encoding that causes the decimal in a floating point number 256 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. 257 | """ 258 | 259 | extension: str = "tsv" 260 | 261 | def write_result( 262 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 263 | ): 264 | print("start", "end", "text", sep="\t", file=file) 265 | for segment in result["segments"]: 266 | print(round(1000 * segment["start"]), file=file, end="\t") 267 | print(round(1000 * segment["end"]), file=file, end="\t") 268 | print(segment["text"].strip().replace("\t", " "), file=file, flush=True) 269 | 270 | 271 | class WriteJSON(ResultWriter): 272 | extension: str = "json" 273 | 274 | def write_result( 275 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 276 | ): 277 | json.dump(result, file) 278 | 279 | 280 | def get_writer( 281 | output_format: str, output_dir: str 282 | ) -> Callable[[dict, TextIO, dict], None]: 283 | writers = { 284 | "txt": WriteTXT, 285 | "vtt": WriteVTT, 286 | "srt": WriteSRT, 287 | "tsv": WriteTSV, 288 | "json": WriteJSON, 289 | } 290 | 291 | if output_format == "all": 292 | all_writers = [writer(output_dir) for writer in writers.values()] 293 | 294 | def write_all( 295 | result: dict, file: TextIO, options: Optional[dict] = None, **kwargs 296 | ): 297 | for writer in all_writers: 298 | writer(result, file, options, **kwargs) 299 | 300 | return write_all 301 | 302 | return writers[output_format](output_dir) 303 | -------------------------------------------------------------------------------- /src/ui/add_subtitles.py: -------------------------------------------------------------------------------- 1 | import os 2 | import queue 3 | import threading 4 | import time 5 | import subprocess 6 | 7 | import customtkinter as ctk 8 | from customtkinter import filedialog as fd 9 | from mutagen import File, MutagenError 10 | from pydub import AudioSegment, exceptions 11 | from tkinterdnd2 import TkinterDnD, DND_ALL 12 | from ..whisper.utils import get_writer 13 | 14 | from .ctkAlert import CTkAlert 15 | from .ctkLoader import CTkLoader 16 | from .icons import icons 17 | from .style import FONTS 18 | from ..logic import Transcriber 19 | 20 | 21 | class CTk(ctk.CTkFrame, TkinterDnD.DnDWrapper): 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.TkdndVersion = TkinterDnD._require(self) 25 | 26 | 27 | class AddSubtitlesUI(CTk): 28 | def __init__(self, parent): 29 | super().__init__(master=parent, width=620, height=720, fg_color=("#F2F0EE", "#1E1F22"), border_width=0) 30 | self.grid_propagate(False) 31 | self.grid_columnconfigure(0, weight=1) 32 | 33 | title = ctk.CTkLabel(self, text=" Add Subtitles", font=FONTS["title"], image=icons["subtitle"], 34 | compound="left") 35 | title.grid(row=0, column=0, padx=20, pady=20, sticky="w") 36 | 37 | self.close_btn = ctk.CTkButton(self, text="", image=icons["close"], fg_color="transparent", hover=False, width=30, 38 | height=30, command=self.hide_transcribe_ui) 39 | self.close_btn.grid(row=0, column=1, padx=20, pady=20, sticky="e") 40 | 41 | self.main_frame = ctk.CTkFrame(self, fg_color="transparent") 42 | self.main_frame.grid(row=1, column=0, padx=20, pady=20, sticky="nsew", columnspan=2) 43 | self.main_frame.grid_columnconfigure(0, weight=1) 44 | 45 | self.master = parent 46 | self.drag_drop = None 47 | self.audio_path = None 48 | self.cancel_signal = False 49 | self.loader = None 50 | self.option_menu = None 51 | self.process = None 52 | self.queue = queue.Queue() 53 | 54 | self.default_widget() 55 | 56 | self.grid(row=0, column=0, sticky="nsew") 57 | 58 | def default_widget(self): 59 | label = ctk.CTkLabel(self.main_frame, text="No File Selected", font=FONTS["subtitle_bold"]) 60 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w") 61 | 62 | self.drag_drop = ctk.CTkButton(self.main_frame, text="➕ \nDrag & Drop Here", width=500, height=250, 63 | text_color=("#000000", "#DFE1E5"), hover=False, fg_color="transparent", 64 | border_width=2, corner_radius=5, border_color=("#D3D5DB", "#2B2D30"), 65 | font=FONTS["normal"]) 66 | self.drag_drop.grid(row=1, column=0, padx=0, pady=10, sticky="nsew") 67 | 68 | self.drag_drop.drop_target_register(DND_ALL) 69 | self.drag_drop.dnd_bind('<>', self.drop) 70 | 71 | label_or = ctk.CTkLabel(self.main_frame, text="Or", font=("", 14)) 72 | label_or.grid(row=2, column=0, padx=0, pady=5, sticky="nsew") 73 | 74 | select_btn = ctk.CTkButton(self.main_frame, text="Browse Files", width=150, height=40, 75 | command=self.select_file_callback, font=FONTS["normal"]) 76 | select_btn.grid(row=3, column=0, padx=200, pady=10, sticky="nsew") 77 | 78 | label_1 = ctk.CTkLabel(self.main_frame, text="Support Formats: MP4, MOV, WMV, AVI", 79 | fg_color=("#D3D5DB", "#2B2D30"), corner_radius=5, width=400, height=50, 80 | font=FONTS["small"]) 81 | label_1.grid(row=4, column=0, padx=0, pady=20, sticky="sew") 82 | 83 | def task_widget(self): 84 | file_name, duration, file_size = self.get_audio_info(self.audio_path) 85 | 86 | label = ctk.CTkLabel(self.main_frame, text="Selected File", font=FONTS["subtitle_bold"]) 87 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w") 88 | 89 | frame = ctk.CTkFrame(self.main_frame) 90 | frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew") 91 | frame.grid_columnconfigure(0, weight=1) 92 | 93 | label_1 = ctk.CTkLabel(frame, text="File Name", font=FONTS["normal"]) 94 | label_1.grid(row=0, column=0, padx=20, pady=(20, 5), sticky="w") 95 | label_1_value = ctk.CTkLabel(frame, text=file_name, font=FONTS["small"]) 96 | label_1_value.grid(row=0, column=1, padx=20, pady=(20, 5), sticky="e") 97 | 98 | label_2 = ctk.CTkLabel(frame, text="Duration", font=FONTS["normal"]) 99 | label_2.grid(row=1, column=0, padx=20, pady=5, sticky="w") 100 | label_2_value = ctk.CTkLabel(frame, text=duration, font=FONTS["small"]) 101 | label_2_value.grid(row=1, column=1, padx=20, pady=5, sticky="e") 102 | 103 | label_3 = ctk.CTkLabel(frame, text="Size", font=FONTS["normal"]) 104 | label_3.grid(row=2, column=0, padx=20, pady=(5, 20), sticky="w") 105 | label_3_value = ctk.CTkLabel(frame, text=f"{file_size:.2f} MB", font=FONTS["small"]) 106 | label_3_value.grid(row=2, column=1, padx=20, pady=(5, 20), sticky="e") 107 | 108 | start_btn = ctk.CTkButton(self.main_frame, text="Start Transcribing", height=40, command=self.start_callback, 109 | font=FONTS["normal"]) 110 | start_btn.grid(row=2, column=0, padx=200, pady=20, sticky="nsew") 111 | 112 | def result_widget(self): 113 | result = self.queue.get() 114 | text = str(result["text"]).strip() 115 | 116 | result_label = ctk.CTkLabel(self.main_frame, text="Transcribed Text", font=FONTS["subtitle_bold"]) 117 | result_label.grid(row=0, column=0, padx=10, pady=(20, 5), sticky="w") 118 | 119 | textbox = ctk.CTkTextbox(self.main_frame, width=580, height=200, border_width=2, font=FONTS["normal"]) 120 | textbox.grid(row=1, column=0, padx=10, pady=(5, 20), sticky="nsew", columnspan=2) 121 | textbox.insert("0.0", text=text) 122 | 123 | download_label = ctk.CTkLabel(self.main_frame, text="Download Video with Subtitles", font=FONTS["subtitle_bold"]) 124 | download_label.grid(row=2, column=0, padx=10, pady=(20, 5), sticky="w") 125 | 126 | download_btn = ctk.CTkButton(self.main_frame, text="Download", command=lambda: self.add_subtitle(result), 127 | font=FONTS["normal"], height=35) 128 | download_btn.grid(row=3, column=0, padx=10, pady=20, sticky="nsw") 129 | 130 | def add_subtitle(self, result): 131 | file_name = os.path.splitext(os.path.basename(self.audio_path))[0] 132 | 133 | output_video = fd.asksaveasfilename( 134 | parent=self, 135 | initialfile=f"{file_name}-subtitle", 136 | title="Export video with subtitle", 137 | defaultextension=".mp4", 138 | filetypes=[("Video files", "*.mp4 *.mov *.wmv *.avi")] 139 | ) 140 | if output_video: 141 | self.close_btn.configure(state="disabled") 142 | widgets = self.main_frame.winfo_children() 143 | for widget in widgets: 144 | widget.destroy() 145 | 146 | thread = threading.Thread(target=self.subtitle_handler, args=(output_video, result)) 147 | thread.start() 148 | 149 | def subtitle_handler(self, output_video, result): 150 | self.loader = CTkLoader(parent=self.master, title="Adding Subtitles", msg="Please wait...", 151 | cancel_func=self.kill_process) 152 | file_name = os.path.splitext(os.path.basename(self.audio_path))[0] 153 | input_video = self.audio_path 154 | writer = get_writer("srt", ".") 155 | subtitle_file_path = f"{file_name}.srt" 156 | writer(result, input_video, {"highlight_words": True, "max_line_count": 50, "max_line_width": 3}) 157 | 158 | ffmpeg_command = [ 159 | "ffmpeg", 160 | "-y", 161 | "-i", input_video, 162 | "-vf", f"subtitles={subtitle_file_path}", 163 | output_video 164 | ] 165 | 166 | self.process = subprocess.Popen(ffmpeg_command, stdout=subprocess.PIPE, universal_newlines=True, shell=True) 167 | 168 | self.process.wait() 169 | 170 | self.loader.hide_loader() 171 | 172 | if self.process.returncode == 0: 173 | CTkAlert(parent=self.master, status="success", title="Success", 174 | msg="Subtitle added successfully.") 175 | self.close_btn.configure(state="normal") 176 | self.after(1000, self.default_widget) 177 | else: 178 | CTkAlert(parent=self.master, status="error", title="Error", 179 | msg="Error adding subtitles to the video.") 180 | self.close_btn.configure(state="normal") 181 | self.result_widget() 182 | 183 | os.remove(subtitle_file_path) 184 | 185 | def kill_process(self): 186 | pass 187 | 188 | def start_callback(self): 189 | self.close_btn.configure(state="disabled") 190 | widgets = self.main_frame.winfo_children() 191 | for widget in widgets: 192 | widget.destroy() 193 | 194 | self.loader = CTkLoader(parent=self.master, title="Transcribing...", msg="Please wait...", 195 | cancel_func=self.set_signal) 196 | thread = threading.Thread(target=self.start_transcribing, args=(self.audio_path, self.check_signal)) 197 | thread.start() 198 | 199 | def start_transcribing(self, audio_path, check_signal): 200 | transcriber = Transcriber(audio=audio_path) 201 | result = transcriber.audio_recognition(cancel_func=check_signal) 202 | self.loader.hide_loader() 203 | if result: 204 | self.queue.put(result) 205 | self.close_btn.configure(state="normal") 206 | self.result_widget() 207 | 208 | def set_signal(self): 209 | self.cancel_signal = True 210 | 211 | def check_signal(self): 212 | original_value = self.cancel_signal 213 | 214 | if self.cancel_signal: 215 | self.cancel_signal = False 216 | self.close_btn.configure(state="normal") 217 | self.after(1000, self.default_widget) 218 | 219 | return original_value 220 | 221 | def select_file_callback(self): 222 | file_path = fd.askopenfilename( 223 | filetypes=[("Video files", "*.mp4 *.mov *.wmv *.avi")]) 224 | if file_path: 225 | audio_path = os.path.abspath(file_path) 226 | if self.is_streamable_audio(audio_path): 227 | self.audio_path = audio_path 228 | widgets = self.main_frame.winfo_children() 229 | 230 | for widget in widgets: 231 | widget.destroy() 232 | self.after(1000, self.task_widget) 233 | else: 234 | CTkAlert(parent=self.master, status="error", title="Error", 235 | msg="The chosen audio file is not valid or streamable.") 236 | 237 | @staticmethod 238 | def is_streamable_audio(audio_path): 239 | if not os.path.isfile(audio_path): 240 | return False 241 | 242 | try: 243 | audio = AudioSegment.from_file(audio_path) 244 | audio_info = File(audio_path) 245 | return len(audio) > 0 and audio_info.info.length > 0 246 | except (FileNotFoundError, exceptions.CouldntDecodeError, MutagenError): 247 | return False 248 | 249 | @staticmethod 250 | def get_audio_info(file_path): 251 | try: 252 | if not os.path.isfile(file_path): 253 | return None, None 254 | 255 | file_name = os.path.basename(file_path) 256 | 257 | audio_info = File(file_path) 258 | if audio_info is None: 259 | return None, None 260 | 261 | duration_seconds = audio_info.info.length 262 | 263 | duration_formatted = time.strftime("%H:%M:%S", time.gmtime(duration_seconds)) 264 | 265 | file_size_mb = os.path.getsize(file_path) / (1024 * 1024) 266 | 267 | return file_name, duration_formatted, file_size_mb 268 | except MutagenError: 269 | return None, None 270 | 271 | def drop(self, event): 272 | dropped_file = event.data.replace("{", "").replace("}", "") 273 | audio_path = os.path.abspath(dropped_file) 274 | if self.is_streamable_audio(audio_path): 275 | self.audio_path = audio_path 276 | widgets = self.main_frame.winfo_children() 277 | 278 | for widget in widgets: 279 | widget.destroy() 280 | 281 | self.after(1000, self.task_widget) 282 | else: 283 | CTkAlert(parent=self.master, status="error", title="Error", 284 | msg="The chosen audio file is not valid or streamable.") 285 | 286 | def hide_transcribe_ui(self): 287 | self.destroy() 288 | -------------------------------------------------------------------------------- /src/ui/translate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import queue 3 | import threading 4 | import time 5 | 6 | import customtkinter as ctk 7 | from customtkinter import filedialog as fd 8 | from mutagen import File, MutagenError 9 | from pydub import AudioSegment, exceptions 10 | from tkinterdnd2 import TkinterDnD, DND_ALL 11 | 12 | from .ctkAlert import CTkAlert 13 | from .ctkLoader import CTkLoader 14 | from .ctkdropdown import CTkScrollableDropdownFrame 15 | from .icons import icons 16 | from .style import FONTS, DROPDOWN 17 | from ..logic import Transcriber 18 | 19 | 20 | class CTk(ctk.CTkFrame, TkinterDnD.DnDWrapper): 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.TkdndVersion = TkinterDnD._require(self) 24 | 25 | 26 | languages = ['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'assamese', 'aymara', 'azerbaijani', 27 | 'bambara', 'basque', 'belarusian', 'bengali', 'bhojpuri', 'bosnian', 'bulgarian', 'catalan', 28 | 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian', 29 | 'czech', 'danish', 'dhivehi', 'dogri', 'dutch', 'english', 'esperanto', 'estonian', 'ewe', 30 | 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'guarani', 31 | 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hindi', 'hmong', 'hungarian', 32 | 'icelandic', 'igbo', 'ilocano', 'indonesian', 'irish', 'italian', 'japanese', 'javanese', 33 | 'kannada', 'kazakh', 'khmer', 'kinyarwanda', 'konkani', 'korean', 'krio', 'kurdish (kurmanji)', 34 | 'kurdish (sorani)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lingala', 'lithuanian', 'luganda', 35 | 'luxembourgish', 'macedonian', 'maithili', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori', 36 | 'marathi', 'meiteilon (manipuri)', 'mizo', 'mongolian', 'myanmar', 'nepali', 'norwegian', 37 | 'odia (oriya)', 'oromo', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'quechua', 38 | 'romanian', 'russian', 'samoan', 'sanskrit', 'scots gaelic', 'sepedi', 'serbian', 'sesotho', 39 | 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili', 40 | 'swedish', 'tajik', 'tamil', 'tatar', 'telugu', 'thai', 'tigrinya', 'tsonga', 'turkish', 'turkmen', 41 | 'twi', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba', 42 | 'zulu'] 43 | 44 | 45 | class TranslateUI(CTk): 46 | def __init__(self, parent): 47 | super().__init__(master=parent, width=620, height=720, fg_color=("#F2F0EE", "#1E1F22"), border_width=0) 48 | self.grid_propagate(False) 49 | self.grid_columnconfigure(0, weight=1) 50 | 51 | title = ctk.CTkLabel(self, text=" Translate Audio", font=FONTS["title"], image=icons["translate"], 52 | compound="left") 53 | title.grid(row=0, column=0, padx=20, pady=20, sticky="w") 54 | 55 | self.close_btn = ctk.CTkButton(self, text="", image=icons["close"], fg_color="transparent", hover=False, width=30, 56 | height=30, command=self.hide_transcribe_ui) 57 | self.close_btn.grid(row=0, column=1, padx=20, pady=20, sticky="e") 58 | 59 | self.main_frame = ctk.CTkFrame(self, fg_color="transparent") 60 | self.main_frame.grid(row=1, column=0, padx=20, pady=20, sticky="nsew", columnspan=2) 61 | self.main_frame.grid_columnconfigure(0, weight=1) 62 | 63 | self.master = parent 64 | self.drag_drop = None 65 | self.audio_path = None 66 | self.cancel_signal = False 67 | self.loader = None 68 | self.lang_dropdown = None 69 | self.queue = queue.Queue() 70 | 71 | self.default_widget() 72 | 73 | self.grid(row=0, column=0, sticky="nsew") 74 | 75 | def default_widget(self): 76 | label = ctk.CTkLabel(self.main_frame, text="No File Selected", font=FONTS["subtitle_bold"]) 77 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w") 78 | 79 | self.drag_drop = ctk.CTkButton(self.main_frame, text="➕ \nDrag & Drop Here", width=500, height=250, 80 | text_color=("#000000", "#DFE1E5"), hover=False, fg_color="transparent", 81 | border_width=2, corner_radius=5, border_color=("#D3D5DB", "#2B2D30"), 82 | font=FONTS["normal"]) 83 | self.drag_drop.grid(row=1, column=0, padx=0, pady=10, sticky="nsew") 84 | 85 | self.drag_drop.drop_target_register(DND_ALL) 86 | self.drag_drop.dnd_bind('<>', self.drop) 87 | 88 | label_or = ctk.CTkLabel(self.main_frame, text="Or", font=("", 14)) 89 | label_or.grid(row=2, column=0, padx=0, pady=5, sticky="nsew") 90 | 91 | select_btn = ctk.CTkButton(self.main_frame, text="Browse Files", width=150, height=40, 92 | command=self.select_file_callback, font=FONTS["normal"]) 93 | select_btn.grid(row=3, column=0, padx=200, pady=10, sticky="nsew") 94 | 95 | label_1 = ctk.CTkLabel(self.main_frame, text="Support Formats: WAV, MP3, OGG, FLAC, MP4, MOV, WMV, AVI", 96 | fg_color=("#D3D5DB", "#2B2D30"), corner_radius=5, width=400, height=50, 97 | font=FONTS["small"]) 98 | label_1.grid(row=4, column=0, padx=0, pady=20, sticky="sew") 99 | 100 | def task_widget(self): 101 | file_name, duration, file_size = self.get_audio_info(self.audio_path) 102 | 103 | label = ctk.CTkLabel(self.main_frame, text="Selected File", font=FONTS["subtitle_bold"]) 104 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w") 105 | 106 | frame = ctk.CTkFrame(self.main_frame) 107 | frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew", columnspan=2) 108 | frame.grid_columnconfigure(0, weight=1) 109 | 110 | label_1 = ctk.CTkLabel(frame, text="File Name", font=FONTS["normal"]) 111 | label_1.grid(row=0, column=0, padx=20, pady=(20, 5), sticky="w") 112 | label_1_value = ctk.CTkLabel(frame, text=file_name, font=FONTS["small"]) 113 | label_1_value.grid(row=0, column=1, padx=20, pady=(20, 5), sticky="e") 114 | 115 | label_2 = ctk.CTkLabel(frame, text="Duration", font=FONTS["normal"]) 116 | label_2.grid(row=1, column=0, padx=20, pady=5, sticky="w") 117 | label_2_value = ctk.CTkLabel(frame, text=duration, font=FONTS["small"]) 118 | label_2_value.grid(row=1, column=1, padx=20, pady=5, sticky="e") 119 | 120 | label_3 = ctk.CTkLabel(frame, text="Size", font=FONTS["normal"]) 121 | label_3.grid(row=2, column=0, padx=20, pady=(5, 20), sticky="w") 122 | label_3_value = ctk.CTkLabel(frame, text=f"{file_size:.2f} MB", font=FONTS["small"]) 123 | label_3_value.grid(row=2, column=1, padx=20, pady=(5, 20), sticky="e") 124 | 125 | translate_to = ctk.CTkLabel(self.main_frame, text="Translate To", font=FONTS["normal"]) 126 | translate_to.grid(row=2, column=0, padx=40, pady=20, sticky="w") 127 | self.lang_dropdown = ctk.CTkOptionMenu(self.main_frame, font=FONTS["small"]) 128 | self.lang_dropdown.grid(row=2, column=1, padx=40, pady=20, sticky="e") 129 | CTkScrollableDropdownFrame(self.lang_dropdown, values=languages, **DROPDOWN) 130 | self.lang_dropdown.set(languages[0]) 131 | 132 | start_btn = ctk.CTkButton(self.main_frame, text="Start Translating", height=40, command=self.start_callback, 133 | font=FONTS["normal"]) 134 | start_btn.grid(row=3, column=0, padx=200, pady=40, sticky="nsew", columnspan=2) 135 | 136 | def result_widget(self): 137 | result = self.queue.get() 138 | 139 | result_label = ctk.CTkLabel(self.main_frame, text="Translated Text", font=FONTS["subtitle_bold"]) 140 | result_label.grid(row=0, column=0, padx=10, pady=(20, 5), sticky="w") 141 | 142 | textbox = ctk.CTkTextbox(self.main_frame, width=580, height=200, border_width=2, font=FONTS["normal"]) 143 | textbox.grid(row=1, column=0, padx=10, pady=(5, 20), sticky="nsew", columnspan=2) 144 | textbox.insert("0.0", text=result) 145 | 146 | download_label = ctk.CTkLabel(self.main_frame, text="Download Text", font=FONTS["subtitle_bold"]) 147 | download_label.grid(row=2, column=0, padx=10, pady=(20, 5), sticky="w") 148 | 149 | download_btn = ctk.CTkButton(self.main_frame, text="Download", command=lambda: self.save_text(result), 150 | font=FONTS["normal"], height=35) 151 | download_btn.grid(row=3, column=0, padx=10, pady=20, sticky="nsw") 152 | 153 | def save_text(self, result): 154 | file_name = os.path.basename(self.audio_path) 155 | sep = "." 156 | file_name = file_name.split(sep, 1)[0] 157 | 158 | file_path = fd.asksaveasfilename( 159 | parent=self, 160 | initialfile=file_name, 161 | title="Export text", 162 | defaultextension=".txt", 163 | filetypes=[("Text Files", "*.txt")] 164 | ) 165 | 166 | if file_path: 167 | with open(file_path, 'w', encoding="utf-8") as file: 168 | file.write(result) 169 | 170 | def start_callback(self): 171 | self.close_btn.configure(state="disabled") 172 | widgets = self.main_frame.winfo_children() 173 | for widget in widgets: 174 | widget.destroy() 175 | 176 | self.loader = CTkLoader(parent=self.master, title="Transcribing...", msg="Please wait...", 177 | cancel_func=self.set_signal) 178 | thread = threading.Thread(target=self.start_transcribing, args=(self.audio_path, self.check_signal)) 179 | thread.start() 180 | 181 | def start_transcribing(self, audio_path, check_signal): 182 | language = self.lang_dropdown.get() 183 | transcriber = Transcriber(audio=audio_path) 184 | result = transcriber.translate_audio(cancel_func=check_signal, to_language=language) 185 | self.loader.hide_loader() 186 | if result: 187 | self.queue.put(result) 188 | self.close_btn.configure(state="normal") 189 | self.result_widget() 190 | 191 | def set_signal(self): 192 | self.cancel_signal = True 193 | 194 | def check_signal(self): 195 | original_value = self.cancel_signal 196 | 197 | if self.cancel_signal: 198 | self.cancel_signal = False 199 | self.close_btn.configure(state="normal") 200 | self.after(1000, self.default_widget) 201 | 202 | return original_value 203 | 204 | def select_file_callback(self): 205 | file_path = fd.askopenfilename( 206 | filetypes=[("Audio files", "*.mp3 *.wav *.ogg *.flac"), ("Video files", "*.mp4 *.mov *.wmv *.avi")]) 207 | if file_path: 208 | audio_path = os.path.abspath(file_path) 209 | if self.is_streamable_audio(audio_path): 210 | self.audio_path = audio_path 211 | widgets = self.main_frame.winfo_children() 212 | 213 | for widget in widgets: 214 | widget.destroy() 215 | self.after(1000, self.task_widget) 216 | else: 217 | CTkAlert(parent=self.master, status="error", title="Error", 218 | msg="The chosen audio file is not valid or streamable.") 219 | 220 | @staticmethod 221 | def is_streamable_audio(audio_path): 222 | if not os.path.isfile(audio_path): 223 | return False 224 | 225 | try: 226 | audio = AudioSegment.from_file(audio_path) 227 | audio_info = File(audio_path) 228 | return len(audio) > 0 and audio_info.info.length > 0 229 | except (FileNotFoundError, exceptions.CouldntDecodeError, MutagenError): 230 | return False 231 | 232 | @staticmethod 233 | def get_audio_info(file_path): 234 | try: 235 | if not os.path.isfile(file_path): 236 | return None, None 237 | 238 | file_name = os.path.basename(file_path) 239 | 240 | audio_info = File(file_path) 241 | if audio_info is None: 242 | return None, None 243 | 244 | duration_seconds = audio_info.info.length 245 | 246 | duration_formatted = time.strftime("%H:%M:%S", time.gmtime(duration_seconds)) 247 | 248 | file_size_mb = os.path.getsize(file_path) / (1024 * 1024) 249 | 250 | return file_name, duration_formatted, file_size_mb 251 | except MutagenError: 252 | return None, None 253 | 254 | def drop(self, event): 255 | dropped_file = event.data.replace("{", "").replace("}", "") 256 | audio_path = os.path.abspath(dropped_file) 257 | if self.is_streamable_audio(audio_path): 258 | self.audio_path = audio_path 259 | widgets = self.main_frame.winfo_children() 260 | 261 | for widget in widgets: 262 | widget.destroy() 263 | 264 | self.after(1000, self.task_widget) 265 | else: 266 | CTkAlert(parent=self.master, status="error", title="Error", 267 | msg="The chosen audio file is not valid or streamable.") 268 | 269 | def hide_transcribe_ui(self): 270 | self.destroy() 271 | -------------------------------------------------------------------------------- /src/ui/live_transcribe.py: -------------------------------------------------------------------------------- 1 | import random 2 | import threading 3 | import tkinter as tk 4 | from datetime import datetime, timedelta 5 | from queue import Queue 6 | from time import sleep 7 | 8 | import customtkinter as ctk 9 | import numpy as np 10 | import speech_recognition as sr 11 | import torch 12 | from .. import whisper 13 | 14 | from .ctk_tooltip import CTkToolTip 15 | from .ctkdropdown import CTkScrollableDropdownFrame 16 | from .icons import icons 17 | from .style import FONTS, DROPDOWN 18 | 19 | 20 | class TkAudioVisualizer(tk.Frame): 21 | def __init__(self, 22 | master: any, 23 | gradient=None, 24 | bar_color: str = "white", 25 | bar_width: int = 7, 26 | **kwargs): 27 | tk.Frame.__init__(self, master) 28 | if gradient is None: 29 | gradient = ["cyan", "blue"] 30 | self.viz = DrawBars(self, gradient[0], gradient[1], bar_width, bar_color, relief="sunken", **kwargs) 31 | self.viz.pack(fill="both", expand=True) 32 | 33 | def start(self): 34 | """ start the vizualizer """ 35 | if not self.viz._running: 36 | self.viz._running = True 37 | self.viz.update() 38 | 39 | def stop(self): 40 | """ stop the visualizer """ 41 | self.viz._running = False 42 | 43 | 44 | class DrawBars(tk.Canvas): 45 | '''A gradient frame which uses a canvas to draw the background''' 46 | 47 | def __init__(self, parent, color1, color2, bar_width, bar_color, **kwargs): 48 | tk.Canvas.__init__(self, parent, bg=bar_color, bd=0, highlightthickness=0, **kwargs) 49 | self._color1 = color1 50 | self._color2 = color2 51 | self._bar_width = bar_width 52 | self._running = False 53 | self.after(100, lambda: self._draw_gradient()) 54 | self.bind("", lambda e: self._draw_gradient() if not self._running else None) 55 | 56 | def _draw_gradient(self, event=None): 57 | '''Draw the gradient spectrum ''' 58 | self.delete("gradient") 59 | width = self.winfo_width() 60 | height = self.winfo_height() 61 | limit = width + 10 62 | 63 | (r1, g1, b1) = self.winfo_rgb(self._color1) 64 | (r2, g2, b2) = self.winfo_rgb(self._color2) 65 | r_ratio = float(r2 - r1) / limit 66 | g_ratio = float(g2 - g1) / limit 67 | b_ratio = float(b2 - b1) / limit 68 | 69 | for i in range(0, limit, self._bar_width): 70 | bar_height = random.randint(int(limit / 8), int(limit / 2.5)) 71 | if not self._running: 72 | bar_height = height 73 | nr = int(r1 + (r_ratio * i)) 74 | ng = int(g1 + (g_ratio * i)) 75 | nb = int(b1 + (b_ratio * i)) 76 | 77 | color = "#%4.4x%4.4x%4.4x" % (nr, ng, nb) 78 | self.create_line(i, 0, i, bar_height, tags=("gradient",), width=self._bar_width, fill=color) 79 | 80 | self.lower("gradient") 81 | 82 | if self._running: 83 | self.after(150, self._draw_gradient) 84 | 85 | def update(self): 86 | self._draw_gradient() 87 | 88 | 89 | languages = [ 90 | "English", 91 | "Chinese", 92 | "German", 93 | "Spanish", 94 | "Russian", 95 | "Korean", 96 | "French", 97 | "Japanese", 98 | "Portuguese", 99 | "Turkish", 100 | "Polish", 101 | "Catalan", 102 | "Dutch", 103 | "Arabic", 104 | "Swedish", 105 | "Italian", 106 | "Indonesian", 107 | "Hindi", 108 | "Finnish", 109 | "Vietnamese", 110 | "Hebrew", 111 | "Ukrainian", 112 | "Greek", 113 | "Malay", 114 | "Czech", 115 | "Romanian", 116 | "Danish", 117 | "Hungarian", 118 | "Tamil", 119 | "Norwegian", 120 | "Thai", 121 | "Urdu", 122 | "Croatian", 123 | "Bulgarian", 124 | "Lithuanian", 125 | "Latin", 126 | "Maori", 127 | "Malayalam", 128 | "Welsh", 129 | "Slovak", 130 | "Telugu", 131 | "Persian", 132 | "Latvian", 133 | "Bengali", 134 | "Serbian", 135 | "Azerbaijani", 136 | "Slovenian", 137 | "Kannada", 138 | "Estonian", 139 | "Macedonian", 140 | "Breton", 141 | "Basque", 142 | "Icelandic", 143 | "Armenian", 144 | "Nepali", 145 | "Mongolian", 146 | "Bosnian", 147 | "Kazakh", 148 | "Albanian", 149 | "Swahili", 150 | "Galician", 151 | "Marathi", 152 | "Punjabi", 153 | "Sinhala", 154 | "Khmer", 155 | "Shona", 156 | "Yoruba", 157 | "Somali", 158 | "Afrikaans", 159 | "Occitan", 160 | "Georgian", 161 | "Belarusian", 162 | "Tajik", 163 | "Sindhi", 164 | "Gujarati", 165 | "Amharic", 166 | "Yiddish", 167 | "Lao", 168 | "Uzbek", 169 | "Faroese", 170 | "Haitian creole", 171 | "Pashto", 172 | "Turkmen", 173 | "Nynorsk", 174 | "Maltese", 175 | "Sanskrit", 176 | "Luxembourgish", 177 | "Myanmar", 178 | "Tibetan", 179 | "Tagalog", 180 | "Malagasy", 181 | "Assamese", 182 | "Tatar", 183 | "Hawaiian", 184 | "Lingala", 185 | "Hausa", 186 | "Bashkir", 187 | "Javanese", 188 | "Sundanese", 189 | ] 190 | 191 | 192 | class LiveTranscribeUI(ctk.CTkFrame): 193 | def __init__(self, parent): 194 | super().__init__(master=parent, width=620, height=720, fg_color=("#F2F0EE", "#1E1F22"), border_width=0) 195 | self.energy_threshold_value = None 196 | self.language_value = None 197 | self.model_size_value = None 198 | self.tooltip = None 199 | self.language_dropdown = None 200 | self.model_dropdown = None 201 | self.mic_dropdown = None 202 | self.mic_btn = None 203 | self.recording = False 204 | self.close_btn = None 205 | self.main_frame = None 206 | self.input_value = None 207 | self.text_box = None 208 | self.record_button = None 209 | self.animation_frame = None 210 | self.transcribe_thread = None 211 | self.record_thread = None 212 | self.paa_thread = None 213 | 214 | self.selected_model = "base" 215 | self.selected_language = "english" 216 | 217 | self.audio_model = whisper.load_model(self.selected_model) 218 | self.phrase_time = None 219 | self.data_queue = Queue() 220 | self.recorder = sr.Recognizer() 221 | self.transcription = [''] 222 | 223 | self.energy_threshold = 500 224 | self.recorder.energy_threshold = self.energy_threshold 225 | self.recorder.dynamic_energy_threshold = False 226 | 227 | self.source = sr.Microphone(sample_rate=16000) 228 | 229 | self.record_timeout = 2 230 | self.phrase_timeout = 3 231 | 232 | self.grid_propagate(False) 233 | self.grid_columnconfigure(0, weight=1) 234 | self.ui() 235 | self.grid(row=0, column=0, sticky="nsew") 236 | 237 | def ui(self): 238 | title = ctk.CTkLabel(self, text=" Live Transcribe", font=FONTS["title"], image=icons["microphone"], 239 | compound="left") 240 | title.grid(row=0, column=0, padx=20, pady=20, sticky="w") 241 | 242 | self.close_btn = ctk.CTkButton(self, text="", image=icons["close"], fg_color="transparent", hover=False, 243 | width=30, 244 | height=30, command=self.hide_live_transcribe_ui) 245 | self.close_btn.grid(row=0, column=1, padx=20, pady=20, sticky="e") 246 | 247 | self.main_frame = ctk.CTkFrame(self, fg_color="transparent") 248 | self.main_frame.grid(row=1, column=0, padx=20, pady=20, sticky="nsew", columnspan=2) 249 | self.main_frame.grid_columnconfigure(0, weight=1) 250 | 251 | models = ["Tiny", "Base", "Small", "Medium"] 252 | model_size_label = ctk.CTkLabel(self.main_frame, text="Model Size", font=FONTS["normal"]) 253 | model_size_label.grid(row=0, column=0, padx=20, pady=(20, 10), sticky="w") 254 | self.model_size_value = ctk.CTkOptionMenu(self.main_frame, font=FONTS["small"]) 255 | self.model_size_value.grid(row=0, column=1, padx=20, pady=10, sticky="e") 256 | self.model_dropdown = CTkScrollableDropdownFrame(self.model_size_value, values=models, **DROPDOWN, 257 | command=self.model_option_callback) 258 | self.model_size_value.set(models[0]) 259 | 260 | language_label = ctk.CTkLabel(self.main_frame, text="Language", font=FONTS["normal"]) 261 | language_label.grid(row=1, column=0, padx=20, pady=10, sticky="w") 262 | self.language_value = ctk.CTkOptionMenu(self.main_frame, font=FONTS["small"]) 263 | self.language_value.grid(row=1, column=1, padx=20, pady=10, sticky="e") 264 | self.language_dropdown = CTkScrollableDropdownFrame(self.language_value, values=languages, **DROPDOWN, 265 | command=self.language_option_callback) 266 | self.language_value.set(languages[0]) 267 | 268 | energy_threshold_label = ctk.CTkLabel(self.main_frame, text="Energy Threshold", font=FONTS["normal"]) 269 | energy_threshold_label.grid(row=2, column=0, padx=20, pady=10, sticky="w") 270 | self.energy_threshold_value = ctk.CTkSlider(self.main_frame, from_=100, to=1000, number_of_steps=50, 271 | command=self.slider_event) 272 | self.energy_threshold_value.grid(row=2, column=1, padx=20, pady=10, sticky="e") 273 | self.energy_threshold_value.set(500) 274 | self.tooltip = CTkToolTip(widget=self.energy_threshold_value, message="500") 275 | 276 | self.record_button = ctk.CTkButton(self.main_frame, text="Start Transcribing", 277 | height=35, command=self.start_callback) 278 | self.record_button.grid(row=3, column=0, padx=100, pady=20, columnspan=2, sticky="nsew") 279 | 280 | ctk.CTkLabel(self.main_frame, text="Transcribed text", font=FONTS["subtitle"]).grid(row=4, column=0, sticky="w", 281 | padx=0, pady=(20, 5)) 282 | 283 | self.text_box = ctk.CTkTextbox(self.main_frame, height=260, font=FONTS["normal"]) 284 | self.text_box.grid(row=5, column=0, sticky="nsew", pady=(5, 20), columnspan=2) 285 | 286 | def start_callback(self): 287 | self.pause_play() 288 | 289 | self.record_thread = threading.Thread(target=self.start_recording) 290 | self.record_thread.start() 291 | 292 | self.paa_thread = threading.Thread(target=self.process_audio_data) 293 | self.paa_thread.start() 294 | 295 | def process_audio_data(self): 296 | while True: 297 | now = datetime.utcnow() 298 | 299 | if not self.data_queue.empty(): 300 | phrase_complete = False 301 | 302 | if self.phrase_time and now - self.phrase_time > timedelta(seconds=self.phrase_timeout): 303 | phrase_complete = True 304 | 305 | self.phrase_time = now 306 | 307 | audio_data = b''.join(self.data_queue.queue) 308 | self.data_queue.queue.clear() 309 | 310 | audio_np = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0 311 | 312 | result = self.audio_model.transcribe(audio_np, fp16=torch.cuda.is_available(), 313 | cancel_func=self.cancel_callback, language=self.selected_language) 314 | text = result['text'].strip() 315 | 316 | if phrase_complete: 317 | self.transcription.append(text) 318 | else: 319 | self.transcription[-1] = text 320 | 321 | self.update_text() 322 | 323 | sleep(0.25) 324 | 325 | def record_callback(self, _, audio: sr.AudioData) -> None: 326 | data = audio.get_raw_data() 327 | self.data_queue.put(data) 328 | 329 | def start_recording(self): 330 | self.recorder.listen_in_background(self.source, self.record_callback, phrase_time_limit=self.record_timeout) 331 | 332 | def update_text(self): 333 | text_to_display = ' '.join(self.transcription) 334 | self.text_box.delete("0.0", "end") 335 | self.text_box.insert("0.0", text_to_display) 336 | 337 | def pause_play(self): 338 | if not self.recording: 339 | self.recording = True 340 | self.record_button.configure(text="Stop Transcribing") 341 | else: 342 | self.recording = False 343 | self.record_button.configure(text="Start Transcribing") 344 | 345 | def slider_event(self, value): 346 | value = int(value) 347 | self.energy_threshold = value 348 | self.recorder.energy_threshold = value 349 | self.tooltip.configure(message=value) 350 | 351 | def model_option_callback(self, value): 352 | value1 = str(value).lower() 353 | self.selected_model = value1 354 | self.audio_model = whisper.load_model(self.selected_model) 355 | self.model_size_value.set(value) 356 | 357 | def language_option_callback(self, value): 358 | value1 = str(value).lower() 359 | self.selected_language = value1 360 | self.language_value.set(value) 361 | 362 | def cancel_callback(self): 363 | pass 364 | 365 | def hide_live_transcribe_ui(self): 366 | self.destroy() 367 | -------------------------------------------------------------------------------- /src/ui/transcribe.py: -------------------------------------------------------------------------------- 1 | import os 2 | import queue 3 | import threading 4 | import time 5 | 6 | import customtkinter as ctk 7 | from customtkinter import filedialog as fd 8 | from mutagen import File, MutagenError 9 | from pydub import AudioSegment, exceptions 10 | from tkinterdnd2 import TkinterDnD, DND_ALL 11 | from ..whisper.utils import get_writer 12 | from .ctkAlert import CTkAlert 13 | from .ctkLoader import CTkLoader 14 | from .ctkdropdown import CTkScrollableDropdownFrame 15 | from .icons import icons 16 | from .style import FONTS, DROPDOWN 17 | from ..logic import Transcriber 18 | 19 | 20 | class CTk(ctk.CTkFrame, TkinterDnD.DnDWrapper): 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.TkdndVersion = TkinterDnD._require(self) 24 | 25 | 26 | class TranscribeUI(CTk): 27 | def __init__(self, parent): 28 | super().__init__(master=parent, width=620, height=720, fg_color=("#F2F0EE", "#1E1F22"), border_width=0) 29 | self.grid_propagate(False) 30 | self.grid_columnconfigure(0, weight=1) 31 | 32 | title = ctk.CTkLabel(self, text=" Transcribe Audio", font=FONTS["title"], image=icons["audio_file"], 33 | compound="left") 34 | title.grid(row=0, column=0, padx=20, pady=20, sticky="w") 35 | 36 | self.close_btn = ctk.CTkButton(self, text="", image=icons["close"], fg_color="transparent", hover=False, width=30, 37 | height=30, command=self.hide_transcribe_ui) 38 | self.close_btn.grid(row=0, column=1, padx=20, pady=20, sticky="e") 39 | 40 | self.main_frame = ctk.CTkFrame(self, fg_color="transparent") 41 | self.main_frame.grid(row=1, column=0, padx=20, pady=20, sticky="nsew", columnspan=2) 42 | self.main_frame.grid_columnconfigure(0, weight=1) 43 | 44 | self.master = parent 45 | self.drag_drop = None 46 | self.audio_path = None 47 | self.cancel_signal = False 48 | self.loader = None 49 | self.option_menu = None 50 | self.queue = queue.Queue() 51 | 52 | self.default_widget() 53 | 54 | self.grid(row=0, column=0, sticky="nsew") 55 | 56 | def default_widget(self): 57 | label = ctk.CTkLabel(self.main_frame, text="No File Selected", font=FONTS["subtitle_bold"]) 58 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w") 59 | 60 | self.drag_drop = ctk.CTkButton(self.main_frame, text="➕ \nDrag & Drop Here", width=500, height=250, 61 | text_color=("#000000", "#DFE1E5"), hover=False, fg_color="transparent", 62 | border_width=2, corner_radius=5, border_color=("#D3D5DB", "#2B2D30"), 63 | font=FONTS["normal"]) 64 | self.drag_drop.grid(row=1, column=0, padx=0, pady=10, sticky="nsew") 65 | 66 | self.drag_drop.drop_target_register(DND_ALL) 67 | self.drag_drop.dnd_bind('<>', self.drop) 68 | 69 | label_or = ctk.CTkLabel(self.main_frame, text="Or", font=("", 14)) 70 | label_or.grid(row=2, column=0, padx=0, pady=5, sticky="nsew") 71 | 72 | select_btn = ctk.CTkButton(self.main_frame, text="Browse Files", width=150, height=40, 73 | command=self.select_file_callback, font=FONTS["normal"]) 74 | select_btn.grid(row=3, column=0, padx=200, pady=10, sticky="nsew") 75 | 76 | label_1 = ctk.CTkLabel(self.main_frame, text="Support Formats: WAV, MP3, OGG, FLAC, MP4, MOV, WMV, AVI", 77 | fg_color=("#D3D5DB", "#2B2D30"), corner_radius=5, width=400, height=50, 78 | font=FONTS["small"]) 79 | label_1.grid(row=4, column=0, padx=0, pady=20, sticky="sew") 80 | 81 | def task_widget(self): 82 | file_name, duration, file_size = self.get_audio_info(self.audio_path) 83 | 84 | label = ctk.CTkLabel(self.main_frame, text="Selected File", font=FONTS["subtitle_bold"]) 85 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w") 86 | 87 | frame = ctk.CTkFrame(self.main_frame) 88 | frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew") 89 | frame.grid_columnconfigure(0, weight=1) 90 | 91 | label_1 = ctk.CTkLabel(frame, text="File Name", font=FONTS["normal"]) 92 | label_1.grid(row=0, column=0, padx=20, pady=(20, 5), sticky="w") 93 | label_1_value = ctk.CTkLabel(frame, text=file_name, font=FONTS["small"]) 94 | label_1_value.grid(row=0, column=1, padx=20, pady=(20, 5), sticky="e") 95 | 96 | label_2 = ctk.CTkLabel(frame, text="Duration", font=FONTS["normal"]) 97 | label_2.grid(row=1, column=0, padx=20, pady=5, sticky="w") 98 | label_2_value = ctk.CTkLabel(frame, text=duration, font=FONTS["small"]) 99 | label_2_value.grid(row=1, column=1, padx=20, pady=5, sticky="e") 100 | 101 | label_3 = ctk.CTkLabel(frame, text="Size", font=FONTS["normal"]) 102 | label_3.grid(row=2, column=0, padx=20, pady=(5, 20), sticky="w") 103 | label_3_value = ctk.CTkLabel(frame, text=f"{file_size:.2f} MB", font=FONTS["small"]) 104 | label_3_value.grid(row=2, column=1, padx=20, pady=(5, 20), sticky="e") 105 | 106 | start_btn = ctk.CTkButton(self.main_frame, text="Start Transcribing", height=40, command=self.start_callback, 107 | font=FONTS["normal"]) 108 | start_btn.grid(row=2, column=0, padx=200, pady=20, sticky="nsew") 109 | 110 | def result_widget(self): 111 | result = self.queue.get() 112 | text = str(result["text"]).strip() 113 | 114 | result_label = ctk.CTkLabel(self.main_frame, text="Transcribed Text", font=FONTS["subtitle_bold"]) 115 | result_label.grid(row=0, column=0, padx=10, pady=(20, 5), sticky="w") 116 | 117 | textbox = ctk.CTkTextbox(self.main_frame, width=580, height=200, border_width=2, font=FONTS["normal"]) 118 | textbox.grid(row=1, column=0, padx=10, pady=(5, 20), sticky="nsew", columnspan=2) 119 | textbox.insert("0.0", text=text) 120 | 121 | download_label = ctk.CTkLabel(self.main_frame, text="Download Text and Subtitles", font=FONTS["subtitle_bold"]) 122 | download_label.grid(row=2, column=0, padx=10, pady=(20, 5), sticky="w") 123 | 124 | self.option_menu = ctk.CTkOptionMenu(self.main_frame, width=200) 125 | self.option_menu.grid(row=3, column=0, padx=10, pady=(5, 20), sticky="w") 126 | format_values = ["Text File (txt)", 127 | "Subtitles (SRT)", 128 | "WebVTT (VTT)", 129 | "Tab-Separated Values (TSV)", 130 | "JSON File (json)", 131 | "Save as all extensions"] 132 | CTkScrollableDropdownFrame(self.option_menu, values=format_values, **DROPDOWN) 133 | self.option_menu.set(format_values[0]) 134 | 135 | download_btn = ctk.CTkButton(self.main_frame, text="Download", command=lambda: self.save_text(result), 136 | font=FONTS["normal"], height=35) 137 | download_btn.grid(row=4, column=0, padx=10, pady=20, sticky="nsw") 138 | 139 | def save_text(self, result): 140 | file_name = os.path.basename(self.audio_path) 141 | sep = "." 142 | file_name = file_name.split(sep, 1)[0] 143 | 144 | selected_extension = self.option_menu.get() 145 | 146 | file_extension_map = { 147 | "Text File (txt)": ".txt", 148 | "Subtitles (SRT)": ".srt", 149 | "WebVTT (VTT)": ".vtt", 150 | "Tab-Separated Values (TSV)": ".tsv", 151 | "JSON File (json)": ".json", 152 | "Save as all extensions": ".all", 153 | } 154 | 155 | file_extension = file_extension_map[selected_extension] 156 | 157 | file_path = fd.asksaveasfilename( 158 | parent=self, 159 | initialfile=file_name, 160 | title="Export subtitle", 161 | defaultextension=file_extension, 162 | filetypes=[(f"{selected_extension} Files", "*" + file_extension)] 163 | ) 164 | 165 | if file_path: 166 | dir_name, get_file_name = os.path.split(file_path) 167 | 168 | if file_extension == ".srt": 169 | writer = get_writer("srt", dir_name) 170 | writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50, "max_line_width": 3}) 171 | elif file_extension == ".txt": 172 | txt_writer = get_writer("txt", dir_name) 173 | txt_writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50, 174 | "max_line_width": 3}) 175 | elif file_extension == ".vtt": 176 | vtt_writer = get_writer("vtt", dir_name) 177 | vtt_writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50, 178 | "max_line_width": 3}) 179 | elif file_extension == ".tsv": 180 | tsv_writer = get_writer("tsv", dir_name) 181 | tsv_writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50, 182 | "max_line_width": 3}) 183 | elif file_extension == ".json": 184 | json_writer = get_writer("json", dir_name) 185 | json_writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50, 186 | "max_line_width": 3}) 187 | elif file_extension == ".all": 188 | all_writer = get_writer("all", dir_name) 189 | all_writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50, 190 | "max_line_width": 3}) 191 | 192 | def start_callback(self): 193 | self.close_btn.configure(state="disabled") 194 | widgets = self.main_frame.winfo_children() 195 | for widget in widgets: 196 | widget.destroy() 197 | 198 | self.loader = CTkLoader(parent=self.master, title="Transcribing...", msg="Please wait...", 199 | cancel_func=self.set_signal) 200 | thread = threading.Thread(target=self.start_transcribing, args=(self.audio_path, self.check_signal)) 201 | thread.start() 202 | 203 | def start_transcribing(self, audio_path, check_signal): 204 | transcriber = Transcriber(audio=audio_path) 205 | result = transcriber.audio_recognition(cancel_func=check_signal) 206 | self.loader.hide_loader() 207 | if result: 208 | self.queue.put(result) 209 | self.close_btn.configure(state="normal") 210 | self.result_widget() 211 | 212 | def set_signal(self): 213 | self.cancel_signal = True 214 | 215 | def check_signal(self): 216 | original_value = self.cancel_signal 217 | 218 | if self.cancel_signal: 219 | self.cancel_signal = False 220 | self.close_btn.configure(state="normal") 221 | self.after(1000, self.default_widget) 222 | 223 | return original_value 224 | 225 | def select_file_callback(self): 226 | file_path = fd.askopenfilename( 227 | filetypes=[("Audio files", "*.mp3 *.wav *.ogg *.flac"), ("Video files", "*.mp4 *.mov *.wmv *.avi")]) 228 | if file_path: 229 | audio_path = os.path.abspath(file_path) 230 | if self.is_streamable_audio(audio_path): 231 | self.audio_path = audio_path 232 | widgets = self.main_frame.winfo_children() 233 | 234 | for widget in widgets: 235 | widget.destroy() 236 | self.after(1000, self.task_widget) 237 | else: 238 | CTkAlert(parent=self.master, status="error", title="Error", 239 | msg="The chosen audio file is not valid or streamable.") 240 | 241 | @staticmethod 242 | def is_streamable_audio(audio_path): 243 | if not os.path.isfile(audio_path): 244 | return False 245 | 246 | try: 247 | audio = AudioSegment.from_file(audio_path) 248 | audio_info = File(audio_path) 249 | return len(audio) > 0 and audio_info.info.length > 0 250 | except (FileNotFoundError, exceptions.CouldntDecodeError, MutagenError): 251 | return False 252 | 253 | @staticmethod 254 | def get_audio_info(file_path): 255 | try: 256 | if not os.path.isfile(file_path): 257 | return None, None 258 | 259 | file_name = os.path.basename(file_path) 260 | 261 | audio_info = File(file_path) 262 | if audio_info is None: 263 | return None, None 264 | 265 | duration_seconds = audio_info.info.length 266 | 267 | duration_formatted = time.strftime("%H:%M:%S", time.gmtime(duration_seconds)) 268 | 269 | file_size_mb = os.path.getsize(file_path) / (1024 * 1024) 270 | 271 | return file_name, duration_formatted, file_size_mb 272 | except MutagenError: 273 | return None, None 274 | 275 | def drop(self, event): 276 | dropped_file = event.data.replace("{", "").replace("}", "") 277 | audio_path = os.path.abspath(dropped_file) 278 | if self.is_streamable_audio(audio_path): 279 | self.audio_path = audio_path 280 | widgets = self.main_frame.winfo_children() 281 | 282 | for widget in widgets: 283 | widget.destroy() 284 | 285 | self.after(1000, self.task_widget) 286 | else: 287 | CTkAlert(parent=self.master, status="error", title="Error", 288 | msg="The chosen audio file is not valid or streamable.") 289 | 290 | def hide_transcribe_ui(self): 291 | self.destroy() 292 | -------------------------------------------------------------------------------- /src/whisper/tokenizer.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import string 4 | from dataclasses import dataclass, field 5 | from functools import cached_property, lru_cache 6 | from typing import Dict, List, Optional, Tuple 7 | 8 | import tiktoken 9 | 10 | LANGUAGES = { 11 | "en": "english", 12 | "zh": "chinese", 13 | "de": "german", 14 | "es": "spanish", 15 | "ru": "russian", 16 | "ko": "korean", 17 | "fr": "french", 18 | "ja": "japanese", 19 | "pt": "portuguese", 20 | "tr": "turkish", 21 | "pl": "polish", 22 | "ca": "catalan", 23 | "nl": "dutch", 24 | "ar": "arabic", 25 | "sv": "swedish", 26 | "it": "italian", 27 | "id": "indonesian", 28 | "hi": "hindi", 29 | "fi": "finnish", 30 | "vi": "vietnamese", 31 | "he": "hebrew", 32 | "uk": "ukrainian", 33 | "el": "greek", 34 | "ms": "malay", 35 | "cs": "czech", 36 | "ro": "romanian", 37 | "da": "danish", 38 | "hu": "hungarian", 39 | "ta": "tamil", 40 | "no": "norwegian", 41 | "th": "thai", 42 | "ur": "urdu", 43 | "hr": "croatian", 44 | "bg": "bulgarian", 45 | "lt": "lithuanian", 46 | "la": "latin", 47 | "mi": "maori", 48 | "ml": "malayalam", 49 | "cy": "welsh", 50 | "sk": "slovak", 51 | "te": "telugu", 52 | "fa": "persian", 53 | "lv": "latvian", 54 | "bn": "bengali", 55 | "sr": "serbian", 56 | "az": "azerbaijani", 57 | "sl": "slovenian", 58 | "kn": "kannada", 59 | "et": "estonian", 60 | "mk": "macedonian", 61 | "br": "breton", 62 | "eu": "basque", 63 | "is": "icelandic", 64 | "hy": "armenian", 65 | "ne": "nepali", 66 | "mn": "mongolian", 67 | "bs": "bosnian", 68 | "kk": "kazakh", 69 | "sq": "albanian", 70 | "sw": "swahili", 71 | "gl": "galician", 72 | "mr": "marathi", 73 | "pa": "punjabi", 74 | "si": "sinhala", 75 | "km": "khmer", 76 | "sn": "shona", 77 | "yo": "yoruba", 78 | "so": "somali", 79 | "af": "afrikaans", 80 | "oc": "occitan", 81 | "ka": "georgian", 82 | "be": "belarusian", 83 | "tg": "tajik", 84 | "sd": "sindhi", 85 | "gu": "gujarati", 86 | "am": "amharic", 87 | "yi": "yiddish", 88 | "lo": "lao", 89 | "uz": "uzbek", 90 | "fo": "faroese", 91 | "ht": "haitian creole", 92 | "ps": "pashto", 93 | "tk": "turkmen", 94 | "nn": "nynorsk", 95 | "mt": "maltese", 96 | "sa": "sanskrit", 97 | "lb": "luxembourgish", 98 | "my": "myanmar", 99 | "bo": "tibetan", 100 | "tl": "tagalog", 101 | "mg": "malagasy", 102 | "as": "assamese", 103 | "tt": "tatar", 104 | "haw": "hawaiian", 105 | "ln": "lingala", 106 | "ha": "hausa", 107 | "ba": "bashkir", 108 | "jw": "javanese", 109 | "su": "sundanese", 110 | "yue": "cantonese", 111 | } 112 | 113 | # language code lookup by name, with a few language aliases 114 | TO_LANGUAGE_CODE = { 115 | **{language: code for code, language in LANGUAGES.items()}, 116 | "burmese": "my", 117 | "valencian": "ca", 118 | "flemish": "nl", 119 | "haitian": "ht", 120 | "letzeburgesch": "lb", 121 | "pushto": "ps", 122 | "panjabi": "pa", 123 | "moldavian": "ro", 124 | "moldovan": "ro", 125 | "sinhalese": "si", 126 | "castilian": "es", 127 | "mandarin": "zh", 128 | } 129 | 130 | 131 | @dataclass 132 | class Tokenizer: 133 | """A thin wrapper around `tiktoken` providing quick access to special tokens""" 134 | 135 | encoding: tiktoken.Encoding 136 | num_languages: int 137 | language: Optional[str] = None 138 | task: Optional[str] = None 139 | sot_sequence: Tuple[int] = () 140 | special_tokens: Dict[str, int] = field(default_factory=dict) 141 | 142 | def __post_init__(self): 143 | for special in self.encoding.special_tokens_set: 144 | special_token = self.encoding.encode_single_token(special) 145 | self.special_tokens[special] = special_token 146 | 147 | sot: int = self.special_tokens["<|startoftranscript|>"] 148 | translate: int = self.special_tokens["<|translate|>"] 149 | transcribe: int = self.special_tokens["<|transcribe|>"] 150 | 151 | langs = tuple(LANGUAGES.keys())[: self.num_languages] 152 | sot_sequence = [sot] 153 | if self.language is not None: 154 | sot_sequence.append(sot + 1 + langs.index(self.language)) 155 | if self.task is not None: 156 | task_token: int = transcribe if self.task == "transcribe" else translate 157 | sot_sequence.append(task_token) 158 | 159 | self.sot_sequence = tuple(sot_sequence) 160 | 161 | def encode(self, text, **kwargs): 162 | return self.encoding.encode(text, **kwargs) 163 | 164 | def decode(self, token_ids: List[int], **kwargs) -> str: 165 | token_ids = [t for t in token_ids if t < self.timestamp_begin] 166 | return self.encoding.decode(token_ids, **kwargs) 167 | 168 | def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str: 169 | """ 170 | Timestamp tokens are above other special tokens' id range and are ignored by `decode()`. 171 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>". 172 | """ 173 | return self.encoding.decode(token_ids, **kwargs) 174 | 175 | @cached_property 176 | def eot(self) -> int: 177 | return self.encoding.eot_token 178 | 179 | @cached_property 180 | def transcribe(self) -> int: 181 | return self.special_tokens["<|transcribe|>"] 182 | 183 | @cached_property 184 | def translate(self) -> int: 185 | return self.special_tokens["<|translate|>"] 186 | 187 | @cached_property 188 | def sot(self) -> int: 189 | return self.special_tokens["<|startoftranscript|>"] 190 | 191 | @cached_property 192 | def sot_lm(self) -> int: 193 | return self.special_tokens["<|startoflm|>"] 194 | 195 | @cached_property 196 | def sot_prev(self) -> int: 197 | return self.special_tokens["<|startofprev|>"] 198 | 199 | @cached_property 200 | def no_speech(self) -> int: 201 | return self.special_tokens["<|nospeech|>"] 202 | 203 | @cached_property 204 | def no_timestamps(self) -> int: 205 | return self.special_tokens["<|notimestamps|>"] 206 | 207 | @cached_property 208 | def timestamp_begin(self) -> int: 209 | return self.special_tokens["<|0.00|>"] 210 | 211 | @cached_property 212 | def language_token(self) -> int: 213 | """Returns the token id corresponding to the value of the `language` field""" 214 | if self.language is None: 215 | raise ValueError("This tokenizer does not have language token configured") 216 | 217 | return self.to_language_token(self.language) 218 | 219 | def to_language_token(self, language): 220 | if token := self.special_tokens.get(f"<|{language}|>", None): 221 | return token 222 | 223 | raise KeyError(f"Language {language} not found in tokenizer.") 224 | 225 | @cached_property 226 | def all_language_tokens(self) -> Tuple[int]: 227 | result = [] 228 | for token, token_id in self.special_tokens.items(): 229 | if token.strip("<|>") in LANGUAGES: 230 | result.append(token_id) 231 | return tuple(result)[: self.num_languages] 232 | 233 | @cached_property 234 | def all_language_codes(self) -> Tuple[str]: 235 | return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens) 236 | 237 | @cached_property 238 | def sot_sequence_including_notimestamps(self) -> Tuple[int]: 239 | return tuple(list(self.sot_sequence) + [self.no_timestamps]) 240 | 241 | @cached_property 242 | def non_speech_tokens(self) -> Tuple[int]: 243 | """ 244 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech 245 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g. 246 | 247 | - ♪♪♪ 248 | - ( SPEAKING FOREIGN LANGUAGE ) 249 | - [DAVID] Hey there, 250 | 251 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc. 252 | """ 253 | symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』') 254 | symbols += ( 255 | "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split() 256 | ) 257 | 258 | # symbols that may be a single token or multiple tokens depending on the tokenizer. 259 | # In case they're multiple tokens, suppress the first token, which is safe because: 260 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress 261 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes. 262 | miscellaneous = set("♩♪♫♬♭♮♯") 263 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous) 264 | 265 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word 266 | result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]} 267 | for symbol in symbols + list(miscellaneous): 268 | for tokens in [ 269 | self.encoding.encode(symbol), 270 | self.encoding.encode(" " + symbol), 271 | ]: 272 | if len(tokens) == 1 or symbol in miscellaneous: 273 | result.add(tokens[0]) 274 | 275 | return tuple(sorted(result)) 276 | 277 | def split_to_word_tokens(self, tokens: List[int]): 278 | if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: 279 | # These languages don't typically use spaces, so it is difficult to split words 280 | # without morpheme analysis. Here, we instead split words at any 281 | # position where the tokens are decoded as valid unicode points 282 | return self.split_tokens_on_unicode(tokens) 283 | 284 | return self.split_tokens_on_spaces(tokens) 285 | 286 | def split_tokens_on_unicode(self, tokens: List[int]): 287 | decoded_full = self.decode_with_timestamps(tokens) 288 | replacement_char = "\ufffd" 289 | 290 | words = [] 291 | word_tokens = [] 292 | current_tokens = [] 293 | unicode_offset = 0 294 | 295 | for token in tokens: 296 | current_tokens.append(token) 297 | decoded = self.decode_with_timestamps(current_tokens) 298 | 299 | if ( 300 | replacement_char not in decoded 301 | or decoded_full[unicode_offset + decoded.index(replacement_char)] 302 | == replacement_char 303 | ): 304 | words.append(decoded) 305 | word_tokens.append(current_tokens) 306 | current_tokens = [] 307 | unicode_offset += len(decoded) 308 | 309 | return words, word_tokens 310 | 311 | def split_tokens_on_spaces(self, tokens: List[int]): 312 | subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens) 313 | words = [] 314 | word_tokens = [] 315 | 316 | for subword, subword_tokens in zip(subwords, subword_tokens_list): 317 | special = subword_tokens[0] >= self.eot 318 | with_space = subword.startswith(" ") 319 | punctuation = subword.strip() in string.punctuation 320 | if special or with_space or punctuation or len(words) == 0: 321 | words.append(subword) 322 | word_tokens.append(subword_tokens) 323 | else: 324 | words[-1] = words[-1] + subword 325 | word_tokens[-1].extend(subword_tokens) 326 | 327 | return words, word_tokens 328 | 329 | 330 | @lru_cache(maxsize=None) 331 | def get_encoding(name: str = "gpt2", num_languages: int = 99): 332 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") 333 | ranks = { 334 | base64.b64decode(token): int(rank) 335 | for token, rank in (line.split() for line in open(vocab_path) if line) 336 | } 337 | n_vocab = len(ranks) 338 | special_tokens = {} 339 | 340 | specials = [ 341 | "<|endoftext|>", 342 | "<|startoftranscript|>", 343 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], 344 | "<|translate|>", 345 | "<|transcribe|>", 346 | "<|startoflm|>", 347 | "<|startofprev|>", 348 | "<|nospeech|>", 349 | "<|notimestamps|>", 350 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], 351 | ] 352 | 353 | for token in specials: 354 | special_tokens[token] = n_vocab 355 | n_vocab += 1 356 | 357 | return tiktoken.Encoding( 358 | name=os.path.basename(vocab_path), 359 | explicit_n_vocab=n_vocab, 360 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", 361 | mergeable_ranks=ranks, 362 | special_tokens=special_tokens, 363 | ) 364 | 365 | 366 | @lru_cache(maxsize=None) 367 | def get_tokenizer( 368 | multilingual: bool, 369 | *, 370 | num_languages: int = 99, 371 | language: Optional[str] = None, 372 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 373 | ) -> Tokenizer: 374 | if language is not None: 375 | language = language.lower() 376 | if language not in LANGUAGES: 377 | if language in TO_LANGUAGE_CODE: 378 | language = TO_LANGUAGE_CODE[language] 379 | else: 380 | raise ValueError(f"Unsupported language: {language}") 381 | 382 | if multilingual: 383 | encoding_name = "multilingual" 384 | language = language or "en" 385 | task = task or "transcribe" 386 | else: 387 | encoding_name = "gpt2" 388 | language = None 389 | task = None 390 | 391 | encoding = get_encoding(name=encoding_name, num_languages=num_languages) 392 | 393 | return Tokenizer( 394 | encoding=encoding, num_languages=num_languages, language=language, task=task 395 | ) 396 | -------------------------------------------------------------------------------- /src/whisper/timing.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import subprocess 3 | import warnings 4 | from dataclasses import dataclass 5 | from typing import TYPE_CHECKING, List 6 | 7 | import numba 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND 13 | from .tokenizer import Tokenizer 14 | 15 | if TYPE_CHECKING: 16 | from .model import Whisper 17 | 18 | 19 | def median_filter(x: torch.Tensor, filter_width: int): 20 | """Apply a median filter of width `filter_width` along the last dimension of `x`""" 21 | pad_width = filter_width // 2 22 | if x.shape[-1] <= pad_width: 23 | # F.pad requires the padding width to be smaller than the input dimension 24 | return x 25 | 26 | if (ndim := x.ndim) <= 2: 27 | # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D 28 | x = x[None, None, :] 29 | 30 | assert ( 31 | filter_width > 0 and filter_width % 2 == 1 32 | ), "`filter_width` should be an odd number" 33 | 34 | result = None 35 | x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") 36 | if x.is_cuda: 37 | try: 38 | from .triton_ops import median_filter_cuda 39 | 40 | result = median_filter_cuda(x, filter_width) 41 | except (RuntimeError, subprocess.CalledProcessError): 42 | warnings.warn( 43 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " 44 | "falling back to a slower median kernel implementation..." 45 | ) 46 | 47 | if result is None: 48 | # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) 49 | result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2] 50 | 51 | if ndim <= 2: 52 | result = result[0, 0] 53 | 54 | return result 55 | 56 | 57 | @numba.jit(nopython=True) 58 | def backtrace(trace: np.ndarray): 59 | i = trace.shape[0] - 1 60 | j = trace.shape[1] - 1 61 | trace[0, :] = 2 62 | trace[:, 0] = 1 63 | 64 | result = [] 65 | while i > 0 or j > 0: 66 | result.append((i - 1, j - 1)) 67 | 68 | if trace[i, j] == 0: 69 | i -= 1 70 | j -= 1 71 | elif trace[i, j] == 1: 72 | i -= 1 73 | elif trace[i, j] == 2: 74 | j -= 1 75 | else: 76 | raise ValueError("Unexpected trace[i, j]") 77 | 78 | result = np.array(result) 79 | return result[::-1, :].T 80 | 81 | 82 | @numba.jit(nopython=True, parallel=True) 83 | def dtw_cpu(x: np.ndarray): 84 | N, M = x.shape 85 | cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf 86 | trace = -np.ones((N + 1, M + 1), dtype=np.float32) 87 | 88 | cost[0, 0] = 0 89 | for j in range(1, M + 1): 90 | for i in range(1, N + 1): 91 | c0 = cost[i - 1, j - 1] 92 | c1 = cost[i - 1, j] 93 | c2 = cost[i, j - 1] 94 | 95 | if c0 < c1 and c0 < c2: 96 | c, t = c0, 0 97 | elif c1 < c0 and c1 < c2: 98 | c, t = c1, 1 99 | else: 100 | c, t = c2, 2 101 | 102 | cost[i, j] = x[i - 1, j - 1] + c 103 | trace[i, j] = t 104 | 105 | return backtrace(trace) 106 | 107 | 108 | def dtw_cuda(x, BLOCK_SIZE=1024): 109 | from .triton_ops import dtw_kernel 110 | 111 | M, N = x.shape 112 | assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}" 113 | 114 | x_skew = ( 115 | F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M) 116 | ) 117 | x_skew = x_skew.T.contiguous() 118 | cost = torch.ones(N + M + 2, M + 2) * np.inf 119 | cost[0, 0] = 0 120 | cost = cost.cuda() 121 | trace = torch.zeros_like(cost, dtype=torch.int32) 122 | 123 | dtw_kernel[(1,)]( 124 | cost, 125 | trace, 126 | x_skew, 127 | x_skew.stride(0), 128 | cost.stride(0), 129 | trace.stride(0), 130 | N, 131 | M, 132 | BLOCK_SIZE=BLOCK_SIZE, 133 | ) 134 | 135 | trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[ 136 | :, : N + 1 137 | ] 138 | return backtrace(trace.cpu().numpy()) 139 | 140 | 141 | def dtw(x: torch.Tensor) -> np.ndarray: 142 | if x.is_cuda: 143 | try: 144 | return dtw_cuda(x) 145 | except (RuntimeError, subprocess.CalledProcessError): 146 | warnings.warn( 147 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; " 148 | "falling back to a slower DTW implementation..." 149 | ) 150 | 151 | return dtw_cpu(x.double().cpu().numpy()) 152 | 153 | 154 | @dataclass 155 | class WordTiming: 156 | word: str 157 | tokens: List[int] 158 | start: float 159 | end: float 160 | probability: float 161 | 162 | 163 | def find_alignment( 164 | model: "Whisper", 165 | tokenizer: Tokenizer, 166 | text_tokens: List[int], 167 | mel: torch.Tensor, 168 | num_frames: int, 169 | *, 170 | medfilt_width: int = 7, 171 | qk_scale: float = 1.0, 172 | ) -> List[WordTiming]: 173 | if len(text_tokens) == 0: 174 | return [] 175 | 176 | tokens = torch.tensor( 177 | [ 178 | *tokenizer.sot_sequence, 179 | tokenizer.no_timestamps, 180 | *text_tokens, 181 | tokenizer.eot, 182 | ] 183 | ).to(model.device) 184 | 185 | # install hooks on the cross attention layers to retrieve the attention weights 186 | QKs = [None] * model.dims.n_text_layer 187 | hooks = [ 188 | block.cross_attn.register_forward_hook( 189 | lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0]) 190 | ) 191 | for i, block in enumerate(model.decoder.blocks) 192 | ] 193 | 194 | with torch.no_grad(): 195 | logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0] 196 | sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot] 197 | token_probs = sampled_logits.softmax(dim=-1) 198 | text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens] 199 | text_token_probs = text_token_probs.tolist() 200 | 201 | for hook in hooks: 202 | hook.remove() 203 | 204 | # heads * tokens * frames 205 | weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T]) 206 | weights = weights[:, :, : num_frames // 2] 207 | weights = (weights * qk_scale).softmax(dim=-1) 208 | std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) 209 | weights = (weights - mean) / std 210 | weights = median_filter(weights, medfilt_width) 211 | 212 | matrix = weights.mean(axis=0) 213 | matrix = matrix[len(tokenizer.sot_sequence) : -1] 214 | text_indices, time_indices = dtw(-matrix) 215 | 216 | words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot]) 217 | if len(word_tokens) <= 1: 218 | # return on eot only 219 | # >>> np.pad([], (1, 0)) 220 | # array([0.]) 221 | # This results in crashes when we lookup jump_times with float, like 222 | # IndexError: arrays used as indices must be of integer (or boolean) type 223 | return [] 224 | word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0)) 225 | 226 | jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) 227 | jump_times = time_indices[jumps] / TOKENS_PER_SECOND 228 | start_times = jump_times[word_boundaries[:-1]] 229 | end_times = jump_times[word_boundaries[1:]] 230 | word_probabilities = [ 231 | np.mean(text_token_probs[i:j]) 232 | for i, j in zip(word_boundaries[:-1], word_boundaries[1:]) 233 | ] 234 | 235 | return [ 236 | WordTiming(word, tokens, start, end, probability) 237 | for word, tokens, start, end, probability in zip( 238 | words, word_tokens, start_times, end_times, word_probabilities 239 | ) 240 | ] 241 | 242 | 243 | def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str): 244 | # merge prepended punctuations 245 | i = len(alignment) - 2 246 | j = len(alignment) - 1 247 | while i >= 0: 248 | previous = alignment[i] 249 | following = alignment[j] 250 | if previous.word.startswith(" ") and previous.word.strip() in prepended: 251 | # prepend it to the following word 252 | following.word = previous.word + following.word 253 | following.tokens = previous.tokens + following.tokens 254 | previous.word = "" 255 | previous.tokens = [] 256 | else: 257 | j = i 258 | i -= 1 259 | 260 | # merge appended punctuations 261 | i = 0 262 | j = 1 263 | while j < len(alignment): 264 | previous = alignment[i] 265 | following = alignment[j] 266 | if not previous.word.endswith(" ") and following.word in appended: 267 | # append it to the previous word 268 | previous.word = previous.word + following.word 269 | previous.tokens = previous.tokens + following.tokens 270 | following.word = "" 271 | following.tokens = [] 272 | else: 273 | i = j 274 | j += 1 275 | 276 | 277 | def add_word_timestamps( 278 | *, 279 | segments: List[dict], 280 | model: "Whisper", 281 | tokenizer: Tokenizer, 282 | mel: torch.Tensor, 283 | num_frames: int, 284 | prepend_punctuations: str = "\"'“¿([{-", 285 | append_punctuations: str = "\"'.。,,!!??::”)]}、", 286 | last_speech_timestamp: float, 287 | **kwargs, 288 | ): 289 | if len(segments) == 0: 290 | return 291 | 292 | text_tokens_per_segment = [ 293 | [token for token in segment["tokens"] if token < tokenizer.eot] 294 | for segment in segments 295 | ] 296 | 297 | text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment)) 298 | alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs) 299 | word_durations = np.array([t.end - t.start for t in alignment]) 300 | word_durations = word_durations[word_durations.nonzero()] 301 | median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0 302 | max_duration = median_duration * 2 303 | 304 | # hack: truncate long words at sentence boundaries. 305 | # a better segmentation algorithm based on VAD should be able to replace this. 306 | if len(word_durations) > 0: 307 | sentence_end_marks = ".。!!??" 308 | # ensure words at sentence boundaries are not longer than twice the median word duration. 309 | for i in range(1, len(alignment)): 310 | if alignment[i].end - alignment[i].start > max_duration: 311 | if alignment[i].word in sentence_end_marks: 312 | alignment[i].end = alignment[i].start + max_duration 313 | elif alignment[i - 1].word in sentence_end_marks: 314 | alignment[i].start = alignment[i].end - max_duration 315 | 316 | merge_punctuations(alignment, prepend_punctuations, append_punctuations) 317 | 318 | time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE 319 | word_index = 0 320 | 321 | for segment, text_tokens in zip(segments, text_tokens_per_segment): 322 | saved_tokens = 0 323 | words = [] 324 | 325 | while word_index < len(alignment) and saved_tokens < len(text_tokens): 326 | timing = alignment[word_index] 327 | 328 | if timing.word: 329 | words.append( 330 | dict( 331 | word=timing.word, 332 | start=round(time_offset + timing.start, 2), 333 | end=round(time_offset + timing.end, 2), 334 | probability=timing.probability, 335 | ) 336 | ) 337 | 338 | saved_tokens += len(timing.tokens) 339 | word_index += 1 340 | 341 | # hack: truncate long words at segment boundaries. 342 | # a better segmentation algorithm based on VAD should be able to replace this. 343 | if len(words) > 0: 344 | # ensure the first and second word after a pause is not longer than 345 | # twice the median word duration. 346 | if words[0]["end"] - last_speech_timestamp > median_duration * 4 and ( 347 | words[0]["end"] - words[0]["start"] > max_duration 348 | or ( 349 | len(words) > 1 350 | and words[1]["end"] - words[0]["start"] > max_duration * 2 351 | ) 352 | ): 353 | if ( 354 | len(words) > 1 355 | and words[1]["end"] - words[1]["start"] > max_duration 356 | ): 357 | boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration) 358 | words[0]["end"] = words[1]["start"] = boundary 359 | words[0]["start"] = max(0, words[0]["end"] - max_duration) 360 | 361 | # prefer the segment-level start timestamp if the first word is too long. 362 | if ( 363 | segment["start"] < words[0]["end"] 364 | and segment["start"] - 0.5 > words[0]["start"] 365 | ): 366 | words[0]["start"] = max( 367 | 0, min(words[0]["end"] - median_duration, segment["start"]) 368 | ) 369 | else: 370 | segment["start"] = words[0]["start"] 371 | 372 | # prefer the segment-level end timestamp if the last word is too long. 373 | if ( 374 | segment["end"] > words[-1]["start"] 375 | and segment["end"] + 0.5 < words[-1]["end"] 376 | ): 377 | words[-1]["end"] = max( 378 | words[-1]["start"] + median_duration, segment["end"] 379 | ) 380 | else: 381 | segment["end"] = words[-1]["end"] 382 | 383 | last_speech_timestamp = segment["end"] 384 | 385 | segment["words"] = words 386 | -------------------------------------------------------------------------------- /src/ui/ctkdropdown.py: -------------------------------------------------------------------------------- 1 | import difflib 2 | import sys 3 | import time 4 | 5 | import customtkinter as ctk 6 | 7 | 8 | class CTkScrollableDropdownFrame(ctk.CTkToplevel): 9 | 10 | def __init__(self, attach, x=None, y=None, button_color=None, height: int = 200, width: int = None, 11 | fg_color=None, button_height: int = 20, justify="center", scrollbar_button_color=None, 12 | scrollbar=True, scrollbar_button_hover_color=None, frame_border_width=2, values=[], 13 | command=None, image_values=[], alpha: float = 0.97, frame_corner_radius=20, double_click=False, 14 | resize=True, frame_border_color=None, text_color=None, autocomplete=False, **button_kwargs): 15 | 16 | super().__init__(takefocus=1) 17 | 18 | self.focus() 19 | self.lift() 20 | self.alpha = alpha 21 | self.attach = attach 22 | self.corner = frame_corner_radius 23 | self.padding = 0 24 | self.focus_something = False 25 | self.disable = True 26 | self.update() 27 | 28 | if sys.platform.startswith("win"): 29 | self.after(100, lambda: self.overrideredirect(True)) 30 | self.transparent_color = self._apply_appearance_mode(self._fg_color) 31 | self.attributes("-transparentcolor", self.transparent_color) 32 | elif sys.platform.startswith("darwin"): 33 | self.overrideredirect(True) 34 | self.transparent_color = 'systemTransparent' 35 | self.attributes("-transparent", True) 36 | self.focus_something = True 37 | else: 38 | self.overrideredirect(True) 39 | self.transparent_color = '#000001' 40 | self.corner = 0 41 | self.padding = 18 42 | self.withdraw() 43 | 44 | self.hide = True 45 | self.attach.bind('', lambda e: self._withdraw() if not self.disable else None, add="+") 46 | self.attach.winfo_toplevel().bind('', lambda e: self._withdraw() if not self.disable else None, 47 | add="+") 48 | self.attach.winfo_toplevel().bind("", lambda e: self._withdraw() if not self.disable else None, 49 | add="+") 50 | 51 | self.attributes('-alpha', 0) 52 | self.disable = False 53 | self.fg_color = ctk.ThemeManager.theme["CTkFrame"]["fg_color"] if fg_color is None else fg_color 54 | self.scroll_button_color = ctk.ThemeManager.theme["CTkScrollbar"][ 55 | "button_color"] if scrollbar_button_color is None else scrollbar_button_color 56 | self.scroll_hover_color = ctk.ThemeManager.theme["CTkScrollbar"][ 57 | "button_hover_color"] if scrollbar_button_hover_color is None else scrollbar_button_hover_color 58 | self.frame_border_color = ctk.ThemeManager.theme["CTkFrame"][ 59 | "border_color"] if frame_border_color is None else frame_border_color 60 | self.button_color = ctk.ThemeManager.theme["CTkFrame"][ 61 | "top_fg_color"] if button_color is None else button_color 62 | self.text_color = ctk.ThemeManager.theme["CTkLabel"][ 63 | "text_color"] if text_color is None else text_color 64 | 65 | if scrollbar is False: 66 | self.scroll_button_color = self.fg_color 67 | self.scroll_hover_color = self.fg_color 68 | 69 | self.frame = ctk.CTkScrollableFrame(self, bg_color=self.transparent_color, fg_color=self.fg_color, 70 | scrollbar_button_hover_color=self.scroll_hover_color, 71 | corner_radius=self.corner, border_width=frame_border_width, 72 | scrollbar_button_color=self.scroll_button_color, 73 | border_color=self.frame_border_color) 74 | self.frame._scrollbar.grid_configure(padx=3) 75 | self.frame.pack(expand=True, fill="both") 76 | self.dummy_entry = ctk.CTkEntry(self.frame, fg_color="transparent", border_width=0, height=1, width=1) 77 | self.no_match = ctk.CTkLabel(self.frame, text="No Match") 78 | self.height = height 79 | self.height_new = height 80 | self.width = width 81 | self.command = command 82 | self.fade = False 83 | self.resize = resize 84 | self.autocomplete = autocomplete 85 | self.var_update = ctk.StringVar() 86 | self.appear = False 87 | 88 | if justify.lower() == "left": 89 | self.justify = "w" 90 | elif justify.lower() == "right": 91 | self.justify = "e" 92 | else: 93 | self.justify = "c" 94 | 95 | self.button_height = button_height 96 | self.values = values 97 | self.button_num = len(self.values) 98 | self.image_values = None if len(image_values) != len(self.values) else image_values 99 | 100 | self.resizable(width=False, height=False) 101 | self.transient(self.master) 102 | self._init_buttons(**button_kwargs) 103 | 104 | # Add binding for different ctk widgets 105 | if double_click or self.attach.winfo_name().startswith("!ctkentry") or self.attach.winfo_name().startswith( 106 | "!ctkcombobox"): 107 | self.attach.bind('', lambda e: self._iconify(), add="+") 108 | else: 109 | self.attach.bind('', lambda e: self._iconify(), add="+") 110 | 111 | if self.attach.winfo_name().startswith("!ctkcombobox"): 112 | self.attach._canvas.tag_bind("right_parts", "", lambda e: self._iconify()) 113 | self.attach._canvas.tag_bind("dropdown_arrow", "", lambda e: self._iconify()) 114 | if self.command is None: 115 | self.command = self.attach.set 116 | 117 | if self.attach.winfo_name().startswith("!ctkoptionmenu"): 118 | self.attach._canvas.bind("", lambda e: self._iconify()) 119 | self.attach._text_label.bind("", lambda e: self._iconify()) 120 | if self.command is None: 121 | self.command = self.attach.set 122 | 123 | self.attach.bind("", lambda _: self._destroy(), add="+") 124 | 125 | self.update_idletasks() 126 | self.x = x 127 | self.y = y 128 | 129 | if self.autocomplete: 130 | self.bind_autocomplete() 131 | 132 | self.deiconify() 133 | self.withdraw() 134 | 135 | self.attributes("-alpha", self.alpha) 136 | 137 | def _destroy(self): 138 | self.after(500, self.destroy_popup) 139 | 140 | def _withdraw(self): 141 | if self.winfo_viewable() and self.hide: 142 | self.withdraw() 143 | 144 | self.event_generate("<>") 145 | self.hide = True 146 | 147 | def _update(self, a, b, c): 148 | self.live_update(self.attach._entry.get()) 149 | 150 | def bind_autocomplete(self, ): 151 | def appear(x): 152 | self.appear = True 153 | 154 | if self.attach.winfo_name().startswith("!ctkcombobox"): 155 | self.attach._entry.configure(textvariable=self.var_update) 156 | self.attach._entry.bind("", appear) 157 | self.attach.set(self.values[0]) 158 | self.var_update.trace_add('write', self._update) 159 | 160 | if self.attach.winfo_name().startswith("!ctkentry"): 161 | self.attach.configure(textvariable=self.var_update) 162 | self.attach.bind("", appear) 163 | self.var_update.trace_add('write', self._update) 164 | 165 | def fade_out(self): 166 | for i in range(100, 0, -10): 167 | if not self.winfo_exists(): 168 | break 169 | self.attributes("-alpha", i / 100) 170 | self.update() 171 | time.sleep(1 / 100) 172 | 173 | def fade_in(self): 174 | for i in range(0, 100, 10): 175 | if not self.winfo_exists(): 176 | break 177 | self.attributes("-alpha", i / 100) 178 | self.update() 179 | time.sleep(1 / 100) 180 | 181 | def _init_buttons(self, **button_kwargs): 182 | self.i = 0 183 | self.widgets = {} 184 | for row in self.values: 185 | self.widgets[self.i] = ctk.CTkButton(self.frame, 186 | text=row, 187 | height=self.button_height, 188 | fg_color=self.button_color, 189 | text_color=self.text_color, 190 | image=self.image_values[ 191 | self.i] if self.image_values is not None else None, 192 | anchor=self.justify, 193 | command=lambda k=row: self._attach_key_press(k), 194 | **button_kwargs) 195 | self.widgets[self.i].pack(fill="x", pady=2, padx=(self.padding, 0)) 196 | self.i += 1 197 | 198 | self.hide = False 199 | 200 | def destroy_popup(self): 201 | self.disable = True 202 | self.destroy() 203 | 204 | def place_dropdown(self): 205 | self.x_pos = self.attach.winfo_rootx() if self.x is None else self.x + self.attach.winfo_rootx() 206 | self.y_pos = self.attach.winfo_rooty() + self.attach.winfo_reqheight() + 5 if self.y is None else self.y + self.attach.winfo_rooty() 207 | self.width_new = self.attach.winfo_width() if self.width is None else self.width 208 | 209 | if self.resize: 210 | if self.button_num <= 5: 211 | self.height_new = self.button_height * self.button_num + 55 212 | else: 213 | self.height_new = self.button_height * self.button_num + 35 214 | if self.height_new > self.height: 215 | self.height_new = self.height 216 | 217 | self.geometry('{}x{}+{}+{}'.format(self.width_new, self.height_new, 218 | self.x_pos, self.y_pos)) 219 | self.fade_in() 220 | self.attributes('-alpha', self.alpha) 221 | self.attach.focus() 222 | 223 | def _iconify(self): 224 | if self.disable: return 225 | if self.hide: 226 | self.event_generate("<>") 227 | self._deiconify() 228 | self.focus() 229 | self.hide = False 230 | self.place_dropdown() 231 | if self.focus_something: 232 | self.dummy_entry.pack() 233 | self.dummy_entry.focus_set() 234 | self.after(100, self.dummy_entry.pack_forget) 235 | else: 236 | self.withdraw() 237 | self.hide = True 238 | 239 | def _attach_key_press(self, k): 240 | self.event_generate("<>") 241 | self.fade = True 242 | if self.command: 243 | self.command(k) 244 | self.fade = False 245 | self.fade_out() 246 | self.withdraw() 247 | self.hide = True 248 | 249 | def live_update(self, string=None): 250 | if not self.appear: return 251 | if self.disable: return 252 | if self.fade: return 253 | if string: 254 | string = string.lower() 255 | self._deiconify() 256 | i = 1 257 | for key in self.widgets.keys(): 258 | s = self.widgets[key].cget("text").lower() 259 | text_similarity = difflib.SequenceMatcher(None, s[0:len(string)], string).ratio() 260 | similar = s.startswith(string) or text_similarity > 0.75 261 | if not similar: 262 | self.widgets[key].pack_forget() 263 | else: 264 | self.widgets[key].pack(fill="x", pady=2, padx=(self.padding, 0)) 265 | i += 1 266 | 267 | if i == 1: 268 | self.no_match.pack(fill="x", pady=2, padx=(self.padding, 0)) 269 | else: 270 | self.no_match.pack_forget() 271 | self.button_num = i 272 | self.place_dropdown() 273 | 274 | else: 275 | self.no_match.pack_forget() 276 | self.button_num = len(self.values) 277 | for key in self.widgets.keys(): 278 | self.widgets[key].destroy() 279 | self._init_buttons() 280 | self.place_dropdown() 281 | 282 | self.frame._parent_canvas.yview_moveto(0.0) 283 | self.appear = False 284 | 285 | def insert(self, value, **kwargs): 286 | self.widgets[self.i] = ctk.CTkButton(self.frame, 287 | text=value, 288 | height=self.button_height, 289 | fg_color=self.button_color, 290 | text_color=self.text_color, 291 | anchor=self.justify, 292 | command=lambda k=value: self._attach_key_press(k), **kwargs) 293 | self.widgets[self.i].pack(fill="x", pady=2, padx=(self.padding, 0)) 294 | self.i += 1 295 | self.values.append(value) 296 | 297 | def _deiconify(self): 298 | if len(self.values) > 0: 299 | self.deiconify() 300 | 301 | def popup(self, x=None, y=None): 302 | self.x = x 303 | self.y = y 304 | self.hide = True 305 | self._iconify() 306 | 307 | def configure(self, **kwargs): 308 | if "height" in kwargs: 309 | self.height = kwargs.pop("height") 310 | self.height_new = self.height 311 | 312 | if "alpha" in kwargs: 313 | self.alpha = kwargs.pop("alpha") 314 | 315 | if "width" in kwargs: 316 | self.width = kwargs.pop("width") 317 | 318 | if "fg_color" in kwargs: 319 | self.frame.configure(fg_color=kwargs.pop("fg_color")) 320 | 321 | if "values" in kwargs: 322 | self.values = kwargs.pop("values") 323 | self.image_values = None 324 | self.button_num = len(self.values) 325 | for key in self.widgets.keys(): 326 | self.widgets[key].destroy() 327 | self._init_buttons() 328 | 329 | if "image_values" in kwargs: 330 | self.image_values = kwargs.pop("image_values") 331 | self.image_values = None if len(self.image_values) != len(self.values) else self.image_values 332 | if self.image_values is not None: 333 | i = 0 334 | for key in self.widgets.keys(): 335 | self.widgets[key].configure(image=self.image_values[i]) 336 | i += 1 337 | 338 | if "button_color" in kwargs: 339 | for key in self.widgets.keys(): 340 | self.widgets[key].configure(fg_color=kwargs.pop("button_color")) 341 | 342 | for key in self.widgets.keys(): 343 | self.widgets[key].configure(**kwargs) 344 | -------------------------------------------------------------------------------- /src/ui/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import sys 4 | import threading 5 | import webbrowser 6 | 7 | import customtkinter as ctk 8 | 9 | from .ctkAlert import CTkAlert 10 | from .ctkLoader import CTkLoader 11 | from .ctkdropdown import CTkScrollableDropdownFrame 12 | from .icons import icons 13 | from .style import FONTS, DROPDOWN 14 | from ..logic import GPUInfo, SettingsHandler, ModelRequirements 15 | 16 | current_path = os.path.dirname(os.path.realpath(__file__)) 17 | 18 | languages = [ 19 | "Auto Detection", 20 | "English", 21 | "Chinese", 22 | "German", 23 | "Spanish", 24 | "Russian", 25 | "Korean", 26 | "French", 27 | "Japanese", 28 | "Portuguese", 29 | "Turkish", 30 | "Polish", 31 | "Catalan", 32 | "Dutch", 33 | "Arabic", 34 | "Swedish", 35 | "Italian", 36 | "Indonesian", 37 | "Hindi", 38 | "Finnish", 39 | "Vietnamese", 40 | "Hebrew", 41 | "Ukrainian", 42 | "Greek", 43 | "Malay", 44 | "Czech", 45 | "Romanian", 46 | "Danish", 47 | "Hungarian", 48 | "Tamil", 49 | "Norwegian", 50 | "Thai", 51 | "Urdu", 52 | "Croatian", 53 | "Bulgarian", 54 | "Lithuanian", 55 | "Latin", 56 | "Maori", 57 | "Malayalam", 58 | "Welsh", 59 | "Slovak", 60 | "Telugu", 61 | "Persian", 62 | "Latvian", 63 | "Bengali", 64 | "Serbian", 65 | "Azerbaijani", 66 | "Slovenian", 67 | "Kannada", 68 | "Estonian", 69 | "Macedonian", 70 | "Breton", 71 | "Basque", 72 | "Icelandic", 73 | "Armenian", 74 | "Nepali", 75 | "Mongolian", 76 | "Bosnian", 77 | "Kazakh", 78 | "Albanian", 79 | "Swahili", 80 | "Galician", 81 | "Marathi", 82 | "Punjabi", 83 | "Sinhala", 84 | "Khmer", 85 | "Shona", 86 | "Yoruba", 87 | "Somali", 88 | "Afrikaans", 89 | "Occitan", 90 | "Georgian", 91 | "Belarusian", 92 | "Tajik", 93 | "Sindhi", 94 | "Gujarati", 95 | "Amharic", 96 | "Yiddish", 97 | "Lao", 98 | "Uzbek", 99 | "Faroese", 100 | "Haitian creole", 101 | "Pashto", 102 | "Turkmen", 103 | "Nynorsk", 104 | "Maltese", 105 | "Sanskrit", 106 | "Luxembourgish", 107 | "Myanmar", 108 | "Tibetan", 109 | "Tagalog", 110 | "Malagasy", 111 | "Assamese", 112 | "Tatar", 113 | "Hawaiian", 114 | "Lingala", 115 | "Hausa", 116 | "Bashkir", 117 | "Javanese", 118 | "Sundanese", 119 | ] 120 | 121 | 122 | def help_link(): 123 | webbrowser.open("https://github.com/rudymohammadbali/Whisper-Transcriber/discussions/categories/q-a") 124 | 125 | 126 | class SettingsUI(ctk.CTkFrame): 127 | def __init__(self, parent): 128 | super().__init__(master=parent, width=620, height=720, fg_color=("#F2F0EE", "#1E1F22"), border_width=0) 129 | self.grid_propagate(False) 130 | self.grid_columnconfigure(0, weight=1) 131 | 132 | self.master = parent 133 | self.settings_handler = SettingsHandler() 134 | 135 | self.theme_btn = None 136 | self.color_theme_btn = None 137 | self.device_dropdown = None 138 | self.language_dropdown = None 139 | self.model_dropdown = None 140 | self.theme_dropdown = None 141 | self.color_theme_dropdown = None 142 | self.general_frame = None 143 | self.model_frame = None 144 | self.gpu_frame = None 145 | self.reset_btn = None 146 | self.mic_dropdown = None 147 | self.mic_btn = None 148 | 149 | self.loader = None 150 | self.process = None 151 | self.thread = None 152 | 153 | title = ctk.CTkLabel(self, text=" Settings", font=FONTS["title"], image=icons["settings"], compound="left") 154 | title.grid(row=0, column=0, padx=20, pady=20, sticky="w") 155 | 156 | self.close_btn = ctk.CTkButton(self, text="", image=icons["close"], fg_color="transparent", hover=False, 157 | width=30, 158 | height=30, command=self.hide_settings_ui) 159 | self.close_btn.grid(row=0, column=1, padx=20, pady=20, sticky="e") 160 | 161 | self.main_frame = ctk.CTkFrame(self, fg_color="transparent") 162 | self.main_frame.grid(row=1, column=0, padx=20, pady=0, sticky="nsew", columnspan=2) 163 | self.main_frame.grid_columnconfigure(0, weight=1) 164 | 165 | self.default_widget() 166 | 167 | self.reset_btn = ctk.CTkButton(self, text="Reset", height=35, command=self.reset_callback, font=FONTS["normal"]) 168 | self.reset_btn.grid(row=3, column=0, padx=20, pady=0, sticky="w") 169 | 170 | self.grid(row=0, column=0, sticky="nsew") 171 | 172 | def default_widget(self): 173 | segmented_button = ctk.CTkSegmentedButton(self.main_frame, values=["General", "Model Settings", "GPU Info"], 174 | width=150, 175 | height=35, command=self.update_view, border_width=0, 176 | font=FONTS["normal"]) 177 | segmented_button.grid(row=0, column=0, padx=0, pady=20, sticky="nsew", columnspan=3) 178 | segmented_button.set("General") 179 | 180 | self.general_frame = ctk.CTkFrame(self.main_frame) 181 | self.general_frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew", columnspan=2) 182 | self.general_frame.grid_columnconfigure(0, weight=1) 183 | self.general_widget() 184 | 185 | def update_view(self, frame): 186 | try: 187 | self.general_frame.grid_forget() 188 | self.model_frame.grid_forget() 189 | self.gpu_frame.grid_forget() 190 | except AttributeError: 191 | pass 192 | 193 | if frame == "General": 194 | self.general_frame = ctk.CTkFrame(self.main_frame) 195 | self.general_frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew", columnspan=2) 196 | self.general_frame.grid_columnconfigure(0, weight=1) 197 | self.general_widget() 198 | 199 | elif frame == "Model Settings": 200 | self.model_frame = ctk.CTkFrame(self.main_frame) 201 | self.model_frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew", columnspan=2) 202 | self.model_frame.grid_columnconfigure(0, weight=1) 203 | self.model_frame.grid_columnconfigure(1, weight=1) 204 | self.model_widget() 205 | 206 | elif frame == "GPU Info": 207 | self.gpu_frame = ctk.CTkFrame(self.main_frame) 208 | self.gpu_frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew", columnspan=2) 209 | self.gpu_frame.grid_columnconfigure(0, weight=1) 210 | self.gpu_widget() 211 | 212 | def gpu_widget(self): 213 | get_info = GPUInfo() 214 | info = get_info.load_gpu_info() 215 | cuda_available = info.get("cuda_available", "N/A") 216 | if cuda_available: 217 | cuda_available = "True" 218 | else: 219 | cuda_available = "False" 220 | 221 | label_1 = ctk.CTkLabel(self.gpu_frame, text="CUDA Available", font=FONTS["normal"]) 222 | label_1.grid(row=0, column=0, padx=20, pady=(20, 10), sticky="w") 223 | label_1_value = ctk.CTkLabel(self.gpu_frame, text=cuda_available, font=FONTS["small"]) 224 | label_1_value.grid(row=0, column=1, padx=20, pady=(20, 10), sticky="e") 225 | 226 | label_2 = ctk.CTkLabel(self.gpu_frame, text="GPU Count", font=FONTS["normal"]) 227 | label_2.grid(row=1, column=0, padx=20, pady=10, sticky="w") 228 | label_2_value = ctk.CTkLabel(self.gpu_frame, text=info.get("gpu_count", "N/A"), font=FONTS["small"]) 229 | label_2_value.grid(row=1, column=1, padx=20, pady=10, sticky="e") 230 | 231 | label_3 = ctk.CTkLabel(self.gpu_frame, text="Current GPU", font=FONTS["normal"]) 232 | label_3.grid(row=2, column=0, padx=20, pady=10, sticky="w") 233 | label_3_value = ctk.CTkLabel(self.gpu_frame, text=info.get("current_gpu", "N/A"), font=FONTS["small"]) 234 | label_3_value.grid(row=2, column=1, padx=20, pady=10, sticky="e") 235 | 236 | label_4 = ctk.CTkLabel(self.gpu_frame, text="GPU Name", font=FONTS["normal"]) 237 | label_4.grid(row=3, column=0, padx=20, pady=10, sticky="w") 238 | label_4_value = ctk.CTkLabel(self.gpu_frame, text=info.get("gpu_name", "N/A"), font=FONTS["small"]) 239 | label_4_value.grid(row=3, column=1, padx=20, pady=10, sticky="e") 240 | 241 | label_5 = ctk.CTkLabel(self.gpu_frame, text="Total Memory", font=FONTS["normal"]) 242 | label_5.grid(row=4, column=0, padx=20, pady=(10, 20), sticky="w") 243 | label_5_value = ctk.CTkLabel(self.gpu_frame, text=f"{info.get('total_memory', 'N/A')} GB", font=FONTS["small"]) 244 | label_5_value.grid(row=4, column=1, padx=20, pady=(10, 20), sticky="e") 245 | 246 | def model_widget(self): 247 | settings = self.settings_handler.load_settings() 248 | 249 | model_size = settings.get("model_size", "Base").capitalize() 250 | language = settings.get("language", "Auto Detection").capitalize() 251 | device = settings.get("device", "CPU").upper() 252 | 253 | model_requirements = ModelRequirements() 254 | model_values, device_values = model_requirements.update_model_requirements() 255 | model_size_label = ctk.CTkLabel(self.model_frame, text="Model Size", font=FONTS["normal"]) 256 | model_size_label.grid(row=0, column=0, padx=20, pady=(20, 10), sticky="w") 257 | model_size_value = ctk.CTkOptionMenu(self.model_frame, font=FONTS["small"]) 258 | model_size_value.grid(row=0, column=1, padx=20, pady=10, sticky="e") 259 | self.model_dropdown = CTkScrollableDropdownFrame(model_size_value, values=model_values, **DROPDOWN, 260 | command=lambda 261 | new_value=model_size_value: self.change_settings_value( 262 | key_name="model_size", new_value=f"{new_value}")) 263 | model_size_value.set(model_size) 264 | 265 | language_label = ctk.CTkLabel(self.model_frame, text="Language", font=FONTS["normal"]) 266 | language_label.grid(row=1, column=0, padx=20, pady=10, sticky="w") 267 | language_value = ctk.CTkOptionMenu(self.model_frame, font=FONTS["small"]) 268 | language_value.grid(row=1, column=1, padx=20, pady=10, sticky="e") 269 | self.language_dropdown = CTkScrollableDropdownFrame(language_value, values=languages, **DROPDOWN, 270 | command=lambda 271 | new_value=language_value: self.change_settings_value( 272 | key_name="language", new_value=f"{new_value}")) 273 | language_value.set(language) 274 | 275 | device_label = ctk.CTkLabel(self.model_frame, text="Device", font=FONTS["normal"]) 276 | device_label.grid(row=2, column=0, padx=20, pady=10, sticky="w") 277 | device_value = ctk.CTkOptionMenu(self.model_frame, font=FONTS["small"]) 278 | device_value.grid(row=2, column=1, padx=20, pady=10, sticky="e") 279 | self.device_dropdown = CTkScrollableDropdownFrame(device_value, values=device_values, **DROPDOWN, 280 | command=lambda 281 | new_value=device_value: self.change_settings_value( 282 | key_name="device", new_value=f"{new_value}")) 283 | device_value.set(device) 284 | 285 | download_btn = ctk.CTkButton(self.model_frame, text="Download All Models", command=self.download_callback, 286 | height=50, font=FONTS["small"]) 287 | download_btn.grid(row=3, column=1, padx=20, pady=(0, 20), sticky="e") 288 | 289 | btns_frame = ctk.CTkScrollableFrame(self.model_frame, fg_color="transparent", label_text="Installed Models", 290 | label_font=FONTS["normal"]) 291 | btns_frame.grid(row=3, column=0, padx=20, pady=(0, 20), sticky="w") 292 | 293 | installed_models = [] 294 | 295 | cache_folder = os.path.join(os.path.expanduser('~'), f'.cache{os.path.sep}whisper{os.path.sep}') 296 | files = os.listdir(cache_folder) 297 | for file in files: 298 | name, extension = os.path.splitext(file) 299 | 300 | installed_models.append(name) 301 | 302 | if installed_models: 303 | for index, model in enumerate(installed_models): 304 | model_btn = ctk.CTkButton(btns_frame, text=str(model).capitalize(), height=25, hover=False, 305 | corner_radius=2, font=FONTS["small"]) 306 | model_btn.grid(row=index, column=0, padx=20, pady=5, sticky="nsew") 307 | else: 308 | warning_btn = ctk.CTkButton(btns_frame, text="No models are installed", height=25, hover=False, 309 | corner_radius=2, font=FONTS["small"]) 310 | warning_btn.grid(row=0, column=0, padx=20, pady=5, sticky="nsew") 311 | 312 | def general_widget(self): 313 | settings = self.settings_handler.load_settings() 314 | theme = settings.get("theme", "System").capitalize() 315 | color_theme = settings.get("color_theme", "Blue").capitalize() 316 | 317 | theme_label = ctk.CTkLabel(self.general_frame, text="Theme", font=FONTS["normal"]) 318 | theme_label.grid(row=0, column=0, padx=20, pady=(20, 10), sticky="w") 319 | self.theme_btn = ctk.CTkOptionMenu(self.general_frame, font=FONTS["small"]) 320 | values = ["System", "Light", "Dark"] 321 | self.theme_btn.grid(row=0, column=1, padx=20, pady=(20, 10), sticky="e") 322 | self.theme_dropdown = CTkScrollableDropdownFrame(self.theme_btn, values=values, **DROPDOWN, 323 | command=self.change_theme) 324 | self.theme_btn.set(theme) 325 | 326 | color_theme_label = ctk.CTkLabel(self.general_frame, text="Color Theme", font=FONTS["normal"]) 327 | color_theme_label.grid(row=1, column=0, padx=20, pady=(20, 10), sticky="w") 328 | self.color_theme_btn = ctk.CTkOptionMenu(self.general_frame, font=FONTS["small"]) 329 | color_values = ["Blue", "Dark-Blue", "Green"] 330 | self.color_theme_btn.grid(row=1, column=1, padx=20, pady=(20, 10), sticky="e") 331 | self.color_theme_dropdown = CTkScrollableDropdownFrame(self.color_theme_btn, values=color_values, 332 | **DROPDOWN, 333 | command=self.change_color_theme) 334 | self.color_theme_btn.set(color_theme) 335 | 336 | developer_label = ctk.CTkLabel(self.general_frame, text="Developer", font=FONTS["normal"]) 337 | developer_label.grid(row=2, column=0, padx=20, pady=10, sticky="w") 338 | developer_value = ctk.CTkLabel(self.general_frame, text="@iamironman", font=FONTS["small"]) 339 | developer_value.grid(row=2, column=1, padx=20, pady=10, sticky="e") 340 | 341 | released_label = ctk.CTkLabel(self.general_frame, text="Released", font=FONTS["normal"]) 342 | released_label.grid(row=3, column=0, padx=20, pady=10, sticky="w") 343 | released_value = ctk.CTkLabel(self.general_frame, text="12/2/2023", font=FONTS["small"]) 344 | released_value.grid(row=3, column=1, padx=20, pady=10, sticky="e") 345 | 346 | help_label = ctk.CTkLabel(self.general_frame, text="FAQ or Help Center", font=FONTS["normal"]) 347 | help_label.grid(row=4, column=0, padx=20, pady=10, sticky="w") 348 | help_value = ctk.CTkButton(self.general_frame, text="Get help", font=FONTS["small"], command=help_link) 349 | help_value.grid(row=4, column=1, padx=20, pady=10, sticky="e") 350 | 351 | def download_callback(self): 352 | self.close_btn.configure(state="disabled") 353 | self.reset_btn.destroy() 354 | widgets = self.main_frame.winfo_children() 355 | 356 | for widget in widgets: 357 | widget.grid_forget() 358 | 359 | self.loader = CTkLoader(parent=self.master, title="Downloading Models", msg="Please wait...", 360 | cancel_func=self.cancel_downloads) 361 | 362 | self.thread = threading.Thread(target=self.download_thread) 363 | self.thread.start() 364 | 365 | def download_thread(self): 366 | interpreter_path = sys.executable 367 | script_path = f"{current_path}{os.path.sep}download_models.py" 368 | 369 | command = [interpreter_path, script_path] 370 | self.process = subprocess.Popen(command) 371 | 372 | return_code = self.process.wait() 373 | if return_code == 0: 374 | CTkAlert(parent=self.master, status="success", title="Download Complete", 375 | msg="All available models are installed successfully.") 376 | self.cancel_downloads() 377 | else: 378 | CTkAlert(parent=self.master, status="error", title="Download Failed", 379 | msg="Failed to download models. Please check your internet connection and try again.") 380 | 381 | self.cancel_downloads() 382 | 383 | def cancel_downloads(self): 384 | self.loader.hide_loader() 385 | self.process.terminate() 386 | self.close_btn.configure(state="normal") 387 | self.main_frame.grid(row=1, column=0, padx=20, pady=0, sticky="nsew", columnspan=2) 388 | self.after(1000, self.default_widget) 389 | 390 | def reset_callback(self): 391 | message = self.settings_handler.reset_settings() 392 | for widget in self.main_frame.winfo_children(): 393 | widget.grid_forget() 394 | 395 | self.default_widget() 396 | 397 | CTkAlert(parent=self.master, status="success", title="Success", msg=message) 398 | 399 | def change_settings_value(self, key_name: str, new_value: str): 400 | self.settings_handler.save_settings(**{f"{key_name}": f"{new_value.lower()}"}) 401 | for widget in self.model_frame.winfo_children(): 402 | widget.grid_forget() 403 | 404 | self.update_view(frame="Model Settings") 405 | 406 | def change_theme(self, theme): 407 | new_theme = str(theme).lower() 408 | ctk.set_appearance_mode(new_theme) 409 | for widget in self.main_frame.winfo_children(): 410 | widget.grid_forget() 411 | 412 | self.settings_handler.save_settings(**{"theme": f"{new_theme}"}) 413 | 414 | self.default_widget() 415 | 416 | def change_color_theme(self, color_theme): 417 | new_color_theme = str(color_theme).lower() 418 | ctk.set_default_color_theme(new_color_theme) 419 | 420 | for widget in self.main_frame.winfo_children(): 421 | widget.grid_forget() 422 | 423 | self.settings_handler.save_settings(**{"color_theme": f"{new_color_theme}"}) 424 | 425 | self.default_widget() 426 | 427 | def hide_settings_ui(self): 428 | self.destroy() 429 | --------------------------------------------------------------------------------