├── src
├── __init__.py
├── whisper
│ ├── version.py
│ ├── __main__.py
│ ├── assets
│ │ └── mel_filters.npz
│ ├── normalizers
│ │ ├── __init__.py
│ │ └── basic.py
│ ├── triton_ops.py
│ ├── audio.py
│ ├── __init__.py
│ ├── model.py
│ ├── utils.py
│ ├── tokenizer.py
│ └── timing.py
├── ui
│ ├── ctkAlert
│ │ ├── __init__.py
│ │ ├── icons
│ │ │ ├── info_dark.png
│ │ │ ├── error_dark.png
│ │ │ ├── error_light.png
│ │ │ ├── info_light.png
│ │ │ ├── success_dark.png
│ │ │ ├── warning_dark.png
│ │ │ ├── success_light.png
│ │ │ └── warning_light.png
│ │ └── ctkalert.py
│ ├── ctkLoader
│ │ ├── __init__.py
│ │ └── ctkloader.py
│ ├── icons
│ │ ├── close_dark.png
│ │ ├── close_light.png
│ │ ├── settings_dark.png
│ │ ├── audio_file_dark.png
│ │ ├── audio_file_light.png
│ │ ├── microphone_dark.png
│ │ ├── microphone_light.png
│ │ ├── settings_light.png
│ │ ├── subtitle_light.png
│ │ ├── subtitles_dark.png
│ │ ├── translate_light.png
│ │ └── translation_dark.png
│ ├── download_models.py
│ ├── style.py
│ ├── icons.py
│ ├── ctk_tooltip.py
│ ├── add_subtitles.py
│ ├── translate.py
│ ├── live_transcribe.py
│ ├── transcribe.py
│ ├── ctkdropdown.py
│ └── settings.py
└── logic
│ ├── __init__.py
│ ├── settings.py
│ ├── transcriber.py
│ ├── gpu_details.py
│ ├── model_requirements.py
│ └── live_transcriber.py
├── demo
├── 1.PNG
├── 2.PNG
├── 3.PNG
├── 4.PNG
├── 5.PNG
├── 6.PNG
├── 7.PNG
├── 8.PNG
├── 9.PNG
├── 10.PNG
└── 11.PNG
├── requirements.txt
├── assets
└── icons
│ ├── logo.ico
│ ├── logo.png
│ ├── close_dark.png
│ ├── close_light.png
│ ├── github_dark.png
│ ├── help_dark.png
│ ├── help_light.png
│ ├── paypal_dark.png
│ ├── github_light.png
│ ├── paypal_light.png
│ ├── settings_dark.png
│ ├── audio_file_dark.png
│ ├── audio_file_light.png
│ ├── microphone_dark.png
│ ├── microphone_light.png
│ ├── settings_light.png
│ ├── subtitles_dark.png
│ ├── subtitles_light.png
│ ├── translation_dark.png
│ └── translation_light.png
├── README.md
├── setup.py
└── main.py
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/whisper/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "20231117"
2 |
--------------------------------------------------------------------------------
/src/ui/ctkAlert/__init__.py:
--------------------------------------------------------------------------------
1 | from .ctkalert import CTkAlert
2 |
--------------------------------------------------------------------------------
/src/ui/ctkLoader/__init__.py:
--------------------------------------------------------------------------------
1 | from .ctkloader import CTkLoader
2 |
--------------------------------------------------------------------------------
/src/whisper/__main__.py:
--------------------------------------------------------------------------------
1 | from .transcribe import cli
2 |
3 | cli()
4 |
--------------------------------------------------------------------------------
/demo/1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/1.PNG
--------------------------------------------------------------------------------
/demo/2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/2.PNG
--------------------------------------------------------------------------------
/demo/3.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/3.PNG
--------------------------------------------------------------------------------
/demo/4.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/4.PNG
--------------------------------------------------------------------------------
/demo/5.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/5.PNG
--------------------------------------------------------------------------------
/demo/6.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/6.PNG
--------------------------------------------------------------------------------
/demo/7.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/7.PNG
--------------------------------------------------------------------------------
/demo/8.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/8.PNG
--------------------------------------------------------------------------------
/demo/9.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/9.PNG
--------------------------------------------------------------------------------
/demo/10.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/10.PNG
--------------------------------------------------------------------------------
/demo/11.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/demo/11.PNG
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/requirements.txt
--------------------------------------------------------------------------------
/assets/icons/logo.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/logo.ico
--------------------------------------------------------------------------------
/assets/icons/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/logo.png
--------------------------------------------------------------------------------
/assets/icons/close_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/close_dark.png
--------------------------------------------------------------------------------
/assets/icons/close_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/close_light.png
--------------------------------------------------------------------------------
/assets/icons/github_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/github_dark.png
--------------------------------------------------------------------------------
/assets/icons/help_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/help_dark.png
--------------------------------------------------------------------------------
/assets/icons/help_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/help_light.png
--------------------------------------------------------------------------------
/assets/icons/paypal_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/paypal_dark.png
--------------------------------------------------------------------------------
/src/ui/icons/close_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/close_dark.png
--------------------------------------------------------------------------------
/src/ui/icons/close_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/close_light.png
--------------------------------------------------------------------------------
/assets/icons/github_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/github_light.png
--------------------------------------------------------------------------------
/assets/icons/paypal_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/paypal_light.png
--------------------------------------------------------------------------------
/assets/icons/settings_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/settings_dark.png
--------------------------------------------------------------------------------
/src/ui/icons/settings_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/settings_dark.png
--------------------------------------------------------------------------------
/assets/icons/audio_file_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/audio_file_dark.png
--------------------------------------------------------------------------------
/assets/icons/audio_file_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/audio_file_light.png
--------------------------------------------------------------------------------
/assets/icons/microphone_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/microphone_dark.png
--------------------------------------------------------------------------------
/assets/icons/microphone_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/microphone_light.png
--------------------------------------------------------------------------------
/assets/icons/settings_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/settings_light.png
--------------------------------------------------------------------------------
/assets/icons/subtitles_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/subtitles_dark.png
--------------------------------------------------------------------------------
/assets/icons/subtitles_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/subtitles_light.png
--------------------------------------------------------------------------------
/assets/icons/translation_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/translation_dark.png
--------------------------------------------------------------------------------
/src/ui/icons/audio_file_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/audio_file_dark.png
--------------------------------------------------------------------------------
/src/ui/icons/audio_file_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/audio_file_light.png
--------------------------------------------------------------------------------
/src/ui/icons/microphone_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/microphone_dark.png
--------------------------------------------------------------------------------
/src/ui/icons/microphone_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/microphone_light.png
--------------------------------------------------------------------------------
/src/ui/icons/settings_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/settings_light.png
--------------------------------------------------------------------------------
/src/ui/icons/subtitle_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/subtitle_light.png
--------------------------------------------------------------------------------
/src/ui/icons/subtitles_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/subtitles_dark.png
--------------------------------------------------------------------------------
/src/ui/icons/translate_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/translate_light.png
--------------------------------------------------------------------------------
/src/ui/icons/translation_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/icons/translation_dark.png
--------------------------------------------------------------------------------
/assets/icons/translation_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/assets/icons/translation_light.png
--------------------------------------------------------------------------------
/src/ui/ctkAlert/icons/info_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/info_dark.png
--------------------------------------------------------------------------------
/src/whisper/assets/mel_filters.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/whisper/assets/mel_filters.npz
--------------------------------------------------------------------------------
/src/ui/ctkAlert/icons/error_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/error_dark.png
--------------------------------------------------------------------------------
/src/ui/ctkAlert/icons/error_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/error_light.png
--------------------------------------------------------------------------------
/src/ui/ctkAlert/icons/info_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/info_light.png
--------------------------------------------------------------------------------
/src/ui/ctkAlert/icons/success_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/success_dark.png
--------------------------------------------------------------------------------
/src/ui/ctkAlert/icons/warning_dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/warning_dark.png
--------------------------------------------------------------------------------
/src/ui/ctkAlert/icons/success_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/success_light.png
--------------------------------------------------------------------------------
/src/ui/ctkAlert/icons/warning_light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rudymohammadbali/Whisper-Transcriber/HEAD/src/ui/ctkAlert/icons/warning_light.png
--------------------------------------------------------------------------------
/src/whisper/normalizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic import BasicTextNormalizer as BasicTextNormalizer
2 | from .english import EnglishTextNormalizer as EnglishTextNormalizer
3 |
--------------------------------------------------------------------------------
/src/logic/__init__.py:
--------------------------------------------------------------------------------
1 | from .gpu_details import GPUInfo
2 | from .settings import SettingsHandler
3 | from .model_requirements import ModelRequirements
4 | from .transcriber import Transcriber
5 |
--------------------------------------------------------------------------------
/src/ui/download_models.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .. import whisper
4 |
5 |
6 | def download_model():
7 | models = whisper.available_models()
8 | cache_folder = os.path.join(os.path.expanduser('~'), f'.cache{os.path.sep}whisper{os.path.sep}')
9 |
10 | for model in models:
11 | print(f"Downloading {model}")
12 | whisper._download(url=whisper._MODELS[model], root=cache_folder, in_memory=False)
13 |
14 |
15 | download_model()
16 |
--------------------------------------------------------------------------------
/src/ui/style.py:
--------------------------------------------------------------------------------
1 | FONTS = {
2 | "title_bold": ("Inter", 24, "bold"),
3 | "title": ("Inter", 22, "normal"),
4 | "subtitle_bold": ("Inter", 18, "bold"),
5 | "subtitle": ("Inter", 18, "normal"),
6 | "normal_bold": ("Inter", 15, "bold"),
7 | "normal": ("Inter", 15, "normal"),
8 | "small": ("Inter", 13, "normal"),
9 | }
10 |
11 | DROPDOWN = {
12 | "corner_radius": 2,
13 | "alpha": 1.0,
14 | "frame_corner_radius": 5,
15 | "x": 0,
16 | "hover": False,
17 | "justify": "left"
18 | }
19 |
--------------------------------------------------------------------------------
/src/ui/icons.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import customtkinter as ctk
4 | from PIL import Image
5 |
6 | current_path = os.path.dirname(os.path.realpath(__file__))
7 | icon_path = f"{current_path}{os.path.sep}icons{os.path.sep}"
8 |
9 | icons = {
10 | "close": ctk.CTkImage(dark_image=Image.open(f"{icon_path}close_dark.png"),
11 | light_image=Image.open(f"{icon_path}close_light.png"), size=(30, 30)),
12 | "audio_file": ctk.CTkImage(dark_image=Image.open(f"{icon_path}audio_file_dark.png"),
13 | light_image=Image.open(f"{icon_path}audio_file_light.png"), size=(30, 30)),
14 | "translate": ctk.CTkImage(dark_image=Image.open(f"{icon_path}translation_dark.png"),
15 | light_image=Image.open(f"{icon_path}translate_light.png"), size=(30, 30)),
16 | "microphone": ctk.CTkImage(dark_image=Image.open(f"{icon_path}microphone_dark.png"),
17 | light_image=Image.open(f"{icon_path}microphone_light.png"), size=(30, 30)),
18 | "subtitle": ctk.CTkImage(dark_image=Image.open(f"{icon_path}subtitles_dark.png"),
19 | light_image=Image.open(f"{icon_path}subtitle_light.png"), size=(30, 30)),
20 | "settings": ctk.CTkImage(dark_image=Image.open(f"{icon_path}settings_dark.png"),
21 | light_image=Image.open(f"{icon_path}settings_light.png"), size=(30, 30))
22 | }
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Whisper Transcriber GUI
6 |
7 | ###
8 |
9 | Modern Desktop Application offering a suite of tools for audio/video text recognition and a variety of other useful utilities.
10 |
11 | ###
12 |
13 | Demo
14 |
15 | ###
16 |
17 | Check demo folder for rest of other images
18 |
19 |
20 |
21 |
22 |
23 | ###
24 |
25 | Installation
26 |
27 | ###
28 |
29 | Setup
30 |
31 | ###
32 |
33 | ```
34 | git clone https://github.com/rudymohammadbali/Whisper-Transcriber.git
35 | ```
36 | ```
37 | cd Whisper-Transcriber
38 | ```
39 | ```
40 | python setup.py
41 | ```
42 | and Finally run it
43 | ```
44 | python main.py
45 | ```
46 | ###
47 |
48 | Special thanks to @Akascape for bringing this app to life
49 |
50 | Support
51 |
52 | ###
53 |
54 | If you'd like to support my ongoing efforts in sharing fantastic open-source projects, you can contribute by making a donation via PayPal.
55 |
56 | ###
57 |
58 |
63 |
64 | ###
65 |
--------------------------------------------------------------------------------
/src/logic/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 |
5 | DEFAULT_CONFIG = {
6 | "theme": "system",
7 | "color_theme": "blue",
8 | "model_size": "base",
9 | "language": "auto detection",
10 | "device": "cpu"
11 | }
12 |
13 |
14 | class SettingsHandler:
15 | def __init__(self):
16 | current_dir = os.path.dirname(__file__)
17 | file_path = os.path.join(current_dir, '..', '..', f'assets{os.path.sep}config')
18 | self.config_path = os.path.normpath(file_path)
19 | self.settings_file_path = os.path.join(self.config_path, "settings.json")
20 |
21 | def save_settings(self, **new_settings: dict):
22 | try:
23 | existing_settings = self.load_settings()
24 | existing_settings.update(new_settings)
25 | if not os.path.exists(self.config_path):
26 | os.mkdir(self.config_path)
27 |
28 | with open(self.settings_file_path, 'w') as file:
29 | json.dump(existing_settings, file)
30 | return "Your settings have been saved successfully."
31 | except FileNotFoundError:
32 | pass
33 | except PermissionError:
34 | pass
35 |
36 | def load_settings(self):
37 | try:
38 | if os.path.exists(self.settings_file_path) and os.path.getsize(self.settings_file_path) > 0:
39 | with open(self.settings_file_path, "r") as file:
40 | loaded_settings = json.load(file)
41 | return loaded_settings
42 | else:
43 | with open(self.settings_file_path, 'w') as file:
44 | json.dump(DEFAULT_CONFIG, file)
45 | return DEFAULT_CONFIG
46 | except FileNotFoundError:
47 | return DEFAULT_CONFIG
48 |
49 | def reset_settings(self):
50 | try:
51 | if not os.path.exists(self.config_path):
52 | os.mkdir(self.config_path)
53 |
54 | with open(self.settings_file_path, 'w') as file:
55 | json.dump(DEFAULT_CONFIG, file)
56 | return "Your settings have been reset to default values."
57 | except FileNotFoundError:
58 | pass
59 | except PermissionError:
60 | pass
61 |
--------------------------------------------------------------------------------
/src/whisper/normalizers/basic.py:
--------------------------------------------------------------------------------
1 | import re
2 | import unicodedata
3 |
4 | import regex
5 |
6 | # non-ASCII letters that are not separated by "NFKD" normalization
7 | ADDITIONAL_DIACRITICS = {
8 | "œ": "oe",
9 | "Œ": "OE",
10 | "ø": "o",
11 | "Ø": "O",
12 | "æ": "ae",
13 | "Æ": "AE",
14 | "ß": "ss",
15 | "ẞ": "SS",
16 | "đ": "d",
17 | "Đ": "D",
18 | "ð": "d",
19 | "Ð": "D",
20 | "þ": "th",
21 | "Þ": "th",
22 | "ł": "l",
23 | "Ł": "L",
24 | }
25 |
26 |
27 | def remove_symbols_and_diacritics(s: str, keep=""):
28 | """
29 | Replace any other markers, symbols, and punctuations with a space,
30 | and drop any diacritics (category 'Mn' and some manual mappings)
31 | """
32 | return "".join(
33 | c
34 | if c in keep
35 | else ADDITIONAL_DIACRITICS[c]
36 | if c in ADDITIONAL_DIACRITICS
37 | else ""
38 | if unicodedata.category(c) == "Mn"
39 | else " "
40 | if unicodedata.category(c)[0] in "MSP"
41 | else c
42 | for c in unicodedata.normalize("NFKD", s)
43 | )
44 |
45 |
46 | def remove_symbols(s: str):
47 | """
48 | Replace any other markers, symbols, punctuations with a space, keeping diacritics
49 | """
50 | return "".join(
51 | " " if unicodedata.category(c)[0] in "MSP" else c
52 | for c in unicodedata.normalize("NFKC", s)
53 | )
54 |
55 |
56 | class BasicTextNormalizer:
57 | def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
58 | self.clean = (
59 | remove_symbols_and_diacritics if remove_diacritics else remove_symbols
60 | )
61 | self.split_letters = split_letters
62 |
63 | def __call__(self, s: str):
64 | s = s.lower()
65 | s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
66 | s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
67 | s = self.clean(s).lower()
68 |
69 | if self.split_letters:
70 | s = " ".join(regex.findall(r"\X", s, regex.U))
71 |
72 | s = re.sub(
73 | r"\s+", " ", s
74 | ) # replace any successive whitespace characters with a space
75 |
76 | return s
77 |
--------------------------------------------------------------------------------
/src/logic/transcriber.py:
--------------------------------------------------------------------------------
1 | from .. import whisper
2 |
3 | from .settings import SettingsHandler
4 | from deep_translator import GoogleTranslator
5 |
6 |
7 | class Transcriber:
8 | def __init__(self, audio: str):
9 | self.audio = audio
10 | get_config = SettingsHandler()
11 | config = get_config.load_settings()
12 |
13 | model = config.get("model_size", "base").lower()
14 | self.language = config.get("language", "auto detect").lower()
15 | device = config.get("device", "cpu").lower()
16 | self.fp16 = False
17 |
18 | if self.language == "english" and model not in ["large", "large-v1", "large-v2"]:
19 | model += ".en"
20 |
21 | if device == "gpu":
22 | device = "cuda"
23 | self.fp16 = True
24 |
25 | self.load_model = whisper.load_model(name=model, device=device)
26 |
27 | self.options = {
28 | "task": "transcribe",
29 | "fp16": self.fp16
30 | }
31 |
32 | def audio_recognition(self, cancel_func=any):
33 | print("Audio Recognition started...")
34 | if self.language == "auto detection":
35 | audio = whisper.load_audio(self.audio)
36 | audio = whisper.pad_or_trim(audio)
37 |
38 | mel = whisper.log_mel_spectrogram(audio).to(self.load_model.device)
39 |
40 | _, probs = self.load_model.detect_language(mel)
41 | detected_language = max(probs, key=probs.get)
42 | result = whisper.transcribe(model=self.load_model, audio=self.audio,
43 | language=detected_language, **self.options, cancel_func=cancel_func)
44 | else:
45 | result = whisper.transcribe(model=self.load_model, audio=self.audio,
46 | **self.options, language=self.language, cancel_func=cancel_func)
47 |
48 | return result
49 |
50 | def translate_audio(self, cancel_func, to_language):
51 | print("Audio Translate started...")
52 | result = self.audio_recognition(cancel_func=cancel_func)
53 | text = str(result["text"]).strip()
54 | language = to_language
55 | translated_text = GoogleTranslator(source='auto', target=language).translate(text=text)
56 | # langs_list = GoogleTranslator().get_supported_languages()
57 | return translated_text
58 |
--------------------------------------------------------------------------------
/src/ui/ctkLoader/ctkloader.py:
--------------------------------------------------------------------------------
1 | import customtkinter as ctk
2 |
3 | CTkFrame = {
4 | "width": 500,
5 | "height": 200,
6 | "corner_radius": 5,
7 | "border_width": 1,
8 | "fg_color": ["#F7F8FA", "#2B2D30"],
9 | "border_color": ["#D3D5DB", "#4E5157"]
10 | }
11 |
12 | CTkButton = {
13 | "width": 150,
14 | "height": 40,
15 | "corner_radius": 5,
16 | "border_width": 1,
17 | "fg_color": ["#3574F0", "#2B2D30"],
18 | "hover_color": ["#3369D6", "#4E5157"],
19 | "border_color": ["#C2D6FC", "#4E5157"],
20 | "text_color": ["#FFFFFF", "#DFE1E5"],
21 | "text_color_disabled": ["#A8ADBD", "#6F737A"]
22 | }
23 |
24 | CTkProgressBar = {
25 | "width": 460,
26 | "height": 5,
27 | "corner_radius": 5,
28 | "border_width": 0,
29 | "fg_color": ["#DFE1E5", "#43454A"],
30 | "progress_color": ["#3574F0", "#3574F0"],
31 | "border_color": ["#3574F0", "#4E5157"]
32 | }
33 |
34 | font_title = ("Inter", 20, "bold")
35 | font_normal = ("Inter", 14, "normal")
36 |
37 |
38 | class CTkLoader(ctk.CTkFrame):
39 | def __init__(self, parent: any, title: str, msg: str, cancel_func: any):
40 | super().__init__(master=parent, **CTkFrame)
41 | self.title = title.capitalize()
42 | self.msg = msg.capitalize()
43 | self.grid_propagate(False)
44 | self.grid_columnconfigure(0, weight=1)
45 |
46 | self.cancel_func = cancel_func
47 |
48 | title = ctk.CTkLabel(self, text=self.title, text_color=("#000000", "#DFE1E5"), font=font_title)
49 | title.grid(row=0, column=0, padx=20, pady=(20, 10), sticky="nw")
50 |
51 | msg = ctk.CTkLabel(self, text=self.msg, text_color=("#000000", "#DFE1E5"), font=font_normal)
52 | msg.grid(row=1, column=0, padx=20, pady=10, sticky="nsw")
53 |
54 | self.progressbar = ctk.CTkProgressBar(self, mode="indeterminate", **CTkProgressBar)
55 | self.progressbar.grid(row=2, column=0, padx=20, pady=10, sticky="sw")
56 | self.progressbar.start()
57 |
58 | button = ctk.CTkButton(self, text="Cancel", command=self.cancel_callback, **CTkButton)
59 | button.grid(row=3, column=0, padx=20, pady=(10, 20), sticky="nsew")
60 |
61 | # self.pack(anchor="center", expand=True)
62 | self.grid(row=0, column=0, sticky="ew", padx=20)
63 |
64 | def hide_loader(self):
65 | self.progressbar.stop()
66 | self.destroy()
67 |
68 | def cancel_callback(self):
69 | self.cancel_func()
70 | self.after(100, self.hide_loader)
71 |
--------------------------------------------------------------------------------
/src/logic/gpu_details.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import GPUtil
5 | import torch
6 |
7 |
8 | class GPUInfo:
9 | def __init__(self):
10 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11 | self.cuda_available = torch.cuda.is_available()
12 | current_dir = os.path.dirname(__file__)
13 | file_path = os.path.join(current_dir, '..', '..', f"assets{os.path.sep}config")
14 | self.config_path = os.path.normpath(file_path)
15 | self.gpu_info_file_path = os.path.join(self.config_path, "gpu_info.json")
16 |
17 | def get_gpu_info(self):
18 | if self.cuda_available:
19 | gpu_count = torch.cuda.device_count()
20 | current_device = torch.cuda.current_device()
21 | gpu_name = torch.cuda.get_device_name(current_device)
22 | total_memory = round(torch.cuda.get_device_properties(current_device).total_memory / (1024 ** 3))
23 | else:
24 | gpus = GPUtil.getGPUs()
25 | gpu_count = len(gpus)
26 |
27 | if gpu_count == 0:
28 | current_device = "No GPU found."
29 | gpu_name = "N/A"
30 | total_memory = 0
31 | else:
32 | current_gpu = gpus[0]
33 | current_device = f"{current_gpu.name} ({current_gpu.id})"
34 | gpu_name = current_gpu.name
35 | total_memory = current_gpu.memoryTotal / 1024
36 |
37 | return {"cuda_available": self.cuda_available,
38 | "gpu_count" : gpu_count,
39 | "current_gpu" : current_device,
40 | "gpu_name" : gpu_name,
41 | "total_memory" : total_memory}
42 |
43 | def save_gpu_info(self):
44 | settings = self.get_gpu_info()
45 |
46 | try:
47 | if not os.path.exists(self.config_path):
48 | os.mkdir(self.config_path)
49 |
50 | with open(self.gpu_info_file_path, 'w') as file: json.dump(settings, file)
51 | except FileNotFoundError or PermissionError:
52 | pass
53 |
54 | def load_gpu_info(self):
55 | if not os.path.exists(self.gpu_info_file_path) or os.path.getsize(self.gpu_info_file_path) == 0:
56 | self.save_gpu_info()
57 | try:
58 | with open(self.gpu_info_file_path, 'r') as file: loaded_settings = json.load(file)
59 | except FileNotFoundError:
60 | self.save_gpu_info()
61 | with open(self.gpu_info_file_path, 'r') as file: loaded_settings = json.load(file)
62 | return loaded_settings
63 |
--------------------------------------------------------------------------------
/src/logic/model_requirements.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from .gpu_details import GPUInfo
5 |
6 | MODEL_SETTINGS = {
7 | "models": ["Tiny", "Base", "Small", "Medium", "Large", "Large-v1", "Large-v2", "Large-v3"],
8 | "device": ["GPU", "CPU"]
9 | }
10 |
11 |
12 | class ModelRequirements:
13 | def __init__(self):
14 | current_dir = os.path.dirname(__file__)
15 | file_path = os.path.join(current_dir, '..', '..', f'assets{os.path.sep}config')
16 | config_path = os.path.normpath(file_path)
17 | self.filename = os.path.join(config_path, "recommended_models.json")
18 |
19 | @staticmethod
20 | def model_requirements(available_memory, required_memory):
21 | recommended_models = [model for model, memory in required_memory.items() if available_memory >= memory]
22 | return recommended_models
23 |
24 | def read_json_file(self):
25 | try:
26 | with open(self.filename, 'r') as json_file:
27 | data = json.load(json_file)
28 | return data
29 | except FileNotFoundError:
30 | with open(self.filename, 'w') as json_file:
31 | json.dump(MODEL_SETTINGS, json_file, indent=2)
32 | return MODEL_SETTINGS
33 |
34 | def write_json_file(self, data):
35 | with open(self.filename, 'w') as json_file:
36 | json.dump(data, json_file, indent=2)
37 |
38 | def update_model_requirements(self):
39 | gpu_monitor = GPUInfo()
40 | gpu_info = gpu_monitor.load_gpu_info()
41 | available_memory = gpu_info["total_memory"]
42 | cuda = gpu_info["cuda_available"]
43 |
44 | required_memory = {
45 | "Tiny": 1,
46 | "Base": 1,
47 | "Small": 2,
48 | "Medium": 5,
49 | "Large": 10,
50 | "Large-v1": 10,
51 | "Large-v2": 10,
52 | "Large-v3": 10
53 | }
54 |
55 | json_data = self.read_json_file()
56 |
57 | if json_data:
58 | if cuda:
59 | device = ["CPU", "GPU"]
60 | recommended_models = self.model_requirements(available_memory, required_memory)
61 | json_data["models"] = recommended_models
62 | json_data["device"] = device
63 | else:
64 | recommended_models = ["Tiny", "Base", "Small", "Medium", "Large", "Large-v1", "Large-v2", "Large-v3"]
65 | device = ["CPU"]
66 | json_data["models"] = recommended_models
67 | json_data["device"] = device
68 |
69 | self.write_json_file(json_data)
70 | return recommended_models, device
71 |
--------------------------------------------------------------------------------
/src/logic/live_transcriber.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import threading
4 | import queue
5 |
6 | import sounddevice as sd
7 | from .. import whisper
8 | from scipy.io.wavfile import write
9 |
10 |
11 | class LiveTranscriber:
12 | def __init__(self):
13 | self.audio_file_name = "temp.wav"
14 | self.model = whisper.load_model("base")
15 | self.language = "english"
16 | self.transcription_queue = queue.Queue()
17 | self.recording_complete = threading.Event()
18 | self.stop_recording = threading.Event()
19 |
20 | def start_transcription_thread(self):
21 | transcription_thread = threading.Thread(target=self.transcribe_loop)
22 | transcription_thread.start()
23 |
24 | def transcribe_loop(self):
25 | while not self.stop_recording.is_set():
26 | self.recording_complete.wait()
27 | transcribed_text = self.transcribe()
28 | self.transcription_queue.put(transcribed_text)
29 | self.recording_complete.clear()
30 |
31 | def start_recording_thread(self):
32 | recording_thread = threading.Thread(target=self.record_loop)
33 | recording_thread.start()
34 |
35 | def record_loop(self):
36 | while not self.stop_recording.is_set():
37 | self.recorder()
38 | self.recording_complete.set()
39 | self.stop_recording.wait(2)
40 |
41 | def start_live_transcription(self):
42 | self.start_recording_thread()
43 | self.start_transcription_thread()
44 |
45 | def stop_live_transcription(self):
46 | if os.path.exists(self.audio_file_name):
47 | os.remove(self.audio_file_name)
48 | self.stop_recording.set()
49 |
50 | def transcribe(self):
51 | print("Transcribing...")
52 |
53 | result = self.model.transcribe(audio=self.audio_file_name, cancel_func=self.cancel_callback)
54 | cleaned_result = self.clean_transcription(result['text'])
55 | print(f"Result: {cleaned_result}")
56 |
57 | os.remove(self.audio_file_name)
58 |
59 | return cleaned_result
60 |
61 | def recorder(self, duration=5):
62 | print("Say something...")
63 | sample_rate = 44100
64 | recording = sd.rec(int(duration * sample_rate), samplerate=sample_rate, channels=2)
65 | sd.wait()
66 |
67 | write(self.audio_file_name, sample_rate, recording)
68 |
69 | print(f"Recording saved as {self.audio_file_name}")
70 |
71 | def cancel_callback(self):
72 | pass
73 |
74 | @staticmethod
75 | def clean_transcription(transcription):
76 | cleaned_text = re.sub(r'[^a-zA-Z\s]', '', transcription).strip()
77 | return cleaned_text
78 |
79 |
80 | # transcriber = LiveTranscriber()
81 | # transcriber.start_live_transcription()
82 | # transcriber.stop_live_transcription()
83 |
--------------------------------------------------------------------------------
/src/ui/ctkAlert/ctkalert.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import customtkinter as ctk
4 | from PIL import Image
5 |
6 | current_path = os.path.dirname(os.path.realpath(__file__))
7 | icon_path = f"{current_path}{os.path.sep}icons{os.path.sep}"
8 |
9 | icons = {
10 | "info": ctk.CTkImage(dark_image=Image.open(f"{icon_path}info_dark.png"),
11 | light_image=Image.open(f"{icon_path}info_light.png"), size=(28, 28)),
12 | "success": ctk.CTkImage(dark_image=Image.open(f"{icon_path}success_dark.png"),
13 | light_image=Image.open(f"{icon_path}success_light.png"), size=(28, 28)),
14 | "error": ctk.CTkImage(dark_image=Image.open(f"{icon_path}error_dark.png"),
15 | light_image=Image.open(f"{icon_path}error_light.png"), size=(28, 28)),
16 | "warning": ctk.CTkImage(dark_image=Image.open(f"{icon_path}warning_dark.png"),
17 | light_image=Image.open(f"{icon_path}warning_light.png"), size=(28, 28))
18 | }
19 |
20 | CTkFrame = {
21 | "width": 500,
22 | "height": 200,
23 | "corner_radius": 5,
24 | "border_width": 1,
25 | "fg_color": ["#F7F8FA", "#2B2D30"],
26 | "border_color": ["#D3D5DB", "#4E5157"]
27 | }
28 |
29 | CTkButton = {
30 | "width": 100,
31 | "height": 30,
32 | "corner_radius": 5,
33 | "border_width": 0,
34 | "fg_color": ["#3574F0", "#3574F0"],
35 | "hover_color": ["#3369D6", "#366ACF"],
36 | "border_color": ["#C2D6FC", "#4E5157"],
37 | "text_color": ["#FFFFFF", "#DFE1E5"],
38 | "text_color_disabled": ["#A8ADBD", "#6F737A"]
39 | }
40 |
41 | font_title = ("Inter", 20, "bold")
42 | font_normal = ("Inter", 13, "normal")
43 |
44 |
45 | class CTkAlert(ctk.CTkFrame):
46 | def __init__(self, parent: any, status: str, title: str, msg: str):
47 | super().__init__(master=parent, **CTkFrame)
48 | self.status = status
49 | self.title = title.capitalize()
50 | self.msg = msg.capitalize()
51 | self.grid_propagate(False)
52 | self.grid_columnconfigure(0, weight=1)
53 |
54 | title = ctk.CTkLabel(self, text=f" {self.title}", image=icons.get(self.status, icons["info"]), compound="left",
55 | text_color=("#000000", "#DFE1E5"), font=font_title)
56 | title.grid(row=0, column=0, padx=20, pady=20, sticky="nw")
57 |
58 | msg = ctk.CTkLabel(self, text=self.msg, text_color=("#000000", "#DFE1E5"), font=font_normal)
59 | msg.grid(row=1, column=0, padx=40, pady=20, sticky="nsew")
60 |
61 | button = ctk.CTkButton(self, text="OK", command=self.hide_alert, **CTkButton)
62 | button.grid(row=2, column=0, padx=20, pady=20, sticky="se")
63 |
64 | # self.pack(anchor="center", expand=True)
65 | self.grid(row=0, column=0, sticky="ew", padx=10)
66 |
67 | def hide_alert(self):
68 | self.destroy()
69 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import sys
4 | import time
5 |
6 | try:
7 | import pkg_resources
8 | except ImportError:
9 | if sys.platform.startswith("win"):
10 | subprocess.call('python -m pip install setuptools', shell=True)
11 | else:
12 | subprocess.call('python3 -m pip install setuptools', shell=True)
13 | import pkg_resources
14 |
15 | DIR_PATH = os.path.dirname(os.path.realpath(__file__))
16 |
17 | # Checking the required folders
18 | folders = ["assets", "src"]
19 | missing_folder = []
20 | for i in folders:
21 | if not os.path.exists(i):
22 | missing_folder.append(i)
23 | if missing_folder:
24 | print("These folder(s) not available: " + str(missing_folder))
25 | print("Download them from the repository properly")
26 | sys.exit()
27 | else:
28 | print("All folders available!")
29 |
30 | # Checking required modules
31 | required = {"Pillow==10.1.0", "sounddevice==0.4.6", "scipy==1.11.4", "deep-translator==1.11.4", "mutagen==1.47.0", "pydub==0.25.1", "tkinterdnd2==0.3.0", "numpy==1.26.2",
32 | "SpeechRecognition==3.10.0", "customtkinter==5.2.1", "tqdm==4.66.1", "numba==0.58.1", "tiktoken==0.5.1", "more-itertools==10.1.0", "GPUtil==1.4.0", "PyAudio==0.2.14", "packaging==23.2"}
33 | installed = {pkg.key for pkg in pkg_resources.working_set}
34 | missing = required - installed
35 | missing_set = [*missing, ]
36 | pytorch_win = "torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118"
37 | pytorch_linux = "torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118"
38 | pytorch_mac = "torch torchvision torchaudio"
39 |
40 | # Download the modules if not installed
41 | if missing:
42 | try:
43 | print("Installing modules...")
44 | for x in range(len(missing_set)):
45 | y = missing_set[x]
46 | if sys.platform.startswith("win"):
47 | subprocess.call('python -m pip install ' + y, shell=True)
48 | else:
49 | subprocess.call('python3 -m pip install ' + y, shell=True)
50 | except:
51 | print("Unable to download! \nThis are the required ones: " + str(
52 | required) + "\nUse 'pip install module_name' to download the modules one by one.")
53 | time.sleep(3)
54 | sys.exit()
55 |
56 | try:
57 | print("Installing Pytorch...")
58 | if sys.platform.startswith("win"):
59 | subprocess.call('python -m pip install ' + pytorch_win, shell=True)
60 | elif sys.platform.startswith("linux"):
61 | subprocess.call('python3 -m pip install ' + pytorch_linux, shell=True)
62 | elif sys.platform.startswith("darwin"):
63 | subprocess.call('python3 -m pip install ' + pytorch_mac, shell=True)
64 | except:
65 | print("Unable to download! \nThis are the required ones: " + str(required) + "\nUse 'pip/pip3 install module_name' to download the modules one by one.")
66 | sys.exit()
67 | else:
68 | print("All required modules installed!")
69 |
70 | # Everything done!
71 | print("Setup Complete!")
72 | time.sleep(5)
73 | sys.exit()
74 |
--------------------------------------------------------------------------------
/src/whisper/triton_ops.py:
--------------------------------------------------------------------------------
1 | from functools import lru_cache
2 |
3 | import numpy as np
4 | import torch
5 |
6 | try:
7 | import triton
8 | import triton.language as tl
9 | except ImportError:
10 | raise RuntimeError("triton import failed; try `pip install --pre triton`")
11 |
12 |
13 | @triton.jit
14 | def dtw_kernel(
15 | cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
16 | ):
17 | offsets = tl.arange(0, BLOCK_SIZE)
18 | mask = offsets < M
19 |
20 | for k in range(1, N + M + 1): # k = i + j
21 | tl.debug_barrier()
22 |
23 | p0 = cost + (k - 1) * cost_stride
24 | p1 = cost + k * cost_stride
25 | p2 = cost + k * cost_stride + 1
26 |
27 | c0 = tl.load(p0 + offsets, mask=mask)
28 | c1 = tl.load(p1 + offsets, mask=mask)
29 | c2 = tl.load(p2 + offsets, mask=mask)
30 |
31 | x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
32 | cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
33 |
34 | cost_ptr = cost + (k + 1) * cost_stride + 1
35 | tl.store(cost_ptr + offsets, cost_row, mask=mask)
36 |
37 | trace_ptr = trace + (k + 1) * trace_stride + 1
38 | tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
39 | tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
40 | tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
41 |
42 |
43 | @lru_cache(maxsize=None)
44 | def median_kernel(filter_width: int):
45 | @triton.jit
46 | def kernel(
47 | y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
48 | ): # x.shape[-1] == filter_width
49 | row_idx = tl.program_id(0)
50 | offsets = tl.arange(0, BLOCK_SIZE)
51 | mask = offsets < y_stride
52 |
53 | x_ptr = x + row_idx * x_stride # noqa: F841
54 | y_ptr = y + row_idx * y_stride
55 |
56 | LOAD_ALL_ROWS_HERE # noqa: F821
57 |
58 | BUBBLESORT_HERE # noqa: F821
59 |
60 | tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
61 |
62 | kernel = triton.JITFunction(kernel.fn)
63 | kernel.src = kernel.src.replace(
64 | " LOAD_ALL_ROWS_HERE",
65 | "\n".join(
66 | [
67 | f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
68 | for i in range(filter_width)
69 | ]
70 | ),
71 | )
72 | kernel.src = kernel.src.replace(
73 | " BUBBLESORT_HERE",
74 | "\n\n".join(
75 | [
76 | "\n\n".join(
77 | [
78 | "\n".join(
79 | [
80 | f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
81 | f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
82 | f" row{j} = smaller",
83 | f" row{j + 1} = larger",
84 | ]
85 | )
86 | for j in range(filter_width - i - 1)
87 | ]
88 | )
89 | for i in range(filter_width // 2 + 1)
90 | ]
91 | ),
92 | )
93 | kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
94 |
95 | return kernel
96 |
97 |
98 | def median_filter_cuda(x: torch.Tensor, filter_width: int):
99 | """Apply a median filter of given width along the last dimension of x"""
100 | slices = x.contiguous().unfold(-1, filter_width, 1)
101 | grid = np.prod(slices.shape[:-2])
102 |
103 | kernel = median_kernel(filter_width)
104 | y = torch.empty_like(slices[..., 0])
105 |
106 | BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
107 | kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
108 |
109 | return y
110 |
--------------------------------------------------------------------------------
/src/whisper/audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | from functools import lru_cache
3 | from subprocess import CalledProcessError, run
4 | from typing import Optional, Union
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from .utils import exact_div
11 |
12 | # hard-coded audio hyperparameters
13 | SAMPLE_RATE = 16000
14 | N_FFT = 400
15 | HOP_LENGTH = 160
16 | CHUNK_LENGTH = 30
17 | N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
18 | N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
19 |
20 | N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
21 | FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
22 | TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
23 |
24 |
25 | def load_audio(file: str, sr: int = SAMPLE_RATE):
26 | """
27 | Open an audio file and read as mono waveform, resampling as necessary
28 |
29 | Parameters
30 | ----------
31 | file: str
32 | The audio file to open
33 |
34 | sr: int
35 | The sample rate to resample the audio if necessary
36 |
37 | Returns
38 | -------
39 | A NumPy array containing the audio waveform, in float32 dtype.
40 | """
41 |
42 | # This launches a subprocess to decode audio while down-mixing
43 | # and resampling as necessary. Requires the ffmpeg CLI in PATH.
44 | # fmt: off
45 | cmd = [
46 | "ffmpeg",
47 | "-nostdin",
48 | "-threads", "0",
49 | "-i", file,
50 | "-f", "s16le",
51 | "-ac", "1",
52 | "-acodec", "pcm_s16le",
53 | "-ar", str(sr),
54 | "-"
55 | ]
56 | # fmt: on
57 | try:
58 | out = run(cmd, capture_output=True, check=True).stdout
59 | except CalledProcessError as e:
60 | raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
61 |
62 | return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
63 |
64 |
65 | def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
66 | """
67 | Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
68 | """
69 | if torch.is_tensor(array):
70 | if array.shape[axis] > length:
71 | array = array.index_select(
72 | dim=axis, index=torch.arange(length, device=array.device)
73 | )
74 |
75 | if array.shape[axis] < length:
76 | pad_widths = [(0, 0)] * array.ndim
77 | pad_widths[axis] = (0, length - array.shape[axis])
78 | array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
79 | else:
80 | if array.shape[axis] > length:
81 | array = array.take(indices=range(length), axis=axis)
82 |
83 | if array.shape[axis] < length:
84 | pad_widths = [(0, 0)] * array.ndim
85 | pad_widths[axis] = (0, length - array.shape[axis])
86 | array = np.pad(array, pad_widths)
87 |
88 | return array
89 |
90 |
91 | @lru_cache(maxsize=None)
92 | def mel_filters(device, n_mels: int) -> torch.Tensor:
93 | """
94 | load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
95 | Allows decoupling librosa dependency; saved using:
96 |
97 | np.savez_compressed(
98 | "mel_filters.npz",
99 | mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
100 | mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
101 | )
102 | """
103 | assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
104 |
105 | filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
106 | with np.load(filters_path, allow_pickle=False) as f:
107 | return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
108 |
109 |
110 | def log_mel_spectrogram(
111 | audio: Union[str, np.ndarray, torch.Tensor],
112 | n_mels: int = 80,
113 | padding: int = 0,
114 | device: Optional[Union[str, torch.device]] = None,
115 | ):
116 | """
117 | Compute the log-Mel spectrogram of
118 |
119 | Parameters
120 | ----------
121 | audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
122 | The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
123 |
124 | n_mels: int
125 | The number of Mel-frequency filters, only 80 is supported
126 |
127 | padding: int
128 | Number of zero samples to pad to the right
129 |
130 | device: Optional[Union[str, torch.device]]
131 | If given, the audio tensor is moved to this device before STFT
132 |
133 | Returns
134 | -------
135 | torch.Tensor, shape = (80, n_frames)
136 | A Tensor that contains the Mel spectrogram
137 | """
138 | if not torch.is_tensor(audio):
139 | if isinstance(audio, str):
140 | audio = load_audio(audio)
141 | audio = torch.from_numpy(audio)
142 |
143 | if device is not None:
144 | audio = audio.to(device)
145 | if padding > 0:
146 | audio = F.pad(audio, (0, padding))
147 | window = torch.hann_window(N_FFT).to(audio.device)
148 | stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
149 | magnitudes = stft[..., :-1].abs() ** 2
150 |
151 | filters = mel_filters(audio.device, n_mels)
152 | mel_spec = filters @ magnitudes
153 |
154 | log_spec = torch.clamp(mel_spec, min=1e-10).log10()
155 | log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
156 | log_spec = (log_spec + 4.0) / 4.0
157 | return log_spec
158 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import webbrowser
3 |
4 | import customtkinter as ctk
5 | from PIL import Image
6 |
7 | from src.logic.settings import SettingsHandler
8 | from src.ui.add_subtitles import AddSubtitlesUI
9 | from src.ui.live_transcribe import LiveTranscribeUI
10 | from src.ui.settings import SettingsUI
11 | from src.ui.style import FONTS
12 | from src.ui.transcribe import TranscribeUI
13 | from src.ui.translate import TranslateUI
14 |
15 | current_path = os.path.dirname(os.path.realpath(__file__))
16 | icon_path = f"{current_path}{os.path.sep}assets{os.path.sep}icons{os.path.sep}"
17 |
18 | icons = {
19 | "logo": ctk.CTkImage(dark_image=Image.open(f"{icon_path}logo.png"),
20 | light_image=Image.open(f"{icon_path}logo.png"), size=(50, 50)),
21 | "close": ctk.CTkImage(dark_image=Image.open(f"{icon_path}close_dark.png"),
22 | light_image=Image.open(f"{icon_path}close_light.png"), size=(30, 30)),
23 | "audio_file": ctk.CTkImage(dark_image=Image.open(f"{icon_path}audio_file_dark.png"),
24 | light_image=Image.open(f"{icon_path}audio_file_light.png"),
25 | size=(30, 30)),
26 | "translation": ctk.CTkImage(dark_image=Image.open(f"{icon_path}translation_dark.png"),
27 | light_image=Image.open(f"{icon_path}translation_light.png"),
28 | size=(30, 30)),
29 | "microphone": ctk.CTkImage(dark_image=Image.open(f"{icon_path}microphone_dark.png"),
30 | light_image=Image.open(f"{icon_path}microphone_light.png"),
31 | size=(30, 30)),
32 | "subtitles": ctk.CTkImage(dark_image=Image.open(f"{icon_path}subtitles_dark.png"),
33 | light_image=Image.open(f"{icon_path}subtitles_light.png"),
34 | size=(30, 30)),
35 | "paypal": ctk.CTkImage(dark_image=Image.open(f"{icon_path}paypal_dark.png"),
36 | light_image=Image.open(f"{icon_path}paypal_light.png"), size=(30, 30)),
37 | "settings": ctk.CTkImage(dark_image=Image.open(f"{icon_path}settings_dark.png"),
38 | light_image=Image.open(f"{icon_path}settings_light.png"),
39 | size=(30, 30)),
40 | "help": ctk.CTkImage(dark_image=Image.open(f"{icon_path}help_dark.png"),
41 | light_image=Image.open(f"{icon_path}help_light.png"), size=(30, 30)),
42 | "github": ctk.CTkImage(dark_image=Image.open(f"{icon_path}github_dark.png"),
43 | light_image=Image.open(f"{icon_path}github_light.png"), size=(30, 30))
44 | }
45 |
46 | logo = f"{icon_path}logo.ico"
47 |
48 | btn = {
49 | "width": 280,
50 | "height": 116,
51 | "text_color": ("#FFFFFF", "#DFE1E5"),
52 | "compound": "left",
53 | "font": ("Inter", 16)
54 | }
55 |
56 | secondary_btn = {
57 | "width": 280,
58 | "height": 100,
59 | "fg_color": ("#221D21", "#2B2D30"),
60 | "hover": False,
61 | "border_width": 0,
62 | "text_color": ("#FFFFFF", "#DFE1E5"),
63 | "compound": "left",
64 | "font": ("Inter", 16)
65 | }
66 |
67 | link_btn = {
68 | "width": 140,
69 | "height": 80,
70 | "fg_color": "transparent",
71 | "hover_color": ("#D3D5DB", "#2B2D30"),
72 | "border_color": ("#D3D5DB", "#2B2D30"),
73 | "border_width": 2,
74 | "text_color": ("#000000", "#DFE1E5"),
75 | "compound": "left",
76 | "font": ("Inter", 16)
77 | }
78 |
79 |
80 | def help_link():
81 | webbrowser.open("https://github.com/rudymohammadbali/Whisper-Transcriber/discussions/categories/q-a")
82 |
83 |
84 | def github_link():
85 | webbrowser.open("https://github.com/rudymohammadbali")
86 |
87 |
88 | def paypal_link():
89 | webbrowser.open("https://www.paypal.com/paypalme/iamironman0")
90 |
91 |
92 | class Testing(ctk.CTk):
93 | def __init__(self):
94 | super().__init__()
95 | self.geometry("620x720")
96 | self.resizable(False, False)
97 | self.iconbitmap(logo)
98 | self.title("Whisper Transcriber")
99 |
100 | settings_handler = SettingsHandler()
101 | settings = settings_handler.load_settings()
102 | theme = settings.get("theme")
103 | color_theme = settings.get("color_theme")
104 |
105 | ctk.set_appearance_mode(theme)
106 | ctk.set_default_color_theme(color_theme)
107 |
108 | self.main_ui()
109 |
110 | def main_ui(self):
111 | title = ctk.CTkLabel(self, text="Welcome to Whisper Transcriber", text_color=("#000000", "#DFE1E5"),
112 | font=FONTS["title_bold"], image=icons["logo"], compound="top")
113 | title.grid(row=0, column=0, padx=20, pady=20, sticky="nsew", columnspan=2)
114 |
115 | label = ctk.CTkLabel(self, text="Select a Service", text_color=("#000000", "#DFE1E5"), font=FONTS["subtitle"])
116 | label.grid(row=1, column=0, padx=20, pady=20, sticky="w")
117 |
118 | btn_1 = ctk.CTkButton(self, text="Transcribe Audio", **btn, image=icons["audio_file"],
119 | command=lambda: TranscribeUI(parent=self))
120 | btn_1.grid(row=2, column=0, padx=(20, 10), pady=10, sticky="nsew")
121 | btn_2 = ctk.CTkButton(self, text="Translate Audio", **btn, image=icons["translation"],
122 | command=lambda: TranslateUI(parent=self))
123 | btn_2.grid(row=2, column=1, padx=(10, 20), pady=10, sticky="nsew")
124 |
125 | btn_3 = ctk.CTkButton(self, text="Live Transcriber", **btn, image=icons["microphone"],
126 | command=lambda: LiveTranscribeUI(parent=self))
127 | btn_3.grid(row=3, column=0, padx=(20, 10), pady=10, sticky="nsew")
128 | btn_4 = ctk.CTkButton(self, text="Add Subtitle", **btn, image=icons["subtitles"],
129 | command=lambda: AddSubtitlesUI(parent=self))
130 | btn_4.grid(row=3, column=1, padx=(10, 20), pady=10, sticky="nsew")
131 |
132 | btn_5 = ctk.CTkButton(self, text="Settings", font=("Inter", 16), text_color=("#FFFFFF", "#DFE1E5"), width=280,
133 | height=100, image=icons["settings"], command=lambda: SettingsUI(parent=self))
134 | btn_5.grid(row=4, column=0, padx=(20, 10), pady=20, sticky="nsew")
135 | btn_6 = ctk.CTkButton(self, text="Support on PayPal", **secondary_btn, image=icons["paypal"],
136 | command=paypal_link)
137 | btn_6.grid(row=4, column=1, padx=(10, 20), pady=20, sticky="nsew")
138 |
139 | btn_7 = ctk.CTkButton(self, text="Github", **link_btn, image=icons["github"], command=github_link)
140 | btn_7.grid(row=5, column=0, padx=20, pady=20, sticky="nsew")
141 | btn_8 = ctk.CTkButton(self, text="Help", **link_btn, image=icons["help"], command=help_link)
142 | btn_8.grid(row=5, column=1, padx=20, pady=20, sticky="nsew")
143 |
144 |
145 | app = Testing()
146 | app.mainloop()
147 |
--------------------------------------------------------------------------------
/src/whisper/__init__.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import io
3 | import os
4 | import urllib
5 | import warnings
6 | from typing import List, Optional, Union
7 |
8 | import torch
9 | from tqdm import tqdm
10 |
11 | from .audio import load_audio, log_mel_spectrogram, pad_or_trim
12 | from .decoding import DecodingOptions, DecodingResult, decode, detect_language
13 | from .model import ModelDimensions, Whisper
14 | from .transcribe import transcribe
15 | from .version import __version__
16 |
17 | _MODELS = {
18 | "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
19 | "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
20 | "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
21 | "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
22 | "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
23 | "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
24 | "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
25 | "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
26 | "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
27 | "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
28 | "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
29 | "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
30 | }
31 |
32 | # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
33 | # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
34 | _ALIGNMENT_HEADS = {
35 | "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
36 | "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
37 | "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
38 | "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
40 | "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P%R7%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
43 | "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
45 | "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
46 | "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
47 | }
48 |
49 |
50 | def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
51 | os.makedirs(root, exist_ok=True)
52 |
53 | expected_sha256 = url.split("/")[-2]
54 | download_target = os.path.join(root, os.path.basename(url))
55 |
56 | if os.path.exists(download_target) and not os.path.isfile(download_target):
57 | raise RuntimeError(f"{download_target} exists and is not a regular file")
58 |
59 | if os.path.isfile(download_target):
60 | with open(download_target, "rb") as f:
61 | model_bytes = f.read()
62 | if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
63 | return model_bytes if in_memory else download_target
64 | else:
65 | warnings.warn(
66 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
67 | )
68 |
69 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
70 | with tqdm(
71 | total=int(source.info().get("Content-Length")),
72 | ncols=80,
73 | unit="iB",
74 | unit_scale=True,
75 | unit_divisor=1024,
76 | ) as loop:
77 | while True:
78 | buffer = source.read(8192)
79 | if not buffer:
80 | break
81 |
82 | output.write(buffer)
83 | loop.update(len(buffer))
84 |
85 | model_bytes = open(download_target, "rb").read()
86 | if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
87 | raise RuntimeError(
88 | "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
89 | )
90 |
91 | return model_bytes if in_memory else download_target
92 |
93 |
94 | def available_models() -> List[str]:
95 | """Returns the names of available models"""
96 | return list(_MODELS.keys())
97 |
98 |
99 | def load_model(
100 | name: str,
101 | device: Optional[Union[str, torch.device]] = None,
102 | download_root: str = None,
103 | in_memory: bool = False,
104 | ) -> Whisper:
105 | """
106 | Load a Whisper ASR model
107 |
108 | Parameters
109 | ----------
110 | name : str
111 | one of the official model names listed by `whisper.available_models()`, or
112 | path to a model checkpoint containing the model dimensions and the model state_dict.
113 | device : Union[str, torch.device]
114 | the PyTorch device to put the model into
115 | download_root: str
116 | path to download the model files; by default, it uses "~/.cache/whisper"
117 | in_memory: bool
118 | whether to preload the model weights into host memory
119 |
120 | Returns
121 | -------
122 | model : Whisper
123 | The Whisper ASR model instance
124 | """
125 |
126 | if device is None:
127 | device = "cuda" if torch.cuda.is_available() else "cpu"
128 | if download_root is None:
129 | default = os.path.join(os.path.expanduser("~"), ".cache")
130 | download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
131 |
132 | if name in _MODELS:
133 | checkpoint_file = _download(_MODELS[name], download_root, in_memory)
134 | alignment_heads = _ALIGNMENT_HEADS[name]
135 | elif os.path.isfile(name):
136 | checkpoint_file = open(name, "rb").read() if in_memory else name
137 | alignment_heads = None
138 | else:
139 | raise RuntimeError(
140 | f"Model {name} not found; available models = {available_models()}"
141 | )
142 |
143 | with (
144 | io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
145 | ) as fp:
146 | checkpoint = torch.load(fp, map_location=device)
147 | del checkpoint_file
148 |
149 | dims = ModelDimensions(**checkpoint["dims"])
150 | model = Whisper(dims)
151 | model.load_state_dict(checkpoint["model_state_dict"])
152 |
153 | if alignment_heads is not None:
154 | model.set_alignment_heads(alignment_heads)
155 |
156 | return model.to(device)
157 |
--------------------------------------------------------------------------------
/src/ui/ctk_tooltip.py:
--------------------------------------------------------------------------------
1 | """
2 | CTkToolTip Widget
3 | version: 0.8
4 | """
5 |
6 | import sys
7 | import time
8 | from tkinter import Toplevel, Frame
9 |
10 | import customtkinter
11 |
12 |
13 | class CTkToolTip(Toplevel):
14 | """
15 | Creates a ToolTip (pop-up) widget for customtkinter.
16 | """
17 |
18 | def __init__(
19 | self,
20 | widget: any = None,
21 | message: str = None,
22 | delay: float = 0.2,
23 | follow: bool = True,
24 | x_offset: int = +20,
25 | y_offset: int = +10,
26 | bg_color: str = None,
27 | corner_radius: int = 10,
28 | border_width: int = 0,
29 | border_color: str = None,
30 | alpha: float = 0.95,
31 | padding: tuple = (10, 2),
32 | **message_kwargs):
33 |
34 | super().__init__()
35 |
36 | self.widget = widget
37 |
38 | self.withdraw()
39 |
40 | # Disable ToolTip's title bar
41 | self.overrideredirect(True)
42 |
43 | if sys.platform.startswith("win"):
44 | self.transparent_color = self.widget._apply_appearance_mode(
45 | customtkinter.ThemeManager.theme["CTkToplevel"]["fg_color"])
46 | self.attributes("-transparentcolor", self.transparent_color)
47 | self.transient()
48 | elif sys.platform.startswith("darwin"):
49 | self.transparent_color = 'systemTransparent'
50 | self.attributes("-transparent", True)
51 | self.transient(self.master)
52 | else:
53 | self.transparent_color = '#000001'
54 | corner_radius = 0
55 | self.transient()
56 |
57 | self.resizable(width=True, height=True)
58 |
59 | # Make the background transparent
60 | self.config(background=self.transparent_color)
61 |
62 | # StringVar instance for msg string
63 | self.messageVar = customtkinter.StringVar()
64 | self.message = message
65 | self.messageVar.set(self.message)
66 |
67 | self.delay = delay
68 | self.follow = follow
69 | self.x_offset = x_offset
70 | self.y_offset = y_offset
71 | self.corner_radius = corner_radius
72 | self.alpha = alpha
73 | self.border_width = border_width
74 | self.padding = padding
75 | self.bg_color = customtkinter.ThemeManager.theme["CTkFrame"]["fg_color"] if bg_color is None else bg_color
76 | self.border_color = border_color
77 | self.disable = False
78 |
79 | # visibility status of the ToolTip inside|outside|visible
80 | self.status = "outside"
81 | self.last_moved = 0
82 | self.attributes('-alpha', self.alpha)
83 |
84 | if sys.platform.startswith("win"):
85 | if self.widget._apply_appearance_mode(self.bg_color) == self.transparent_color:
86 | self.transparent_color = "#000001"
87 | self.config(background=self.transparent_color)
88 | self.attributes("-transparentcolor", self.transparent_color)
89 |
90 | # Add the message widget inside the tooltip
91 | self.transparent_frame = Frame(self, bg=self.transparent_color)
92 | self.transparent_frame.pack(padx=0, pady=0, fill="both", expand=True)
93 |
94 | self.frame = customtkinter.CTkFrame(self.transparent_frame, bg_color=self.transparent_color,
95 | corner_radius=self.corner_radius,
96 | border_width=self.border_width, fg_color=self.bg_color,
97 | border_color=self.border_color)
98 | self.frame.pack(padx=0, pady=0, fill="both", expand=True)
99 |
100 | self.message_label = customtkinter.CTkLabel(self.frame, textvariable=self.messageVar, **message_kwargs)
101 | self.message_label.pack(fill="both", padx=self.padding[0] + self.border_width,
102 | pady=self.padding[1] + self.border_width, expand=True)
103 |
104 | if self.widget.winfo_name() != "tk":
105 | if self.frame.cget("fg_color") == self.widget.cget("bg_color"):
106 | if not bg_color:
107 | self._top_fg_color = self.frame._apply_appearance_mode(
108 | customtkinter.ThemeManager.theme["CTkFrame"]["top_fg_color"])
109 | if self._top_fg_color != self.transparent_color:
110 | self.frame.configure(fg_color=self._top_fg_color)
111 |
112 | # Add bindings to the widget without overriding the existing ones
113 | self.widget.bind("", self.on_enter, add="+")
114 | self.widget.bind("", self.on_leave, add="+")
115 | self.widget.bind("", self.on_enter, add="+")
116 | self.widget.bind("", self.on_enter, add="+")
117 | self.widget.bind("", lambda _: self.hide(), add="+")
118 |
119 | def show(self) -> None:
120 | """
121 | Enable the widget.
122 | """
123 | self.disable = False
124 |
125 | def on_enter(self, event) -> None:
126 | """
127 | Processes motion within the widget including entering and moving.
128 | """
129 |
130 | if self.disable:
131 | return
132 | self.last_moved = time.time()
133 |
134 | # Set the status as inside for the very first time
135 | if self.status == "outside":
136 | self.status = "inside"
137 |
138 | # If the follow flag is not set, motion within the widget will make the ToolTip dissapear
139 | if not self.follow:
140 | self.status = "inside"
141 | self.withdraw()
142 |
143 | # Calculate available space on the right side of the widget relative to the screen
144 | root_width = self.winfo_screenwidth()
145 | widget_x = event.x_root
146 | space_on_right = root_width - widget_x
147 |
148 | # Calculate the width of the tooltip's text based on the length of the message string
149 | text_width = self.message_label.winfo_reqwidth()
150 |
151 | # Calculate the offset based on available space and text width to avoid going off-screen on the right side
152 | offset_x = self.x_offset
153 | if space_on_right < text_width + 20: # Adjust the threshold as needed
154 | offset_x = -text_width - 20 # Negative offset when space is limited on the right side
155 |
156 | # Offsets the ToolTip using the coordinates od an event as an origin
157 | self.geometry(f"+{event.x_root + offset_x}+{event.y_root + self.y_offset}")
158 |
159 | # Time is in integer: milliseconds
160 | self.after(int(self.delay * 1000), self._show)
161 |
162 | def on_leave(self, event=None) -> None:
163 | """
164 | Hides the ToolTip temporarily.
165 | """
166 |
167 | if self.disable: return
168 | self.status = "outside"
169 | self.withdraw()
170 |
171 | def _show(self) -> None:
172 | """
173 | Displays the ToolTip.
174 | """
175 |
176 | if not self.widget.winfo_exists():
177 | self.hide()
178 | self.destroy()
179 |
180 | if self.status == "inside" and time.time() - self.last_moved >= self.delay:
181 | self.status = "visible"
182 | self.deiconify()
183 |
184 | def hide(self) -> None:
185 | """
186 | Disable the widget from appearing.
187 | """
188 | if not self.winfo_exists():
189 | return
190 | self.withdraw()
191 | self.disable = True
192 |
193 | def is_disabled(self) -> None:
194 | """
195 | Return the window state
196 | """
197 | return self.disable
198 |
199 | def get(self) -> None:
200 | """
201 | Returns the text on the tooltip.
202 | """
203 | return self.messageVar.get()
204 |
205 | def configure(self, message: str = None, delay: float = None, bg_color: str = None, **kwargs):
206 | """
207 | Set new message or configure the label parameters.
208 | """
209 | if delay: self.delay = delay
210 | if bg_color: self.frame.configure(fg_color=bg_color)
211 |
212 | self.messageVar.set(message)
213 | self.message_label.configure(**kwargs)
214 |
--------------------------------------------------------------------------------
/src/whisper/model.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import gzip
3 | from dataclasses import dataclass
4 | from typing import Dict, Iterable, Optional
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import Tensor, nn
10 |
11 | from .decoding import decode as decode_function
12 | from .decoding import detect_language as detect_language_function
13 | from .transcribe import transcribe as transcribe_function
14 |
15 |
16 | @dataclass
17 | class ModelDimensions:
18 | n_mels: int
19 | n_audio_ctx: int
20 | n_audio_state: int
21 | n_audio_head: int
22 | n_audio_layer: int
23 | n_vocab: int
24 | n_text_ctx: int
25 | n_text_state: int
26 | n_text_head: int
27 | n_text_layer: int
28 |
29 |
30 | class LayerNorm(nn.LayerNorm):
31 | def forward(self, x: Tensor) -> Tensor:
32 | return super().forward(x.float()).type(x.dtype)
33 |
34 |
35 | class Linear(nn.Linear):
36 | def forward(self, x: Tensor) -> Tensor:
37 | return F.linear(
38 | x,
39 | self.weight.to(x.dtype),
40 | None if self.bias is None else self.bias.to(x.dtype),
41 | )
42 |
43 |
44 | class Conv1d(nn.Conv1d):
45 | def _conv_forward(
46 | self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
47 | ) -> Tensor:
48 | return super()._conv_forward(
49 | x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
50 | )
51 |
52 |
53 | def sinusoids(length, channels, max_timescale=10000):
54 | """Returns sinusoids for positional embedding"""
55 | assert channels % 2 == 0
56 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
57 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
58 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
59 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
60 |
61 |
62 | class MultiHeadAttention(nn.Module):
63 | def __init__(self, n_state: int, n_head: int):
64 | super().__init__()
65 | self.n_head = n_head
66 | self.query = Linear(n_state, n_state)
67 | self.key = Linear(n_state, n_state, bias=False)
68 | self.value = Linear(n_state, n_state)
69 | self.out = Linear(n_state, n_state)
70 |
71 | def forward(
72 | self,
73 | x: Tensor,
74 | xa: Optional[Tensor] = None,
75 | mask: Optional[Tensor] = None,
76 | kv_cache: Optional[dict] = None,
77 | ):
78 | q = self.query(x)
79 |
80 | if kv_cache is None or xa is None or self.key not in kv_cache:
81 | # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
82 | # otherwise, perform key/value projections for self- or cross-attention as usual.
83 | k = self.key(x if xa is None else xa)
84 | v = self.value(x if xa is None else xa)
85 | else:
86 | # for cross-attention, calculate keys and values once and reuse in subsequent calls.
87 | k = kv_cache[self.key]
88 | v = kv_cache[self.value]
89 |
90 | wv, qk = self.qkv_attention(q, k, v, mask)
91 | return self.out(wv), qk
92 |
93 | def qkv_attention(
94 | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
95 | ):
96 | n_batch, n_ctx, n_state = q.shape
97 | scale = (n_state // self.n_head) ** -0.25
98 | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
99 | k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
100 | v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
101 |
102 | qk = q @ k
103 | if mask is not None:
104 | qk = qk + mask[:n_ctx, :n_ctx]
105 | qk = qk.float()
106 |
107 | w = F.softmax(qk, dim=-1).to(q.dtype)
108 | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
109 |
110 |
111 | class ResidualAttentionBlock(nn.Module):
112 | def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
113 | super().__init__()
114 |
115 | self.attn = MultiHeadAttention(n_state, n_head)
116 | self.attn_ln = LayerNorm(n_state)
117 |
118 | self.cross_attn = (
119 | MultiHeadAttention(n_state, n_head) if cross_attention else None
120 | )
121 | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
122 |
123 | n_mlp = n_state * 4
124 | self.mlp = nn.Sequential(
125 | Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
126 | )
127 | self.mlp_ln = LayerNorm(n_state)
128 |
129 | def forward(
130 | self,
131 | x: Tensor,
132 | xa: Optional[Tensor] = None,
133 | mask: Optional[Tensor] = None,
134 | kv_cache: Optional[dict] = None,
135 | ):
136 | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
137 | if self.cross_attn:
138 | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
139 | x = x + self.mlp(self.mlp_ln(x))
140 | return x
141 |
142 |
143 | class AudioEncoder(nn.Module):
144 | def __init__(
145 | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
146 | ):
147 | super().__init__()
148 | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
149 | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
150 | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
151 |
152 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
153 | [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
154 | )
155 | self.ln_post = LayerNorm(n_state)
156 |
157 | def forward(self, x: Tensor):
158 | """
159 | x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
160 | the mel spectrogram of the audio
161 | """
162 | x = F.gelu(self.conv1(x))
163 | x = F.gelu(self.conv2(x))
164 | x = x.permute(0, 2, 1)
165 |
166 | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
167 | x = (x + self.positional_embedding).to(x.dtype)
168 |
169 | for block in self.blocks:
170 | x = block(x)
171 |
172 | x = self.ln_post(x)
173 | return x
174 |
175 |
176 | class TextDecoder(nn.Module):
177 | def __init__(
178 | self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
179 | ):
180 | super().__init__()
181 |
182 | self.token_embedding = nn.Embedding(n_vocab, n_state)
183 | self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
184 |
185 | self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
186 | [
187 | ResidualAttentionBlock(n_state, n_head, cross_attention=True)
188 | for _ in range(n_layer)
189 | ]
190 | )
191 | self.ln = LayerNorm(n_state)
192 |
193 | mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
194 | self.register_buffer("mask", mask, persistent=False)
195 |
196 | def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
197 | """
198 | x : torch.LongTensor, shape = (batch_size, <= n_ctx)
199 | the text tokens
200 | xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
201 | the encoded audio features to be attended on
202 | """
203 | offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
204 | x = (
205 | self.token_embedding(x)
206 | + self.positional_embedding[offset : offset + x.shape[-1]]
207 | )
208 | x = x.to(xa.dtype)
209 |
210 | for block in self.blocks:
211 | x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
212 |
213 | x = self.ln(x)
214 | logits = (
215 | x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
216 | ).float()
217 |
218 | return logits
219 |
220 |
221 | class Whisper(nn.Module):
222 | def __init__(self, dims: ModelDimensions):
223 | super().__init__()
224 | self.dims = dims
225 | self.encoder = AudioEncoder(
226 | self.dims.n_mels,
227 | self.dims.n_audio_ctx,
228 | self.dims.n_audio_state,
229 | self.dims.n_audio_head,
230 | self.dims.n_audio_layer,
231 | )
232 | self.decoder = TextDecoder(
233 | self.dims.n_vocab,
234 | self.dims.n_text_ctx,
235 | self.dims.n_text_state,
236 | self.dims.n_text_head,
237 | self.dims.n_text_layer,
238 | )
239 | # use the last half among the decoder layers for time alignment by default;
240 | # to use a specific set of heads, see `set_alignment_heads()` below.
241 | all_heads = torch.zeros(
242 | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
243 | )
244 | all_heads[self.dims.n_text_layer // 2 :] = True
245 | self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
246 |
247 | def set_alignment_heads(self, dump: bytes):
248 | array = np.frombuffer(
249 | gzip.decompress(base64.b85decode(dump)), dtype=bool
250 | ).copy()
251 | mask = torch.from_numpy(array).reshape(
252 | self.dims.n_text_layer, self.dims.n_text_head
253 | )
254 | self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
255 |
256 | def embed_audio(self, mel: torch.Tensor):
257 | return self.encoder(mel)
258 |
259 | def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
260 | return self.decoder(tokens, audio_features)
261 |
262 | def forward(
263 | self, mel: torch.Tensor, tokens: torch.Tensor
264 | ) -> Dict[str, torch.Tensor]:
265 | return self.decoder(tokens, self.encoder(mel))
266 |
267 | @property
268 | def device(self):
269 | return next(self.parameters()).device
270 |
271 | @property
272 | def is_multilingual(self):
273 | return self.dims.n_vocab >= 51865
274 |
275 | @property
276 | def num_languages(self):
277 | return self.dims.n_vocab - 51765 - int(self.is_multilingual)
278 |
279 | def install_kv_cache_hooks(self, cache: Optional[dict] = None):
280 | """
281 | The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
282 | tensors calculated for the previous positions. This method returns a dictionary that stores
283 | all caches, and the necessary hooks for the key and value projection modules that save the
284 | intermediate tensors to be reused during later calculations.
285 |
286 | Returns
287 | -------
288 | cache : Dict[nn.Module, torch.Tensor]
289 | A dictionary object mapping the key/value projection modules to its cache
290 | hooks : List[RemovableHandle]
291 | List of PyTorch RemovableHandle objects to stop the hooks to be called
292 | """
293 | cache = {**cache} if cache is not None else {}
294 | hooks = []
295 |
296 | def save_to_cache(module, _, output):
297 | if module not in cache or output.shape[1] > self.dims.n_text_ctx:
298 | # save as-is, for the first token or cross attention
299 | cache[module] = output
300 | else:
301 | cache[module] = torch.cat([cache[module], output], dim=1).detach()
302 | return cache[module]
303 |
304 | def install_hooks(layer: nn.Module):
305 | if isinstance(layer, MultiHeadAttention):
306 | hooks.append(layer.key.register_forward_hook(save_to_cache))
307 | hooks.append(layer.value.register_forward_hook(save_to_cache))
308 |
309 | self.decoder.apply(install_hooks)
310 | return cache, hooks
311 |
312 | detect_language = detect_language_function
313 | transcribe = transcribe_function
314 | decode = decode_function
315 |
--------------------------------------------------------------------------------
/src/whisper/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | import sys
5 | import zlib
6 | from typing import Callable, Optional, TextIO
7 |
8 | system_encoding = sys.getdefaultencoding()
9 |
10 | if system_encoding != "utf-8":
11 |
12 | def make_safe(string):
13 | # replaces any character not representable using the system default encoding with an '?',
14 | # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
15 | return string.encode(system_encoding, errors="replace").decode(system_encoding)
16 |
17 | else:
18 |
19 | def make_safe(string):
20 | # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
21 | return string
22 |
23 |
24 | def exact_div(x, y):
25 | assert x % y == 0
26 | return x // y
27 |
28 |
29 | def str2bool(string):
30 | str2val = {"True": True, "False": False}
31 | if string in str2val:
32 | return str2val[string]
33 | else:
34 | raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
35 |
36 |
37 | def optional_int(string):
38 | return None if string == "None" else int(string)
39 |
40 |
41 | def optional_float(string):
42 | return None if string == "None" else float(string)
43 |
44 |
45 | def compression_ratio(text) -> float:
46 | text_bytes = text.encode("utf-8")
47 | return len(text_bytes) / len(zlib.compress(text_bytes))
48 |
49 |
50 | def format_timestamp(
51 | seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
52 | ):
53 | assert seconds >= 0, "non-negative timestamp expected"
54 | milliseconds = round(seconds * 1000.0)
55 |
56 | hours = milliseconds // 3_600_000
57 | milliseconds -= hours * 3_600_000
58 |
59 | minutes = milliseconds // 60_000
60 | milliseconds -= minutes * 60_000
61 |
62 | seconds = milliseconds // 1_000
63 | milliseconds -= seconds * 1_000
64 |
65 | hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
66 | return (
67 | f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
68 | )
69 |
70 |
71 | class ResultWriter:
72 | extension: str
73 |
74 | def __init__(self, output_dir: str):
75 | self.output_dir = output_dir
76 |
77 | def __call__(
78 | self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
79 | ):
80 | audio_basename = os.path.basename(audio_path)
81 | audio_basename = os.path.splitext(audio_basename)[0]
82 | output_path = os.path.join(
83 | self.output_dir, audio_basename + "." + self.extension
84 | )
85 |
86 | with open(output_path, "w", encoding="utf-8") as f:
87 | self.write_result(result, file=f, options=options, **kwargs)
88 |
89 | def write_result(
90 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
91 | ):
92 | raise NotImplementedError
93 |
94 |
95 | class WriteTXT(ResultWriter):
96 | extension: str = "txt"
97 |
98 | def write_result(
99 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
100 | ):
101 | for segment in result["segments"]:
102 | print(segment["text"].strip(), file=file, flush=True)
103 |
104 |
105 | class SubtitlesWriter(ResultWriter):
106 | always_include_hours: bool
107 | decimal_marker: str
108 |
109 | def iterate_result(
110 | self,
111 | result: dict,
112 | options: Optional[dict] = None,
113 | *,
114 | max_line_width: Optional[int] = None,
115 | max_line_count: Optional[int] = None,
116 | highlight_words: bool = False,
117 | max_words_per_line: Optional[int] = None,
118 | ):
119 | options = options or {}
120 | max_line_width = max_line_width or options.get("max_line_width")
121 | max_line_count = max_line_count or options.get("max_line_count")
122 | highlight_words = highlight_words or options.get("highlight_words", False)
123 | max_words_per_line = max_words_per_line or options.get("max_words_per_line")
124 | preserve_segments = max_line_count is None or max_line_width is None
125 | max_line_width = max_line_width or 1000
126 | max_words_per_line = max_words_per_line or 1000
127 |
128 | def iterate_subtitles():
129 | line_len = 0
130 | line_count = 1
131 | # the next subtitle to yield (a list of word timings with whitespace)
132 | subtitle: list[dict] = []
133 | last = result["segments"][0]["words"][0]["start"]
134 | for segment in result["segments"]:
135 | chunk_index = 0
136 | words_count = max_words_per_line
137 | while chunk_index < len(segment["words"]):
138 | remaining_words = len(segment["words"]) - chunk_index
139 | if max_words_per_line > len(segment["words"]) - chunk_index:
140 | words_count = remaining_words
141 | for i, original_timing in enumerate(
142 | segment["words"][chunk_index : chunk_index + words_count]
143 | ):
144 | timing = original_timing.copy()
145 | long_pause = (
146 | not preserve_segments and timing["start"] - last > 3.0
147 | )
148 | has_room = line_len + len(timing["word"]) <= max_line_width
149 | seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
150 | if (
151 | line_len > 0
152 | and has_room
153 | and not long_pause
154 | and not seg_break
155 | ):
156 | # line continuation
157 | line_len += len(timing["word"])
158 | else:
159 | # new line
160 | timing["word"] = timing["word"].strip()
161 | if (
162 | len(subtitle) > 0
163 | and max_line_count is not None
164 | and (long_pause or line_count >= max_line_count)
165 | or seg_break
166 | ):
167 | # subtitle break
168 | yield subtitle
169 | subtitle = []
170 | line_count = 1
171 | elif line_len > 0:
172 | # line break
173 | line_count += 1
174 | timing["word"] = "\n" + timing["word"]
175 | line_len = len(timing["word"].strip())
176 | subtitle.append(timing)
177 | last = timing["start"]
178 | chunk_index += max_words_per_line
179 | if len(subtitle) > 0:
180 | yield subtitle
181 |
182 | if len(result["segments"]) > 0 and "words" in result["segments"][0]:
183 | for subtitle in iterate_subtitles():
184 | subtitle_start = self.format_timestamp(subtitle[0]["start"])
185 | subtitle_end = self.format_timestamp(subtitle[-1]["end"])
186 | subtitle_text = "".join([word["word"] for word in subtitle])
187 | if highlight_words:
188 | last = subtitle_start
189 | all_words = [timing["word"] for timing in subtitle]
190 | for i, this_word in enumerate(subtitle):
191 | start = self.format_timestamp(this_word["start"])
192 | end = self.format_timestamp(this_word["end"])
193 | if last != start:
194 | yield last, start, subtitle_text
195 |
196 | yield start, end, "".join(
197 | [
198 | re.sub(r"^(\s*)(.*)$", r"\1\2", word)
199 | if j == i
200 | else word
201 | for j, word in enumerate(all_words)
202 | ]
203 | )
204 | last = end
205 | else:
206 | yield subtitle_start, subtitle_end, subtitle_text
207 | else:
208 | for segment in result["segments"]:
209 | segment_start = self.format_timestamp(segment["start"])
210 | segment_end = self.format_timestamp(segment["end"])
211 | segment_text = segment["text"].strip().replace("-->", "->")
212 | yield segment_start, segment_end, segment_text
213 |
214 | def format_timestamp(self, seconds: float):
215 | return format_timestamp(
216 | seconds=seconds,
217 | always_include_hours=self.always_include_hours,
218 | decimal_marker=self.decimal_marker,
219 | )
220 |
221 |
222 | class WriteVTT(SubtitlesWriter):
223 | extension: str = "vtt"
224 | always_include_hours: bool = False
225 | decimal_marker: str = "."
226 |
227 | def write_result(
228 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
229 | ):
230 | print("WEBVTT\n", file=file)
231 | for start, end, text in self.iterate_result(result, options, **kwargs):
232 | print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
233 |
234 |
235 | class WriteSRT(SubtitlesWriter):
236 | extension: str = "srt"
237 | always_include_hours: bool = True
238 | decimal_marker: str = ","
239 |
240 | def write_result(
241 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
242 | ):
243 | for i, (start, end, text) in enumerate(
244 | self.iterate_result(result, options, **kwargs), start=1
245 | ):
246 | print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
247 |
248 |
249 | class WriteTSV(ResultWriter):
250 | """
251 | Write a transcript to a file in TSV (tab-separated values) format containing lines like:
252 | \t\t
253 |
254 | Using integer milliseconds as start and end times means there's no chance of interference from
255 | an environment setting a language encoding that causes the decimal in a floating point number
256 | to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
257 | """
258 |
259 | extension: str = "tsv"
260 |
261 | def write_result(
262 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
263 | ):
264 | print("start", "end", "text", sep="\t", file=file)
265 | for segment in result["segments"]:
266 | print(round(1000 * segment["start"]), file=file, end="\t")
267 | print(round(1000 * segment["end"]), file=file, end="\t")
268 | print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
269 |
270 |
271 | class WriteJSON(ResultWriter):
272 | extension: str = "json"
273 |
274 | def write_result(
275 | self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
276 | ):
277 | json.dump(result, file)
278 |
279 |
280 | def get_writer(
281 | output_format: str, output_dir: str
282 | ) -> Callable[[dict, TextIO, dict], None]:
283 | writers = {
284 | "txt": WriteTXT,
285 | "vtt": WriteVTT,
286 | "srt": WriteSRT,
287 | "tsv": WriteTSV,
288 | "json": WriteJSON,
289 | }
290 |
291 | if output_format == "all":
292 | all_writers = [writer(output_dir) for writer in writers.values()]
293 |
294 | def write_all(
295 | result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
296 | ):
297 | for writer in all_writers:
298 | writer(result, file, options, **kwargs)
299 |
300 | return write_all
301 |
302 | return writers[output_format](output_dir)
303 |
--------------------------------------------------------------------------------
/src/ui/add_subtitles.py:
--------------------------------------------------------------------------------
1 | import os
2 | import queue
3 | import threading
4 | import time
5 | import subprocess
6 |
7 | import customtkinter as ctk
8 | from customtkinter import filedialog as fd
9 | from mutagen import File, MutagenError
10 | from pydub import AudioSegment, exceptions
11 | from tkinterdnd2 import TkinterDnD, DND_ALL
12 | from ..whisper.utils import get_writer
13 |
14 | from .ctkAlert import CTkAlert
15 | from .ctkLoader import CTkLoader
16 | from .icons import icons
17 | from .style import FONTS
18 | from ..logic import Transcriber
19 |
20 |
21 | class CTk(ctk.CTkFrame, TkinterDnD.DnDWrapper):
22 | def __init__(self, *args, **kwargs):
23 | super().__init__(*args, **kwargs)
24 | self.TkdndVersion = TkinterDnD._require(self)
25 |
26 |
27 | class AddSubtitlesUI(CTk):
28 | def __init__(self, parent):
29 | super().__init__(master=parent, width=620, height=720, fg_color=("#F2F0EE", "#1E1F22"), border_width=0)
30 | self.grid_propagate(False)
31 | self.grid_columnconfigure(0, weight=1)
32 |
33 | title = ctk.CTkLabel(self, text=" Add Subtitles", font=FONTS["title"], image=icons["subtitle"],
34 | compound="left")
35 | title.grid(row=0, column=0, padx=20, pady=20, sticky="w")
36 |
37 | self.close_btn = ctk.CTkButton(self, text="", image=icons["close"], fg_color="transparent", hover=False, width=30,
38 | height=30, command=self.hide_transcribe_ui)
39 | self.close_btn.grid(row=0, column=1, padx=20, pady=20, sticky="e")
40 |
41 | self.main_frame = ctk.CTkFrame(self, fg_color="transparent")
42 | self.main_frame.grid(row=1, column=0, padx=20, pady=20, sticky="nsew", columnspan=2)
43 | self.main_frame.grid_columnconfigure(0, weight=1)
44 |
45 | self.master = parent
46 | self.drag_drop = None
47 | self.audio_path = None
48 | self.cancel_signal = False
49 | self.loader = None
50 | self.option_menu = None
51 | self.process = None
52 | self.queue = queue.Queue()
53 |
54 | self.default_widget()
55 |
56 | self.grid(row=0, column=0, sticky="nsew")
57 |
58 | def default_widget(self):
59 | label = ctk.CTkLabel(self.main_frame, text="No File Selected", font=FONTS["subtitle_bold"])
60 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w")
61 |
62 | self.drag_drop = ctk.CTkButton(self.main_frame, text="➕ \nDrag & Drop Here", width=500, height=250,
63 | text_color=("#000000", "#DFE1E5"), hover=False, fg_color="transparent",
64 | border_width=2, corner_radius=5, border_color=("#D3D5DB", "#2B2D30"),
65 | font=FONTS["normal"])
66 | self.drag_drop.grid(row=1, column=0, padx=0, pady=10, sticky="nsew")
67 |
68 | self.drag_drop.drop_target_register(DND_ALL)
69 | self.drag_drop.dnd_bind('<>', self.drop)
70 |
71 | label_or = ctk.CTkLabel(self.main_frame, text="Or", font=("", 14))
72 | label_or.grid(row=2, column=0, padx=0, pady=5, sticky="nsew")
73 |
74 | select_btn = ctk.CTkButton(self.main_frame, text="Browse Files", width=150, height=40,
75 | command=self.select_file_callback, font=FONTS["normal"])
76 | select_btn.grid(row=3, column=0, padx=200, pady=10, sticky="nsew")
77 |
78 | label_1 = ctk.CTkLabel(self.main_frame, text="Support Formats: MP4, MOV, WMV, AVI",
79 | fg_color=("#D3D5DB", "#2B2D30"), corner_radius=5, width=400, height=50,
80 | font=FONTS["small"])
81 | label_1.grid(row=4, column=0, padx=0, pady=20, sticky="sew")
82 |
83 | def task_widget(self):
84 | file_name, duration, file_size = self.get_audio_info(self.audio_path)
85 |
86 | label = ctk.CTkLabel(self.main_frame, text="Selected File", font=FONTS["subtitle_bold"])
87 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w")
88 |
89 | frame = ctk.CTkFrame(self.main_frame)
90 | frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew")
91 | frame.grid_columnconfigure(0, weight=1)
92 |
93 | label_1 = ctk.CTkLabel(frame, text="File Name", font=FONTS["normal"])
94 | label_1.grid(row=0, column=0, padx=20, pady=(20, 5), sticky="w")
95 | label_1_value = ctk.CTkLabel(frame, text=file_name, font=FONTS["small"])
96 | label_1_value.grid(row=0, column=1, padx=20, pady=(20, 5), sticky="e")
97 |
98 | label_2 = ctk.CTkLabel(frame, text="Duration", font=FONTS["normal"])
99 | label_2.grid(row=1, column=0, padx=20, pady=5, sticky="w")
100 | label_2_value = ctk.CTkLabel(frame, text=duration, font=FONTS["small"])
101 | label_2_value.grid(row=1, column=1, padx=20, pady=5, sticky="e")
102 |
103 | label_3 = ctk.CTkLabel(frame, text="Size", font=FONTS["normal"])
104 | label_3.grid(row=2, column=0, padx=20, pady=(5, 20), sticky="w")
105 | label_3_value = ctk.CTkLabel(frame, text=f"{file_size:.2f} MB", font=FONTS["small"])
106 | label_3_value.grid(row=2, column=1, padx=20, pady=(5, 20), sticky="e")
107 |
108 | start_btn = ctk.CTkButton(self.main_frame, text="Start Transcribing", height=40, command=self.start_callback,
109 | font=FONTS["normal"])
110 | start_btn.grid(row=2, column=0, padx=200, pady=20, sticky="nsew")
111 |
112 | def result_widget(self):
113 | result = self.queue.get()
114 | text = str(result["text"]).strip()
115 |
116 | result_label = ctk.CTkLabel(self.main_frame, text="Transcribed Text", font=FONTS["subtitle_bold"])
117 | result_label.grid(row=0, column=0, padx=10, pady=(20, 5), sticky="w")
118 |
119 | textbox = ctk.CTkTextbox(self.main_frame, width=580, height=200, border_width=2, font=FONTS["normal"])
120 | textbox.grid(row=1, column=0, padx=10, pady=(5, 20), sticky="nsew", columnspan=2)
121 | textbox.insert("0.0", text=text)
122 |
123 | download_label = ctk.CTkLabel(self.main_frame, text="Download Video with Subtitles", font=FONTS["subtitle_bold"])
124 | download_label.grid(row=2, column=0, padx=10, pady=(20, 5), sticky="w")
125 |
126 | download_btn = ctk.CTkButton(self.main_frame, text="Download", command=lambda: self.add_subtitle(result),
127 | font=FONTS["normal"], height=35)
128 | download_btn.grid(row=3, column=0, padx=10, pady=20, sticky="nsw")
129 |
130 | def add_subtitle(self, result):
131 | file_name = os.path.splitext(os.path.basename(self.audio_path))[0]
132 |
133 | output_video = fd.asksaveasfilename(
134 | parent=self,
135 | initialfile=f"{file_name}-subtitle",
136 | title="Export video with subtitle",
137 | defaultextension=".mp4",
138 | filetypes=[("Video files", "*.mp4 *.mov *.wmv *.avi")]
139 | )
140 | if output_video:
141 | self.close_btn.configure(state="disabled")
142 | widgets = self.main_frame.winfo_children()
143 | for widget in widgets:
144 | widget.destroy()
145 |
146 | thread = threading.Thread(target=self.subtitle_handler, args=(output_video, result))
147 | thread.start()
148 |
149 | def subtitle_handler(self, output_video, result):
150 | self.loader = CTkLoader(parent=self.master, title="Adding Subtitles", msg="Please wait...",
151 | cancel_func=self.kill_process)
152 | file_name = os.path.splitext(os.path.basename(self.audio_path))[0]
153 | input_video = self.audio_path
154 | writer = get_writer("srt", ".")
155 | subtitle_file_path = f"{file_name}.srt"
156 | writer(result, input_video, {"highlight_words": True, "max_line_count": 50, "max_line_width": 3})
157 |
158 | ffmpeg_command = [
159 | "ffmpeg",
160 | "-y",
161 | "-i", input_video,
162 | "-vf", f"subtitles={subtitle_file_path}",
163 | output_video
164 | ]
165 |
166 | self.process = subprocess.Popen(ffmpeg_command, stdout=subprocess.PIPE, universal_newlines=True, shell=True)
167 |
168 | self.process.wait()
169 |
170 | self.loader.hide_loader()
171 |
172 | if self.process.returncode == 0:
173 | CTkAlert(parent=self.master, status="success", title="Success",
174 | msg="Subtitle added successfully.")
175 | self.close_btn.configure(state="normal")
176 | self.after(1000, self.default_widget)
177 | else:
178 | CTkAlert(parent=self.master, status="error", title="Error",
179 | msg="Error adding subtitles to the video.")
180 | self.close_btn.configure(state="normal")
181 | self.result_widget()
182 |
183 | os.remove(subtitle_file_path)
184 |
185 | def kill_process(self):
186 | pass
187 |
188 | def start_callback(self):
189 | self.close_btn.configure(state="disabled")
190 | widgets = self.main_frame.winfo_children()
191 | for widget in widgets:
192 | widget.destroy()
193 |
194 | self.loader = CTkLoader(parent=self.master, title="Transcribing...", msg="Please wait...",
195 | cancel_func=self.set_signal)
196 | thread = threading.Thread(target=self.start_transcribing, args=(self.audio_path, self.check_signal))
197 | thread.start()
198 |
199 | def start_transcribing(self, audio_path, check_signal):
200 | transcriber = Transcriber(audio=audio_path)
201 | result = transcriber.audio_recognition(cancel_func=check_signal)
202 | self.loader.hide_loader()
203 | if result:
204 | self.queue.put(result)
205 | self.close_btn.configure(state="normal")
206 | self.result_widget()
207 |
208 | def set_signal(self):
209 | self.cancel_signal = True
210 |
211 | def check_signal(self):
212 | original_value = self.cancel_signal
213 |
214 | if self.cancel_signal:
215 | self.cancel_signal = False
216 | self.close_btn.configure(state="normal")
217 | self.after(1000, self.default_widget)
218 |
219 | return original_value
220 |
221 | def select_file_callback(self):
222 | file_path = fd.askopenfilename(
223 | filetypes=[("Video files", "*.mp4 *.mov *.wmv *.avi")])
224 | if file_path:
225 | audio_path = os.path.abspath(file_path)
226 | if self.is_streamable_audio(audio_path):
227 | self.audio_path = audio_path
228 | widgets = self.main_frame.winfo_children()
229 |
230 | for widget in widgets:
231 | widget.destroy()
232 | self.after(1000, self.task_widget)
233 | else:
234 | CTkAlert(parent=self.master, status="error", title="Error",
235 | msg="The chosen audio file is not valid or streamable.")
236 |
237 | @staticmethod
238 | def is_streamable_audio(audio_path):
239 | if not os.path.isfile(audio_path):
240 | return False
241 |
242 | try:
243 | audio = AudioSegment.from_file(audio_path)
244 | audio_info = File(audio_path)
245 | return len(audio) > 0 and audio_info.info.length > 0
246 | except (FileNotFoundError, exceptions.CouldntDecodeError, MutagenError):
247 | return False
248 |
249 | @staticmethod
250 | def get_audio_info(file_path):
251 | try:
252 | if not os.path.isfile(file_path):
253 | return None, None
254 |
255 | file_name = os.path.basename(file_path)
256 |
257 | audio_info = File(file_path)
258 | if audio_info is None:
259 | return None, None
260 |
261 | duration_seconds = audio_info.info.length
262 |
263 | duration_formatted = time.strftime("%H:%M:%S", time.gmtime(duration_seconds))
264 |
265 | file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
266 |
267 | return file_name, duration_formatted, file_size_mb
268 | except MutagenError:
269 | return None, None
270 |
271 | def drop(self, event):
272 | dropped_file = event.data.replace("{", "").replace("}", "")
273 | audio_path = os.path.abspath(dropped_file)
274 | if self.is_streamable_audio(audio_path):
275 | self.audio_path = audio_path
276 | widgets = self.main_frame.winfo_children()
277 |
278 | for widget in widgets:
279 | widget.destroy()
280 |
281 | self.after(1000, self.task_widget)
282 | else:
283 | CTkAlert(parent=self.master, status="error", title="Error",
284 | msg="The chosen audio file is not valid or streamable.")
285 |
286 | def hide_transcribe_ui(self):
287 | self.destroy()
288 |
--------------------------------------------------------------------------------
/src/ui/translate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import queue
3 | import threading
4 | import time
5 |
6 | import customtkinter as ctk
7 | from customtkinter import filedialog as fd
8 | from mutagen import File, MutagenError
9 | from pydub import AudioSegment, exceptions
10 | from tkinterdnd2 import TkinterDnD, DND_ALL
11 |
12 | from .ctkAlert import CTkAlert
13 | from .ctkLoader import CTkLoader
14 | from .ctkdropdown import CTkScrollableDropdownFrame
15 | from .icons import icons
16 | from .style import FONTS, DROPDOWN
17 | from ..logic import Transcriber
18 |
19 |
20 | class CTk(ctk.CTkFrame, TkinterDnD.DnDWrapper):
21 | def __init__(self, *args, **kwargs):
22 | super().__init__(*args, **kwargs)
23 | self.TkdndVersion = TkinterDnD._require(self)
24 |
25 |
26 | languages = ['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'assamese', 'aymara', 'azerbaijani',
27 | 'bambara', 'basque', 'belarusian', 'bengali', 'bhojpuri', 'bosnian', 'bulgarian', 'catalan',
28 | 'cebuano', 'chichewa', 'chinese (simplified)', 'chinese (traditional)', 'corsican', 'croatian',
29 | 'czech', 'danish', 'dhivehi', 'dogri', 'dutch', 'english', 'esperanto', 'estonian', 'ewe',
30 | 'filipino', 'finnish', 'french', 'frisian', 'galician', 'georgian', 'german', 'greek', 'guarani',
31 | 'gujarati', 'haitian creole', 'hausa', 'hawaiian', 'hebrew', 'hindi', 'hmong', 'hungarian',
32 | 'icelandic', 'igbo', 'ilocano', 'indonesian', 'irish', 'italian', 'japanese', 'javanese',
33 | 'kannada', 'kazakh', 'khmer', 'kinyarwanda', 'konkani', 'korean', 'krio', 'kurdish (kurmanji)',
34 | 'kurdish (sorani)', 'kyrgyz', 'lao', 'latin', 'latvian', 'lingala', 'lithuanian', 'luganda',
35 | 'luxembourgish', 'macedonian', 'maithili', 'malagasy', 'malay', 'malayalam', 'maltese', 'maori',
36 | 'marathi', 'meiteilon (manipuri)', 'mizo', 'mongolian', 'myanmar', 'nepali', 'norwegian',
37 | 'odia (oriya)', 'oromo', 'pashto', 'persian', 'polish', 'portuguese', 'punjabi', 'quechua',
38 | 'romanian', 'russian', 'samoan', 'sanskrit', 'scots gaelic', 'sepedi', 'serbian', 'sesotho',
39 | 'shona', 'sindhi', 'sinhala', 'slovak', 'slovenian', 'somali', 'spanish', 'sundanese', 'swahili',
40 | 'swedish', 'tajik', 'tamil', 'tatar', 'telugu', 'thai', 'tigrinya', 'tsonga', 'turkish', 'turkmen',
41 | 'twi', 'ukrainian', 'urdu', 'uyghur', 'uzbek', 'vietnamese', 'welsh', 'xhosa', 'yiddish', 'yoruba',
42 | 'zulu']
43 |
44 |
45 | class TranslateUI(CTk):
46 | def __init__(self, parent):
47 | super().__init__(master=parent, width=620, height=720, fg_color=("#F2F0EE", "#1E1F22"), border_width=0)
48 | self.grid_propagate(False)
49 | self.grid_columnconfigure(0, weight=1)
50 |
51 | title = ctk.CTkLabel(self, text=" Translate Audio", font=FONTS["title"], image=icons["translate"],
52 | compound="left")
53 | title.grid(row=0, column=0, padx=20, pady=20, sticky="w")
54 |
55 | self.close_btn = ctk.CTkButton(self, text="", image=icons["close"], fg_color="transparent", hover=False, width=30,
56 | height=30, command=self.hide_transcribe_ui)
57 | self.close_btn.grid(row=0, column=1, padx=20, pady=20, sticky="e")
58 |
59 | self.main_frame = ctk.CTkFrame(self, fg_color="transparent")
60 | self.main_frame.grid(row=1, column=0, padx=20, pady=20, sticky="nsew", columnspan=2)
61 | self.main_frame.grid_columnconfigure(0, weight=1)
62 |
63 | self.master = parent
64 | self.drag_drop = None
65 | self.audio_path = None
66 | self.cancel_signal = False
67 | self.loader = None
68 | self.lang_dropdown = None
69 | self.queue = queue.Queue()
70 |
71 | self.default_widget()
72 |
73 | self.grid(row=0, column=0, sticky="nsew")
74 |
75 | def default_widget(self):
76 | label = ctk.CTkLabel(self.main_frame, text="No File Selected", font=FONTS["subtitle_bold"])
77 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w")
78 |
79 | self.drag_drop = ctk.CTkButton(self.main_frame, text="➕ \nDrag & Drop Here", width=500, height=250,
80 | text_color=("#000000", "#DFE1E5"), hover=False, fg_color="transparent",
81 | border_width=2, corner_radius=5, border_color=("#D3D5DB", "#2B2D30"),
82 | font=FONTS["normal"])
83 | self.drag_drop.grid(row=1, column=0, padx=0, pady=10, sticky="nsew")
84 |
85 | self.drag_drop.drop_target_register(DND_ALL)
86 | self.drag_drop.dnd_bind('<>', self.drop)
87 |
88 | label_or = ctk.CTkLabel(self.main_frame, text="Or", font=("", 14))
89 | label_or.grid(row=2, column=0, padx=0, pady=5, sticky="nsew")
90 |
91 | select_btn = ctk.CTkButton(self.main_frame, text="Browse Files", width=150, height=40,
92 | command=self.select_file_callback, font=FONTS["normal"])
93 | select_btn.grid(row=3, column=0, padx=200, pady=10, sticky="nsew")
94 |
95 | label_1 = ctk.CTkLabel(self.main_frame, text="Support Formats: WAV, MP3, OGG, FLAC, MP4, MOV, WMV, AVI",
96 | fg_color=("#D3D5DB", "#2B2D30"), corner_radius=5, width=400, height=50,
97 | font=FONTS["small"])
98 | label_1.grid(row=4, column=0, padx=0, pady=20, sticky="sew")
99 |
100 | def task_widget(self):
101 | file_name, duration, file_size = self.get_audio_info(self.audio_path)
102 |
103 | label = ctk.CTkLabel(self.main_frame, text="Selected File", font=FONTS["subtitle_bold"])
104 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w")
105 |
106 | frame = ctk.CTkFrame(self.main_frame)
107 | frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew", columnspan=2)
108 | frame.grid_columnconfigure(0, weight=1)
109 |
110 | label_1 = ctk.CTkLabel(frame, text="File Name", font=FONTS["normal"])
111 | label_1.grid(row=0, column=0, padx=20, pady=(20, 5), sticky="w")
112 | label_1_value = ctk.CTkLabel(frame, text=file_name, font=FONTS["small"])
113 | label_1_value.grid(row=0, column=1, padx=20, pady=(20, 5), sticky="e")
114 |
115 | label_2 = ctk.CTkLabel(frame, text="Duration", font=FONTS["normal"])
116 | label_2.grid(row=1, column=0, padx=20, pady=5, sticky="w")
117 | label_2_value = ctk.CTkLabel(frame, text=duration, font=FONTS["small"])
118 | label_2_value.grid(row=1, column=1, padx=20, pady=5, sticky="e")
119 |
120 | label_3 = ctk.CTkLabel(frame, text="Size", font=FONTS["normal"])
121 | label_3.grid(row=2, column=0, padx=20, pady=(5, 20), sticky="w")
122 | label_3_value = ctk.CTkLabel(frame, text=f"{file_size:.2f} MB", font=FONTS["small"])
123 | label_3_value.grid(row=2, column=1, padx=20, pady=(5, 20), sticky="e")
124 |
125 | translate_to = ctk.CTkLabel(self.main_frame, text="Translate To", font=FONTS["normal"])
126 | translate_to.grid(row=2, column=0, padx=40, pady=20, sticky="w")
127 | self.lang_dropdown = ctk.CTkOptionMenu(self.main_frame, font=FONTS["small"])
128 | self.lang_dropdown.grid(row=2, column=1, padx=40, pady=20, sticky="e")
129 | CTkScrollableDropdownFrame(self.lang_dropdown, values=languages, **DROPDOWN)
130 | self.lang_dropdown.set(languages[0])
131 |
132 | start_btn = ctk.CTkButton(self.main_frame, text="Start Translating", height=40, command=self.start_callback,
133 | font=FONTS["normal"])
134 | start_btn.grid(row=3, column=0, padx=200, pady=40, sticky="nsew", columnspan=2)
135 |
136 | def result_widget(self):
137 | result = self.queue.get()
138 |
139 | result_label = ctk.CTkLabel(self.main_frame, text="Translated Text", font=FONTS["subtitle_bold"])
140 | result_label.grid(row=0, column=0, padx=10, pady=(20, 5), sticky="w")
141 |
142 | textbox = ctk.CTkTextbox(self.main_frame, width=580, height=200, border_width=2, font=FONTS["normal"])
143 | textbox.grid(row=1, column=0, padx=10, pady=(5, 20), sticky="nsew", columnspan=2)
144 | textbox.insert("0.0", text=result)
145 |
146 | download_label = ctk.CTkLabel(self.main_frame, text="Download Text", font=FONTS["subtitle_bold"])
147 | download_label.grid(row=2, column=0, padx=10, pady=(20, 5), sticky="w")
148 |
149 | download_btn = ctk.CTkButton(self.main_frame, text="Download", command=lambda: self.save_text(result),
150 | font=FONTS["normal"], height=35)
151 | download_btn.grid(row=3, column=0, padx=10, pady=20, sticky="nsw")
152 |
153 | def save_text(self, result):
154 | file_name = os.path.basename(self.audio_path)
155 | sep = "."
156 | file_name = file_name.split(sep, 1)[0]
157 |
158 | file_path = fd.asksaveasfilename(
159 | parent=self,
160 | initialfile=file_name,
161 | title="Export text",
162 | defaultextension=".txt",
163 | filetypes=[("Text Files", "*.txt")]
164 | )
165 |
166 | if file_path:
167 | with open(file_path, 'w', encoding="utf-8") as file:
168 | file.write(result)
169 |
170 | def start_callback(self):
171 | self.close_btn.configure(state="disabled")
172 | widgets = self.main_frame.winfo_children()
173 | for widget in widgets:
174 | widget.destroy()
175 |
176 | self.loader = CTkLoader(parent=self.master, title="Transcribing...", msg="Please wait...",
177 | cancel_func=self.set_signal)
178 | thread = threading.Thread(target=self.start_transcribing, args=(self.audio_path, self.check_signal))
179 | thread.start()
180 |
181 | def start_transcribing(self, audio_path, check_signal):
182 | language = self.lang_dropdown.get()
183 | transcriber = Transcriber(audio=audio_path)
184 | result = transcriber.translate_audio(cancel_func=check_signal, to_language=language)
185 | self.loader.hide_loader()
186 | if result:
187 | self.queue.put(result)
188 | self.close_btn.configure(state="normal")
189 | self.result_widget()
190 |
191 | def set_signal(self):
192 | self.cancel_signal = True
193 |
194 | def check_signal(self):
195 | original_value = self.cancel_signal
196 |
197 | if self.cancel_signal:
198 | self.cancel_signal = False
199 | self.close_btn.configure(state="normal")
200 | self.after(1000, self.default_widget)
201 |
202 | return original_value
203 |
204 | def select_file_callback(self):
205 | file_path = fd.askopenfilename(
206 | filetypes=[("Audio files", "*.mp3 *.wav *.ogg *.flac"), ("Video files", "*.mp4 *.mov *.wmv *.avi")])
207 | if file_path:
208 | audio_path = os.path.abspath(file_path)
209 | if self.is_streamable_audio(audio_path):
210 | self.audio_path = audio_path
211 | widgets = self.main_frame.winfo_children()
212 |
213 | for widget in widgets:
214 | widget.destroy()
215 | self.after(1000, self.task_widget)
216 | else:
217 | CTkAlert(parent=self.master, status="error", title="Error",
218 | msg="The chosen audio file is not valid or streamable.")
219 |
220 | @staticmethod
221 | def is_streamable_audio(audio_path):
222 | if not os.path.isfile(audio_path):
223 | return False
224 |
225 | try:
226 | audio = AudioSegment.from_file(audio_path)
227 | audio_info = File(audio_path)
228 | return len(audio) > 0 and audio_info.info.length > 0
229 | except (FileNotFoundError, exceptions.CouldntDecodeError, MutagenError):
230 | return False
231 |
232 | @staticmethod
233 | def get_audio_info(file_path):
234 | try:
235 | if not os.path.isfile(file_path):
236 | return None, None
237 |
238 | file_name = os.path.basename(file_path)
239 |
240 | audio_info = File(file_path)
241 | if audio_info is None:
242 | return None, None
243 |
244 | duration_seconds = audio_info.info.length
245 |
246 | duration_formatted = time.strftime("%H:%M:%S", time.gmtime(duration_seconds))
247 |
248 | file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
249 |
250 | return file_name, duration_formatted, file_size_mb
251 | except MutagenError:
252 | return None, None
253 |
254 | def drop(self, event):
255 | dropped_file = event.data.replace("{", "").replace("}", "")
256 | audio_path = os.path.abspath(dropped_file)
257 | if self.is_streamable_audio(audio_path):
258 | self.audio_path = audio_path
259 | widgets = self.main_frame.winfo_children()
260 |
261 | for widget in widgets:
262 | widget.destroy()
263 |
264 | self.after(1000, self.task_widget)
265 | else:
266 | CTkAlert(parent=self.master, status="error", title="Error",
267 | msg="The chosen audio file is not valid or streamable.")
268 |
269 | def hide_transcribe_ui(self):
270 | self.destroy()
271 |
--------------------------------------------------------------------------------
/src/ui/live_transcribe.py:
--------------------------------------------------------------------------------
1 | import random
2 | import threading
3 | import tkinter as tk
4 | from datetime import datetime, timedelta
5 | from queue import Queue
6 | from time import sleep
7 |
8 | import customtkinter as ctk
9 | import numpy as np
10 | import speech_recognition as sr
11 | import torch
12 | from .. import whisper
13 |
14 | from .ctk_tooltip import CTkToolTip
15 | from .ctkdropdown import CTkScrollableDropdownFrame
16 | from .icons import icons
17 | from .style import FONTS, DROPDOWN
18 |
19 |
20 | class TkAudioVisualizer(tk.Frame):
21 | def __init__(self,
22 | master: any,
23 | gradient=None,
24 | bar_color: str = "white",
25 | bar_width: int = 7,
26 | **kwargs):
27 | tk.Frame.__init__(self, master)
28 | if gradient is None:
29 | gradient = ["cyan", "blue"]
30 | self.viz = DrawBars(self, gradient[0], gradient[1], bar_width, bar_color, relief="sunken", **kwargs)
31 | self.viz.pack(fill="both", expand=True)
32 |
33 | def start(self):
34 | """ start the vizualizer """
35 | if not self.viz._running:
36 | self.viz._running = True
37 | self.viz.update()
38 |
39 | def stop(self):
40 | """ stop the visualizer """
41 | self.viz._running = False
42 |
43 |
44 | class DrawBars(tk.Canvas):
45 | '''A gradient frame which uses a canvas to draw the background'''
46 |
47 | def __init__(self, parent, color1, color2, bar_width, bar_color, **kwargs):
48 | tk.Canvas.__init__(self, parent, bg=bar_color, bd=0, highlightthickness=0, **kwargs)
49 | self._color1 = color1
50 | self._color2 = color2
51 | self._bar_width = bar_width
52 | self._running = False
53 | self.after(100, lambda: self._draw_gradient())
54 | self.bind("", lambda e: self._draw_gradient() if not self._running else None)
55 |
56 | def _draw_gradient(self, event=None):
57 | '''Draw the gradient spectrum '''
58 | self.delete("gradient")
59 | width = self.winfo_width()
60 | height = self.winfo_height()
61 | limit = width + 10
62 |
63 | (r1, g1, b1) = self.winfo_rgb(self._color1)
64 | (r2, g2, b2) = self.winfo_rgb(self._color2)
65 | r_ratio = float(r2 - r1) / limit
66 | g_ratio = float(g2 - g1) / limit
67 | b_ratio = float(b2 - b1) / limit
68 |
69 | for i in range(0, limit, self._bar_width):
70 | bar_height = random.randint(int(limit / 8), int(limit / 2.5))
71 | if not self._running:
72 | bar_height = height
73 | nr = int(r1 + (r_ratio * i))
74 | ng = int(g1 + (g_ratio * i))
75 | nb = int(b1 + (b_ratio * i))
76 |
77 | color = "#%4.4x%4.4x%4.4x" % (nr, ng, nb)
78 | self.create_line(i, 0, i, bar_height, tags=("gradient",), width=self._bar_width, fill=color)
79 |
80 | self.lower("gradient")
81 |
82 | if self._running:
83 | self.after(150, self._draw_gradient)
84 |
85 | def update(self):
86 | self._draw_gradient()
87 |
88 |
89 | languages = [
90 | "English",
91 | "Chinese",
92 | "German",
93 | "Spanish",
94 | "Russian",
95 | "Korean",
96 | "French",
97 | "Japanese",
98 | "Portuguese",
99 | "Turkish",
100 | "Polish",
101 | "Catalan",
102 | "Dutch",
103 | "Arabic",
104 | "Swedish",
105 | "Italian",
106 | "Indonesian",
107 | "Hindi",
108 | "Finnish",
109 | "Vietnamese",
110 | "Hebrew",
111 | "Ukrainian",
112 | "Greek",
113 | "Malay",
114 | "Czech",
115 | "Romanian",
116 | "Danish",
117 | "Hungarian",
118 | "Tamil",
119 | "Norwegian",
120 | "Thai",
121 | "Urdu",
122 | "Croatian",
123 | "Bulgarian",
124 | "Lithuanian",
125 | "Latin",
126 | "Maori",
127 | "Malayalam",
128 | "Welsh",
129 | "Slovak",
130 | "Telugu",
131 | "Persian",
132 | "Latvian",
133 | "Bengali",
134 | "Serbian",
135 | "Azerbaijani",
136 | "Slovenian",
137 | "Kannada",
138 | "Estonian",
139 | "Macedonian",
140 | "Breton",
141 | "Basque",
142 | "Icelandic",
143 | "Armenian",
144 | "Nepali",
145 | "Mongolian",
146 | "Bosnian",
147 | "Kazakh",
148 | "Albanian",
149 | "Swahili",
150 | "Galician",
151 | "Marathi",
152 | "Punjabi",
153 | "Sinhala",
154 | "Khmer",
155 | "Shona",
156 | "Yoruba",
157 | "Somali",
158 | "Afrikaans",
159 | "Occitan",
160 | "Georgian",
161 | "Belarusian",
162 | "Tajik",
163 | "Sindhi",
164 | "Gujarati",
165 | "Amharic",
166 | "Yiddish",
167 | "Lao",
168 | "Uzbek",
169 | "Faroese",
170 | "Haitian creole",
171 | "Pashto",
172 | "Turkmen",
173 | "Nynorsk",
174 | "Maltese",
175 | "Sanskrit",
176 | "Luxembourgish",
177 | "Myanmar",
178 | "Tibetan",
179 | "Tagalog",
180 | "Malagasy",
181 | "Assamese",
182 | "Tatar",
183 | "Hawaiian",
184 | "Lingala",
185 | "Hausa",
186 | "Bashkir",
187 | "Javanese",
188 | "Sundanese",
189 | ]
190 |
191 |
192 | class LiveTranscribeUI(ctk.CTkFrame):
193 | def __init__(self, parent):
194 | super().__init__(master=parent, width=620, height=720, fg_color=("#F2F0EE", "#1E1F22"), border_width=0)
195 | self.energy_threshold_value = None
196 | self.language_value = None
197 | self.model_size_value = None
198 | self.tooltip = None
199 | self.language_dropdown = None
200 | self.model_dropdown = None
201 | self.mic_dropdown = None
202 | self.mic_btn = None
203 | self.recording = False
204 | self.close_btn = None
205 | self.main_frame = None
206 | self.input_value = None
207 | self.text_box = None
208 | self.record_button = None
209 | self.animation_frame = None
210 | self.transcribe_thread = None
211 | self.record_thread = None
212 | self.paa_thread = None
213 |
214 | self.selected_model = "base"
215 | self.selected_language = "english"
216 |
217 | self.audio_model = whisper.load_model(self.selected_model)
218 | self.phrase_time = None
219 | self.data_queue = Queue()
220 | self.recorder = sr.Recognizer()
221 | self.transcription = ['']
222 |
223 | self.energy_threshold = 500
224 | self.recorder.energy_threshold = self.energy_threshold
225 | self.recorder.dynamic_energy_threshold = False
226 |
227 | self.source = sr.Microphone(sample_rate=16000)
228 |
229 | self.record_timeout = 2
230 | self.phrase_timeout = 3
231 |
232 | self.grid_propagate(False)
233 | self.grid_columnconfigure(0, weight=1)
234 | self.ui()
235 | self.grid(row=0, column=0, sticky="nsew")
236 |
237 | def ui(self):
238 | title = ctk.CTkLabel(self, text=" Live Transcribe", font=FONTS["title"], image=icons["microphone"],
239 | compound="left")
240 | title.grid(row=0, column=0, padx=20, pady=20, sticky="w")
241 |
242 | self.close_btn = ctk.CTkButton(self, text="", image=icons["close"], fg_color="transparent", hover=False,
243 | width=30,
244 | height=30, command=self.hide_live_transcribe_ui)
245 | self.close_btn.grid(row=0, column=1, padx=20, pady=20, sticky="e")
246 |
247 | self.main_frame = ctk.CTkFrame(self, fg_color="transparent")
248 | self.main_frame.grid(row=1, column=0, padx=20, pady=20, sticky="nsew", columnspan=2)
249 | self.main_frame.grid_columnconfigure(0, weight=1)
250 |
251 | models = ["Tiny", "Base", "Small", "Medium"]
252 | model_size_label = ctk.CTkLabel(self.main_frame, text="Model Size", font=FONTS["normal"])
253 | model_size_label.grid(row=0, column=0, padx=20, pady=(20, 10), sticky="w")
254 | self.model_size_value = ctk.CTkOptionMenu(self.main_frame, font=FONTS["small"])
255 | self.model_size_value.grid(row=0, column=1, padx=20, pady=10, sticky="e")
256 | self.model_dropdown = CTkScrollableDropdownFrame(self.model_size_value, values=models, **DROPDOWN,
257 | command=self.model_option_callback)
258 | self.model_size_value.set(models[0])
259 |
260 | language_label = ctk.CTkLabel(self.main_frame, text="Language", font=FONTS["normal"])
261 | language_label.grid(row=1, column=0, padx=20, pady=10, sticky="w")
262 | self.language_value = ctk.CTkOptionMenu(self.main_frame, font=FONTS["small"])
263 | self.language_value.grid(row=1, column=1, padx=20, pady=10, sticky="e")
264 | self.language_dropdown = CTkScrollableDropdownFrame(self.language_value, values=languages, **DROPDOWN,
265 | command=self.language_option_callback)
266 | self.language_value.set(languages[0])
267 |
268 | energy_threshold_label = ctk.CTkLabel(self.main_frame, text="Energy Threshold", font=FONTS["normal"])
269 | energy_threshold_label.grid(row=2, column=0, padx=20, pady=10, sticky="w")
270 | self.energy_threshold_value = ctk.CTkSlider(self.main_frame, from_=100, to=1000, number_of_steps=50,
271 | command=self.slider_event)
272 | self.energy_threshold_value.grid(row=2, column=1, padx=20, pady=10, sticky="e")
273 | self.energy_threshold_value.set(500)
274 | self.tooltip = CTkToolTip(widget=self.energy_threshold_value, message="500")
275 |
276 | self.record_button = ctk.CTkButton(self.main_frame, text="Start Transcribing",
277 | height=35, command=self.start_callback)
278 | self.record_button.grid(row=3, column=0, padx=100, pady=20, columnspan=2, sticky="nsew")
279 |
280 | ctk.CTkLabel(self.main_frame, text="Transcribed text", font=FONTS["subtitle"]).grid(row=4, column=0, sticky="w",
281 | padx=0, pady=(20, 5))
282 |
283 | self.text_box = ctk.CTkTextbox(self.main_frame, height=260, font=FONTS["normal"])
284 | self.text_box.grid(row=5, column=0, sticky="nsew", pady=(5, 20), columnspan=2)
285 |
286 | def start_callback(self):
287 | self.pause_play()
288 |
289 | self.record_thread = threading.Thread(target=self.start_recording)
290 | self.record_thread.start()
291 |
292 | self.paa_thread = threading.Thread(target=self.process_audio_data)
293 | self.paa_thread.start()
294 |
295 | def process_audio_data(self):
296 | while True:
297 | now = datetime.utcnow()
298 |
299 | if not self.data_queue.empty():
300 | phrase_complete = False
301 |
302 | if self.phrase_time and now - self.phrase_time > timedelta(seconds=self.phrase_timeout):
303 | phrase_complete = True
304 |
305 | self.phrase_time = now
306 |
307 | audio_data = b''.join(self.data_queue.queue)
308 | self.data_queue.queue.clear()
309 |
310 | audio_np = np.frombuffer(audio_data, dtype=np.int16).astype(np.float32) / 32768.0
311 |
312 | result = self.audio_model.transcribe(audio_np, fp16=torch.cuda.is_available(),
313 | cancel_func=self.cancel_callback, language=self.selected_language)
314 | text = result['text'].strip()
315 |
316 | if phrase_complete:
317 | self.transcription.append(text)
318 | else:
319 | self.transcription[-1] = text
320 |
321 | self.update_text()
322 |
323 | sleep(0.25)
324 |
325 | def record_callback(self, _, audio: sr.AudioData) -> None:
326 | data = audio.get_raw_data()
327 | self.data_queue.put(data)
328 |
329 | def start_recording(self):
330 | self.recorder.listen_in_background(self.source, self.record_callback, phrase_time_limit=self.record_timeout)
331 |
332 | def update_text(self):
333 | text_to_display = ' '.join(self.transcription)
334 | self.text_box.delete("0.0", "end")
335 | self.text_box.insert("0.0", text_to_display)
336 |
337 | def pause_play(self):
338 | if not self.recording:
339 | self.recording = True
340 | self.record_button.configure(text="Stop Transcribing")
341 | else:
342 | self.recording = False
343 | self.record_button.configure(text="Start Transcribing")
344 |
345 | def slider_event(self, value):
346 | value = int(value)
347 | self.energy_threshold = value
348 | self.recorder.energy_threshold = value
349 | self.tooltip.configure(message=value)
350 |
351 | def model_option_callback(self, value):
352 | value1 = str(value).lower()
353 | self.selected_model = value1
354 | self.audio_model = whisper.load_model(self.selected_model)
355 | self.model_size_value.set(value)
356 |
357 | def language_option_callback(self, value):
358 | value1 = str(value).lower()
359 | self.selected_language = value1
360 | self.language_value.set(value)
361 |
362 | def cancel_callback(self):
363 | pass
364 |
365 | def hide_live_transcribe_ui(self):
366 | self.destroy()
367 |
--------------------------------------------------------------------------------
/src/ui/transcribe.py:
--------------------------------------------------------------------------------
1 | import os
2 | import queue
3 | import threading
4 | import time
5 |
6 | import customtkinter as ctk
7 | from customtkinter import filedialog as fd
8 | from mutagen import File, MutagenError
9 | from pydub import AudioSegment, exceptions
10 | from tkinterdnd2 import TkinterDnD, DND_ALL
11 | from ..whisper.utils import get_writer
12 | from .ctkAlert import CTkAlert
13 | from .ctkLoader import CTkLoader
14 | from .ctkdropdown import CTkScrollableDropdownFrame
15 | from .icons import icons
16 | from .style import FONTS, DROPDOWN
17 | from ..logic import Transcriber
18 |
19 |
20 | class CTk(ctk.CTkFrame, TkinterDnD.DnDWrapper):
21 | def __init__(self, *args, **kwargs):
22 | super().__init__(*args, **kwargs)
23 | self.TkdndVersion = TkinterDnD._require(self)
24 |
25 |
26 | class TranscribeUI(CTk):
27 | def __init__(self, parent):
28 | super().__init__(master=parent, width=620, height=720, fg_color=("#F2F0EE", "#1E1F22"), border_width=0)
29 | self.grid_propagate(False)
30 | self.grid_columnconfigure(0, weight=1)
31 |
32 | title = ctk.CTkLabel(self, text=" Transcribe Audio", font=FONTS["title"], image=icons["audio_file"],
33 | compound="left")
34 | title.grid(row=0, column=0, padx=20, pady=20, sticky="w")
35 |
36 | self.close_btn = ctk.CTkButton(self, text="", image=icons["close"], fg_color="transparent", hover=False, width=30,
37 | height=30, command=self.hide_transcribe_ui)
38 | self.close_btn.grid(row=0, column=1, padx=20, pady=20, sticky="e")
39 |
40 | self.main_frame = ctk.CTkFrame(self, fg_color="transparent")
41 | self.main_frame.grid(row=1, column=0, padx=20, pady=20, sticky="nsew", columnspan=2)
42 | self.main_frame.grid_columnconfigure(0, weight=1)
43 |
44 | self.master = parent
45 | self.drag_drop = None
46 | self.audio_path = None
47 | self.cancel_signal = False
48 | self.loader = None
49 | self.option_menu = None
50 | self.queue = queue.Queue()
51 |
52 | self.default_widget()
53 |
54 | self.grid(row=0, column=0, sticky="nsew")
55 |
56 | def default_widget(self):
57 | label = ctk.CTkLabel(self.main_frame, text="No File Selected", font=FONTS["subtitle_bold"])
58 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w")
59 |
60 | self.drag_drop = ctk.CTkButton(self.main_frame, text="➕ \nDrag & Drop Here", width=500, height=250,
61 | text_color=("#000000", "#DFE1E5"), hover=False, fg_color="transparent",
62 | border_width=2, corner_radius=5, border_color=("#D3D5DB", "#2B2D30"),
63 | font=FONTS["normal"])
64 | self.drag_drop.grid(row=1, column=0, padx=0, pady=10, sticky="nsew")
65 |
66 | self.drag_drop.drop_target_register(DND_ALL)
67 | self.drag_drop.dnd_bind('<>', self.drop)
68 |
69 | label_or = ctk.CTkLabel(self.main_frame, text="Or", font=("", 14))
70 | label_or.grid(row=2, column=0, padx=0, pady=5, sticky="nsew")
71 |
72 | select_btn = ctk.CTkButton(self.main_frame, text="Browse Files", width=150, height=40,
73 | command=self.select_file_callback, font=FONTS["normal"])
74 | select_btn.grid(row=3, column=0, padx=200, pady=10, sticky="nsew")
75 |
76 | label_1 = ctk.CTkLabel(self.main_frame, text="Support Formats: WAV, MP3, OGG, FLAC, MP4, MOV, WMV, AVI",
77 | fg_color=("#D3D5DB", "#2B2D30"), corner_radius=5, width=400, height=50,
78 | font=FONTS["small"])
79 | label_1.grid(row=4, column=0, padx=0, pady=20, sticky="sew")
80 |
81 | def task_widget(self):
82 | file_name, duration, file_size = self.get_audio_info(self.audio_path)
83 |
84 | label = ctk.CTkLabel(self.main_frame, text="Selected File", font=FONTS["subtitle_bold"])
85 | label.grid(row=0, column=0, padx=0, pady=(20, 5), sticky="w")
86 |
87 | frame = ctk.CTkFrame(self.main_frame)
88 | frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew")
89 | frame.grid_columnconfigure(0, weight=1)
90 |
91 | label_1 = ctk.CTkLabel(frame, text="File Name", font=FONTS["normal"])
92 | label_1.grid(row=0, column=0, padx=20, pady=(20, 5), sticky="w")
93 | label_1_value = ctk.CTkLabel(frame, text=file_name, font=FONTS["small"])
94 | label_1_value.grid(row=0, column=1, padx=20, pady=(20, 5), sticky="e")
95 |
96 | label_2 = ctk.CTkLabel(frame, text="Duration", font=FONTS["normal"])
97 | label_2.grid(row=1, column=0, padx=20, pady=5, sticky="w")
98 | label_2_value = ctk.CTkLabel(frame, text=duration, font=FONTS["small"])
99 | label_2_value.grid(row=1, column=1, padx=20, pady=5, sticky="e")
100 |
101 | label_3 = ctk.CTkLabel(frame, text="Size", font=FONTS["normal"])
102 | label_3.grid(row=2, column=0, padx=20, pady=(5, 20), sticky="w")
103 | label_3_value = ctk.CTkLabel(frame, text=f"{file_size:.2f} MB", font=FONTS["small"])
104 | label_3_value.grid(row=2, column=1, padx=20, pady=(5, 20), sticky="e")
105 |
106 | start_btn = ctk.CTkButton(self.main_frame, text="Start Transcribing", height=40, command=self.start_callback,
107 | font=FONTS["normal"])
108 | start_btn.grid(row=2, column=0, padx=200, pady=20, sticky="nsew")
109 |
110 | def result_widget(self):
111 | result = self.queue.get()
112 | text = str(result["text"]).strip()
113 |
114 | result_label = ctk.CTkLabel(self.main_frame, text="Transcribed Text", font=FONTS["subtitle_bold"])
115 | result_label.grid(row=0, column=0, padx=10, pady=(20, 5), sticky="w")
116 |
117 | textbox = ctk.CTkTextbox(self.main_frame, width=580, height=200, border_width=2, font=FONTS["normal"])
118 | textbox.grid(row=1, column=0, padx=10, pady=(5, 20), sticky="nsew", columnspan=2)
119 | textbox.insert("0.0", text=text)
120 |
121 | download_label = ctk.CTkLabel(self.main_frame, text="Download Text and Subtitles", font=FONTS["subtitle_bold"])
122 | download_label.grid(row=2, column=0, padx=10, pady=(20, 5), sticky="w")
123 |
124 | self.option_menu = ctk.CTkOptionMenu(self.main_frame, width=200)
125 | self.option_menu.grid(row=3, column=0, padx=10, pady=(5, 20), sticky="w")
126 | format_values = ["Text File (txt)",
127 | "Subtitles (SRT)",
128 | "WebVTT (VTT)",
129 | "Tab-Separated Values (TSV)",
130 | "JSON File (json)",
131 | "Save as all extensions"]
132 | CTkScrollableDropdownFrame(self.option_menu, values=format_values, **DROPDOWN)
133 | self.option_menu.set(format_values[0])
134 |
135 | download_btn = ctk.CTkButton(self.main_frame, text="Download", command=lambda: self.save_text(result),
136 | font=FONTS["normal"], height=35)
137 | download_btn.grid(row=4, column=0, padx=10, pady=20, sticky="nsw")
138 |
139 | def save_text(self, result):
140 | file_name = os.path.basename(self.audio_path)
141 | sep = "."
142 | file_name = file_name.split(sep, 1)[0]
143 |
144 | selected_extension = self.option_menu.get()
145 |
146 | file_extension_map = {
147 | "Text File (txt)": ".txt",
148 | "Subtitles (SRT)": ".srt",
149 | "WebVTT (VTT)": ".vtt",
150 | "Tab-Separated Values (TSV)": ".tsv",
151 | "JSON File (json)": ".json",
152 | "Save as all extensions": ".all",
153 | }
154 |
155 | file_extension = file_extension_map[selected_extension]
156 |
157 | file_path = fd.asksaveasfilename(
158 | parent=self,
159 | initialfile=file_name,
160 | title="Export subtitle",
161 | defaultextension=file_extension,
162 | filetypes=[(f"{selected_extension} Files", "*" + file_extension)]
163 | )
164 |
165 | if file_path:
166 | dir_name, get_file_name = os.path.split(file_path)
167 |
168 | if file_extension == ".srt":
169 | writer = get_writer("srt", dir_name)
170 | writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50, "max_line_width": 3})
171 | elif file_extension == ".txt":
172 | txt_writer = get_writer("txt", dir_name)
173 | txt_writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50,
174 | "max_line_width": 3})
175 | elif file_extension == ".vtt":
176 | vtt_writer = get_writer("vtt", dir_name)
177 | vtt_writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50,
178 | "max_line_width": 3})
179 | elif file_extension == ".tsv":
180 | tsv_writer = get_writer("tsv", dir_name)
181 | tsv_writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50,
182 | "max_line_width": 3})
183 | elif file_extension == ".json":
184 | json_writer = get_writer("json", dir_name)
185 | json_writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50,
186 | "max_line_width": 3})
187 | elif file_extension == ".all":
188 | all_writer = get_writer("all", dir_name)
189 | all_writer(result, self.audio_path, {"highlight_words": True, "max_line_count": 50,
190 | "max_line_width": 3})
191 |
192 | def start_callback(self):
193 | self.close_btn.configure(state="disabled")
194 | widgets = self.main_frame.winfo_children()
195 | for widget in widgets:
196 | widget.destroy()
197 |
198 | self.loader = CTkLoader(parent=self.master, title="Transcribing...", msg="Please wait...",
199 | cancel_func=self.set_signal)
200 | thread = threading.Thread(target=self.start_transcribing, args=(self.audio_path, self.check_signal))
201 | thread.start()
202 |
203 | def start_transcribing(self, audio_path, check_signal):
204 | transcriber = Transcriber(audio=audio_path)
205 | result = transcriber.audio_recognition(cancel_func=check_signal)
206 | self.loader.hide_loader()
207 | if result:
208 | self.queue.put(result)
209 | self.close_btn.configure(state="normal")
210 | self.result_widget()
211 |
212 | def set_signal(self):
213 | self.cancel_signal = True
214 |
215 | def check_signal(self):
216 | original_value = self.cancel_signal
217 |
218 | if self.cancel_signal:
219 | self.cancel_signal = False
220 | self.close_btn.configure(state="normal")
221 | self.after(1000, self.default_widget)
222 |
223 | return original_value
224 |
225 | def select_file_callback(self):
226 | file_path = fd.askopenfilename(
227 | filetypes=[("Audio files", "*.mp3 *.wav *.ogg *.flac"), ("Video files", "*.mp4 *.mov *.wmv *.avi")])
228 | if file_path:
229 | audio_path = os.path.abspath(file_path)
230 | if self.is_streamable_audio(audio_path):
231 | self.audio_path = audio_path
232 | widgets = self.main_frame.winfo_children()
233 |
234 | for widget in widgets:
235 | widget.destroy()
236 | self.after(1000, self.task_widget)
237 | else:
238 | CTkAlert(parent=self.master, status="error", title="Error",
239 | msg="The chosen audio file is not valid or streamable.")
240 |
241 | @staticmethod
242 | def is_streamable_audio(audio_path):
243 | if not os.path.isfile(audio_path):
244 | return False
245 |
246 | try:
247 | audio = AudioSegment.from_file(audio_path)
248 | audio_info = File(audio_path)
249 | return len(audio) > 0 and audio_info.info.length > 0
250 | except (FileNotFoundError, exceptions.CouldntDecodeError, MutagenError):
251 | return False
252 |
253 | @staticmethod
254 | def get_audio_info(file_path):
255 | try:
256 | if not os.path.isfile(file_path):
257 | return None, None
258 |
259 | file_name = os.path.basename(file_path)
260 |
261 | audio_info = File(file_path)
262 | if audio_info is None:
263 | return None, None
264 |
265 | duration_seconds = audio_info.info.length
266 |
267 | duration_formatted = time.strftime("%H:%M:%S", time.gmtime(duration_seconds))
268 |
269 | file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
270 |
271 | return file_name, duration_formatted, file_size_mb
272 | except MutagenError:
273 | return None, None
274 |
275 | def drop(self, event):
276 | dropped_file = event.data.replace("{", "").replace("}", "")
277 | audio_path = os.path.abspath(dropped_file)
278 | if self.is_streamable_audio(audio_path):
279 | self.audio_path = audio_path
280 | widgets = self.main_frame.winfo_children()
281 |
282 | for widget in widgets:
283 | widget.destroy()
284 |
285 | self.after(1000, self.task_widget)
286 | else:
287 | CTkAlert(parent=self.master, status="error", title="Error",
288 | msg="The chosen audio file is not valid or streamable.")
289 |
290 | def hide_transcribe_ui(self):
291 | self.destroy()
292 |
--------------------------------------------------------------------------------
/src/whisper/tokenizer.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import os
3 | import string
4 | from dataclasses import dataclass, field
5 | from functools import cached_property, lru_cache
6 | from typing import Dict, List, Optional, Tuple
7 |
8 | import tiktoken
9 |
10 | LANGUAGES = {
11 | "en": "english",
12 | "zh": "chinese",
13 | "de": "german",
14 | "es": "spanish",
15 | "ru": "russian",
16 | "ko": "korean",
17 | "fr": "french",
18 | "ja": "japanese",
19 | "pt": "portuguese",
20 | "tr": "turkish",
21 | "pl": "polish",
22 | "ca": "catalan",
23 | "nl": "dutch",
24 | "ar": "arabic",
25 | "sv": "swedish",
26 | "it": "italian",
27 | "id": "indonesian",
28 | "hi": "hindi",
29 | "fi": "finnish",
30 | "vi": "vietnamese",
31 | "he": "hebrew",
32 | "uk": "ukrainian",
33 | "el": "greek",
34 | "ms": "malay",
35 | "cs": "czech",
36 | "ro": "romanian",
37 | "da": "danish",
38 | "hu": "hungarian",
39 | "ta": "tamil",
40 | "no": "norwegian",
41 | "th": "thai",
42 | "ur": "urdu",
43 | "hr": "croatian",
44 | "bg": "bulgarian",
45 | "lt": "lithuanian",
46 | "la": "latin",
47 | "mi": "maori",
48 | "ml": "malayalam",
49 | "cy": "welsh",
50 | "sk": "slovak",
51 | "te": "telugu",
52 | "fa": "persian",
53 | "lv": "latvian",
54 | "bn": "bengali",
55 | "sr": "serbian",
56 | "az": "azerbaijani",
57 | "sl": "slovenian",
58 | "kn": "kannada",
59 | "et": "estonian",
60 | "mk": "macedonian",
61 | "br": "breton",
62 | "eu": "basque",
63 | "is": "icelandic",
64 | "hy": "armenian",
65 | "ne": "nepali",
66 | "mn": "mongolian",
67 | "bs": "bosnian",
68 | "kk": "kazakh",
69 | "sq": "albanian",
70 | "sw": "swahili",
71 | "gl": "galician",
72 | "mr": "marathi",
73 | "pa": "punjabi",
74 | "si": "sinhala",
75 | "km": "khmer",
76 | "sn": "shona",
77 | "yo": "yoruba",
78 | "so": "somali",
79 | "af": "afrikaans",
80 | "oc": "occitan",
81 | "ka": "georgian",
82 | "be": "belarusian",
83 | "tg": "tajik",
84 | "sd": "sindhi",
85 | "gu": "gujarati",
86 | "am": "amharic",
87 | "yi": "yiddish",
88 | "lo": "lao",
89 | "uz": "uzbek",
90 | "fo": "faroese",
91 | "ht": "haitian creole",
92 | "ps": "pashto",
93 | "tk": "turkmen",
94 | "nn": "nynorsk",
95 | "mt": "maltese",
96 | "sa": "sanskrit",
97 | "lb": "luxembourgish",
98 | "my": "myanmar",
99 | "bo": "tibetan",
100 | "tl": "tagalog",
101 | "mg": "malagasy",
102 | "as": "assamese",
103 | "tt": "tatar",
104 | "haw": "hawaiian",
105 | "ln": "lingala",
106 | "ha": "hausa",
107 | "ba": "bashkir",
108 | "jw": "javanese",
109 | "su": "sundanese",
110 | "yue": "cantonese",
111 | }
112 |
113 | # language code lookup by name, with a few language aliases
114 | TO_LANGUAGE_CODE = {
115 | **{language: code for code, language in LANGUAGES.items()},
116 | "burmese": "my",
117 | "valencian": "ca",
118 | "flemish": "nl",
119 | "haitian": "ht",
120 | "letzeburgesch": "lb",
121 | "pushto": "ps",
122 | "panjabi": "pa",
123 | "moldavian": "ro",
124 | "moldovan": "ro",
125 | "sinhalese": "si",
126 | "castilian": "es",
127 | "mandarin": "zh",
128 | }
129 |
130 |
131 | @dataclass
132 | class Tokenizer:
133 | """A thin wrapper around `tiktoken` providing quick access to special tokens"""
134 |
135 | encoding: tiktoken.Encoding
136 | num_languages: int
137 | language: Optional[str] = None
138 | task: Optional[str] = None
139 | sot_sequence: Tuple[int] = ()
140 | special_tokens: Dict[str, int] = field(default_factory=dict)
141 |
142 | def __post_init__(self):
143 | for special in self.encoding.special_tokens_set:
144 | special_token = self.encoding.encode_single_token(special)
145 | self.special_tokens[special] = special_token
146 |
147 | sot: int = self.special_tokens["<|startoftranscript|>"]
148 | translate: int = self.special_tokens["<|translate|>"]
149 | transcribe: int = self.special_tokens["<|transcribe|>"]
150 |
151 | langs = tuple(LANGUAGES.keys())[: self.num_languages]
152 | sot_sequence = [sot]
153 | if self.language is not None:
154 | sot_sequence.append(sot + 1 + langs.index(self.language))
155 | if self.task is not None:
156 | task_token: int = transcribe if self.task == "transcribe" else translate
157 | sot_sequence.append(task_token)
158 |
159 | self.sot_sequence = tuple(sot_sequence)
160 |
161 | def encode(self, text, **kwargs):
162 | return self.encoding.encode(text, **kwargs)
163 |
164 | def decode(self, token_ids: List[int], **kwargs) -> str:
165 | token_ids = [t for t in token_ids if t < self.timestamp_begin]
166 | return self.encoding.decode(token_ids, **kwargs)
167 |
168 | def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
169 | """
170 | Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
171 | This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
172 | """
173 | return self.encoding.decode(token_ids, **kwargs)
174 |
175 | @cached_property
176 | def eot(self) -> int:
177 | return self.encoding.eot_token
178 |
179 | @cached_property
180 | def transcribe(self) -> int:
181 | return self.special_tokens["<|transcribe|>"]
182 |
183 | @cached_property
184 | def translate(self) -> int:
185 | return self.special_tokens["<|translate|>"]
186 |
187 | @cached_property
188 | def sot(self) -> int:
189 | return self.special_tokens["<|startoftranscript|>"]
190 |
191 | @cached_property
192 | def sot_lm(self) -> int:
193 | return self.special_tokens["<|startoflm|>"]
194 |
195 | @cached_property
196 | def sot_prev(self) -> int:
197 | return self.special_tokens["<|startofprev|>"]
198 |
199 | @cached_property
200 | def no_speech(self) -> int:
201 | return self.special_tokens["<|nospeech|>"]
202 |
203 | @cached_property
204 | def no_timestamps(self) -> int:
205 | return self.special_tokens["<|notimestamps|>"]
206 |
207 | @cached_property
208 | def timestamp_begin(self) -> int:
209 | return self.special_tokens["<|0.00|>"]
210 |
211 | @cached_property
212 | def language_token(self) -> int:
213 | """Returns the token id corresponding to the value of the `language` field"""
214 | if self.language is None:
215 | raise ValueError("This tokenizer does not have language token configured")
216 |
217 | return self.to_language_token(self.language)
218 |
219 | def to_language_token(self, language):
220 | if token := self.special_tokens.get(f"<|{language}|>", None):
221 | return token
222 |
223 | raise KeyError(f"Language {language} not found in tokenizer.")
224 |
225 | @cached_property
226 | def all_language_tokens(self) -> Tuple[int]:
227 | result = []
228 | for token, token_id in self.special_tokens.items():
229 | if token.strip("<|>") in LANGUAGES:
230 | result.append(token_id)
231 | return tuple(result)[: self.num_languages]
232 |
233 | @cached_property
234 | def all_language_codes(self) -> Tuple[str]:
235 | return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
236 |
237 | @cached_property
238 | def sot_sequence_including_notimestamps(self) -> Tuple[int]:
239 | return tuple(list(self.sot_sequence) + [self.no_timestamps])
240 |
241 | @cached_property
242 | def non_speech_tokens(self) -> Tuple[int]:
243 | """
244 | Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
245 | annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
246 |
247 | - ♪♪♪
248 | - ( SPEAKING FOREIGN LANGUAGE )
249 | - [DAVID] Hey there,
250 |
251 | keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
252 | """
253 | symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
254 | symbols += (
255 | "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
256 | )
257 |
258 | # symbols that may be a single token or multiple tokens depending on the tokenizer.
259 | # In case they're multiple tokens, suppress the first token, which is safe because:
260 | # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
261 | # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
262 | miscellaneous = set("♩♪♫♬♭♮♯")
263 | assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
264 |
265 | # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
266 | result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
267 | for symbol in symbols + list(miscellaneous):
268 | for tokens in [
269 | self.encoding.encode(symbol),
270 | self.encoding.encode(" " + symbol),
271 | ]:
272 | if len(tokens) == 1 or symbol in miscellaneous:
273 | result.add(tokens[0])
274 |
275 | return tuple(sorted(result))
276 |
277 | def split_to_word_tokens(self, tokens: List[int]):
278 | if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
279 | # These languages don't typically use spaces, so it is difficult to split words
280 | # without morpheme analysis. Here, we instead split words at any
281 | # position where the tokens are decoded as valid unicode points
282 | return self.split_tokens_on_unicode(tokens)
283 |
284 | return self.split_tokens_on_spaces(tokens)
285 |
286 | def split_tokens_on_unicode(self, tokens: List[int]):
287 | decoded_full = self.decode_with_timestamps(tokens)
288 | replacement_char = "\ufffd"
289 |
290 | words = []
291 | word_tokens = []
292 | current_tokens = []
293 | unicode_offset = 0
294 |
295 | for token in tokens:
296 | current_tokens.append(token)
297 | decoded = self.decode_with_timestamps(current_tokens)
298 |
299 | if (
300 | replacement_char not in decoded
301 | or decoded_full[unicode_offset + decoded.index(replacement_char)]
302 | == replacement_char
303 | ):
304 | words.append(decoded)
305 | word_tokens.append(current_tokens)
306 | current_tokens = []
307 | unicode_offset += len(decoded)
308 |
309 | return words, word_tokens
310 |
311 | def split_tokens_on_spaces(self, tokens: List[int]):
312 | subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
313 | words = []
314 | word_tokens = []
315 |
316 | for subword, subword_tokens in zip(subwords, subword_tokens_list):
317 | special = subword_tokens[0] >= self.eot
318 | with_space = subword.startswith(" ")
319 | punctuation = subword.strip() in string.punctuation
320 | if special or with_space or punctuation or len(words) == 0:
321 | words.append(subword)
322 | word_tokens.append(subword_tokens)
323 | else:
324 | words[-1] = words[-1] + subword
325 | word_tokens[-1].extend(subword_tokens)
326 |
327 | return words, word_tokens
328 |
329 |
330 | @lru_cache(maxsize=None)
331 | def get_encoding(name: str = "gpt2", num_languages: int = 99):
332 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
333 | ranks = {
334 | base64.b64decode(token): int(rank)
335 | for token, rank in (line.split() for line in open(vocab_path) if line)
336 | }
337 | n_vocab = len(ranks)
338 | special_tokens = {}
339 |
340 | specials = [
341 | "<|endoftext|>",
342 | "<|startoftranscript|>",
343 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
344 | "<|translate|>",
345 | "<|transcribe|>",
346 | "<|startoflm|>",
347 | "<|startofprev|>",
348 | "<|nospeech|>",
349 | "<|notimestamps|>",
350 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
351 | ]
352 |
353 | for token in specials:
354 | special_tokens[token] = n_vocab
355 | n_vocab += 1
356 |
357 | return tiktoken.Encoding(
358 | name=os.path.basename(vocab_path),
359 | explicit_n_vocab=n_vocab,
360 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
361 | mergeable_ranks=ranks,
362 | special_tokens=special_tokens,
363 | )
364 |
365 |
366 | @lru_cache(maxsize=None)
367 | def get_tokenizer(
368 | multilingual: bool,
369 | *,
370 | num_languages: int = 99,
371 | language: Optional[str] = None,
372 | task: Optional[str] = None, # Literal["transcribe", "translate", None]
373 | ) -> Tokenizer:
374 | if language is not None:
375 | language = language.lower()
376 | if language not in LANGUAGES:
377 | if language in TO_LANGUAGE_CODE:
378 | language = TO_LANGUAGE_CODE[language]
379 | else:
380 | raise ValueError(f"Unsupported language: {language}")
381 |
382 | if multilingual:
383 | encoding_name = "multilingual"
384 | language = language or "en"
385 | task = task or "transcribe"
386 | else:
387 | encoding_name = "gpt2"
388 | language = None
389 | task = None
390 |
391 | encoding = get_encoding(name=encoding_name, num_languages=num_languages)
392 |
393 | return Tokenizer(
394 | encoding=encoding, num_languages=num_languages, language=language, task=task
395 | )
396 |
--------------------------------------------------------------------------------
/src/whisper/timing.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import subprocess
3 | import warnings
4 | from dataclasses import dataclass
5 | from typing import TYPE_CHECKING, List
6 |
7 | import numba
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 |
12 | from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
13 | from .tokenizer import Tokenizer
14 |
15 | if TYPE_CHECKING:
16 | from .model import Whisper
17 |
18 |
19 | def median_filter(x: torch.Tensor, filter_width: int):
20 | """Apply a median filter of width `filter_width` along the last dimension of `x`"""
21 | pad_width = filter_width // 2
22 | if x.shape[-1] <= pad_width:
23 | # F.pad requires the padding width to be smaller than the input dimension
24 | return x
25 |
26 | if (ndim := x.ndim) <= 2:
27 | # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
28 | x = x[None, None, :]
29 |
30 | assert (
31 | filter_width > 0 and filter_width % 2 == 1
32 | ), "`filter_width` should be an odd number"
33 |
34 | result = None
35 | x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
36 | if x.is_cuda:
37 | try:
38 | from .triton_ops import median_filter_cuda
39 |
40 | result = median_filter_cuda(x, filter_width)
41 | except (RuntimeError, subprocess.CalledProcessError):
42 | warnings.warn(
43 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
44 | "falling back to a slower median kernel implementation..."
45 | )
46 |
47 | if result is None:
48 | # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
49 | result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
50 |
51 | if ndim <= 2:
52 | result = result[0, 0]
53 |
54 | return result
55 |
56 |
57 | @numba.jit(nopython=True)
58 | def backtrace(trace: np.ndarray):
59 | i = trace.shape[0] - 1
60 | j = trace.shape[1] - 1
61 | trace[0, :] = 2
62 | trace[:, 0] = 1
63 |
64 | result = []
65 | while i > 0 or j > 0:
66 | result.append((i - 1, j - 1))
67 |
68 | if trace[i, j] == 0:
69 | i -= 1
70 | j -= 1
71 | elif trace[i, j] == 1:
72 | i -= 1
73 | elif trace[i, j] == 2:
74 | j -= 1
75 | else:
76 | raise ValueError("Unexpected trace[i, j]")
77 |
78 | result = np.array(result)
79 | return result[::-1, :].T
80 |
81 |
82 | @numba.jit(nopython=True, parallel=True)
83 | def dtw_cpu(x: np.ndarray):
84 | N, M = x.shape
85 | cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
86 | trace = -np.ones((N + 1, M + 1), dtype=np.float32)
87 |
88 | cost[0, 0] = 0
89 | for j in range(1, M + 1):
90 | for i in range(1, N + 1):
91 | c0 = cost[i - 1, j - 1]
92 | c1 = cost[i - 1, j]
93 | c2 = cost[i, j - 1]
94 |
95 | if c0 < c1 and c0 < c2:
96 | c, t = c0, 0
97 | elif c1 < c0 and c1 < c2:
98 | c, t = c1, 1
99 | else:
100 | c, t = c2, 2
101 |
102 | cost[i, j] = x[i - 1, j - 1] + c
103 | trace[i, j] = t
104 |
105 | return backtrace(trace)
106 |
107 |
108 | def dtw_cuda(x, BLOCK_SIZE=1024):
109 | from .triton_ops import dtw_kernel
110 |
111 | M, N = x.shape
112 | assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
113 |
114 | x_skew = (
115 | F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
116 | )
117 | x_skew = x_skew.T.contiguous()
118 | cost = torch.ones(N + M + 2, M + 2) * np.inf
119 | cost[0, 0] = 0
120 | cost = cost.cuda()
121 | trace = torch.zeros_like(cost, dtype=torch.int32)
122 |
123 | dtw_kernel[(1,)](
124 | cost,
125 | trace,
126 | x_skew,
127 | x_skew.stride(0),
128 | cost.stride(0),
129 | trace.stride(0),
130 | N,
131 | M,
132 | BLOCK_SIZE=BLOCK_SIZE,
133 | )
134 |
135 | trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
136 | :, : N + 1
137 | ]
138 | return backtrace(trace.cpu().numpy())
139 |
140 |
141 | def dtw(x: torch.Tensor) -> np.ndarray:
142 | if x.is_cuda:
143 | try:
144 | return dtw_cuda(x)
145 | except (RuntimeError, subprocess.CalledProcessError):
146 | warnings.warn(
147 | "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
148 | "falling back to a slower DTW implementation..."
149 | )
150 |
151 | return dtw_cpu(x.double().cpu().numpy())
152 |
153 |
154 | @dataclass
155 | class WordTiming:
156 | word: str
157 | tokens: List[int]
158 | start: float
159 | end: float
160 | probability: float
161 |
162 |
163 | def find_alignment(
164 | model: "Whisper",
165 | tokenizer: Tokenizer,
166 | text_tokens: List[int],
167 | mel: torch.Tensor,
168 | num_frames: int,
169 | *,
170 | medfilt_width: int = 7,
171 | qk_scale: float = 1.0,
172 | ) -> List[WordTiming]:
173 | if len(text_tokens) == 0:
174 | return []
175 |
176 | tokens = torch.tensor(
177 | [
178 | *tokenizer.sot_sequence,
179 | tokenizer.no_timestamps,
180 | *text_tokens,
181 | tokenizer.eot,
182 | ]
183 | ).to(model.device)
184 |
185 | # install hooks on the cross attention layers to retrieve the attention weights
186 | QKs = [None] * model.dims.n_text_layer
187 | hooks = [
188 | block.cross_attn.register_forward_hook(
189 | lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
190 | )
191 | for i, block in enumerate(model.decoder.blocks)
192 | ]
193 |
194 | with torch.no_grad():
195 | logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
196 | sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
197 | token_probs = sampled_logits.softmax(dim=-1)
198 | text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
199 | text_token_probs = text_token_probs.tolist()
200 |
201 | for hook in hooks:
202 | hook.remove()
203 |
204 | # heads * tokens * frames
205 | weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
206 | weights = weights[:, :, : num_frames // 2]
207 | weights = (weights * qk_scale).softmax(dim=-1)
208 | std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
209 | weights = (weights - mean) / std
210 | weights = median_filter(weights, medfilt_width)
211 |
212 | matrix = weights.mean(axis=0)
213 | matrix = matrix[len(tokenizer.sot_sequence) : -1]
214 | text_indices, time_indices = dtw(-matrix)
215 |
216 | words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
217 | if len(word_tokens) <= 1:
218 | # return on eot only
219 | # >>> np.pad([], (1, 0))
220 | # array([0.])
221 | # This results in crashes when we lookup jump_times with float, like
222 | # IndexError: arrays used as indices must be of integer (or boolean) type
223 | return []
224 | word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
225 |
226 | jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
227 | jump_times = time_indices[jumps] / TOKENS_PER_SECOND
228 | start_times = jump_times[word_boundaries[:-1]]
229 | end_times = jump_times[word_boundaries[1:]]
230 | word_probabilities = [
231 | np.mean(text_token_probs[i:j])
232 | for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
233 | ]
234 |
235 | return [
236 | WordTiming(word, tokens, start, end, probability)
237 | for word, tokens, start, end, probability in zip(
238 | words, word_tokens, start_times, end_times, word_probabilities
239 | )
240 | ]
241 |
242 |
243 | def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
244 | # merge prepended punctuations
245 | i = len(alignment) - 2
246 | j = len(alignment) - 1
247 | while i >= 0:
248 | previous = alignment[i]
249 | following = alignment[j]
250 | if previous.word.startswith(" ") and previous.word.strip() in prepended:
251 | # prepend it to the following word
252 | following.word = previous.word + following.word
253 | following.tokens = previous.tokens + following.tokens
254 | previous.word = ""
255 | previous.tokens = []
256 | else:
257 | j = i
258 | i -= 1
259 |
260 | # merge appended punctuations
261 | i = 0
262 | j = 1
263 | while j < len(alignment):
264 | previous = alignment[i]
265 | following = alignment[j]
266 | if not previous.word.endswith(" ") and following.word in appended:
267 | # append it to the previous word
268 | previous.word = previous.word + following.word
269 | previous.tokens = previous.tokens + following.tokens
270 | following.word = ""
271 | following.tokens = []
272 | else:
273 | i = j
274 | j += 1
275 |
276 |
277 | def add_word_timestamps(
278 | *,
279 | segments: List[dict],
280 | model: "Whisper",
281 | tokenizer: Tokenizer,
282 | mel: torch.Tensor,
283 | num_frames: int,
284 | prepend_punctuations: str = "\"'“¿([{-",
285 | append_punctuations: str = "\"'.。,,!!??::”)]}、",
286 | last_speech_timestamp: float,
287 | **kwargs,
288 | ):
289 | if len(segments) == 0:
290 | return
291 |
292 | text_tokens_per_segment = [
293 | [token for token in segment["tokens"] if token < tokenizer.eot]
294 | for segment in segments
295 | ]
296 |
297 | text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
298 | alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
299 | word_durations = np.array([t.end - t.start for t in alignment])
300 | word_durations = word_durations[word_durations.nonzero()]
301 | median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
302 | max_duration = median_duration * 2
303 |
304 | # hack: truncate long words at sentence boundaries.
305 | # a better segmentation algorithm based on VAD should be able to replace this.
306 | if len(word_durations) > 0:
307 | sentence_end_marks = ".。!!??"
308 | # ensure words at sentence boundaries are not longer than twice the median word duration.
309 | for i in range(1, len(alignment)):
310 | if alignment[i].end - alignment[i].start > max_duration:
311 | if alignment[i].word in sentence_end_marks:
312 | alignment[i].end = alignment[i].start + max_duration
313 | elif alignment[i - 1].word in sentence_end_marks:
314 | alignment[i].start = alignment[i].end - max_duration
315 |
316 | merge_punctuations(alignment, prepend_punctuations, append_punctuations)
317 |
318 | time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
319 | word_index = 0
320 |
321 | for segment, text_tokens in zip(segments, text_tokens_per_segment):
322 | saved_tokens = 0
323 | words = []
324 |
325 | while word_index < len(alignment) and saved_tokens < len(text_tokens):
326 | timing = alignment[word_index]
327 |
328 | if timing.word:
329 | words.append(
330 | dict(
331 | word=timing.word,
332 | start=round(time_offset + timing.start, 2),
333 | end=round(time_offset + timing.end, 2),
334 | probability=timing.probability,
335 | )
336 | )
337 |
338 | saved_tokens += len(timing.tokens)
339 | word_index += 1
340 |
341 | # hack: truncate long words at segment boundaries.
342 | # a better segmentation algorithm based on VAD should be able to replace this.
343 | if len(words) > 0:
344 | # ensure the first and second word after a pause is not longer than
345 | # twice the median word duration.
346 | if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
347 | words[0]["end"] - words[0]["start"] > max_duration
348 | or (
349 | len(words) > 1
350 | and words[1]["end"] - words[0]["start"] > max_duration * 2
351 | )
352 | ):
353 | if (
354 | len(words) > 1
355 | and words[1]["end"] - words[1]["start"] > max_duration
356 | ):
357 | boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
358 | words[0]["end"] = words[1]["start"] = boundary
359 | words[0]["start"] = max(0, words[0]["end"] - max_duration)
360 |
361 | # prefer the segment-level start timestamp if the first word is too long.
362 | if (
363 | segment["start"] < words[0]["end"]
364 | and segment["start"] - 0.5 > words[0]["start"]
365 | ):
366 | words[0]["start"] = max(
367 | 0, min(words[0]["end"] - median_duration, segment["start"])
368 | )
369 | else:
370 | segment["start"] = words[0]["start"]
371 |
372 | # prefer the segment-level end timestamp if the last word is too long.
373 | if (
374 | segment["end"] > words[-1]["start"]
375 | and segment["end"] + 0.5 < words[-1]["end"]
376 | ):
377 | words[-1]["end"] = max(
378 | words[-1]["start"] + median_duration, segment["end"]
379 | )
380 | else:
381 | segment["end"] = words[-1]["end"]
382 |
383 | last_speech_timestamp = segment["end"]
384 |
385 | segment["words"] = words
386 |
--------------------------------------------------------------------------------
/src/ui/ctkdropdown.py:
--------------------------------------------------------------------------------
1 | import difflib
2 | import sys
3 | import time
4 |
5 | import customtkinter as ctk
6 |
7 |
8 | class CTkScrollableDropdownFrame(ctk.CTkToplevel):
9 |
10 | def __init__(self, attach, x=None, y=None, button_color=None, height: int = 200, width: int = None,
11 | fg_color=None, button_height: int = 20, justify="center", scrollbar_button_color=None,
12 | scrollbar=True, scrollbar_button_hover_color=None, frame_border_width=2, values=[],
13 | command=None, image_values=[], alpha: float = 0.97, frame_corner_radius=20, double_click=False,
14 | resize=True, frame_border_color=None, text_color=None, autocomplete=False, **button_kwargs):
15 |
16 | super().__init__(takefocus=1)
17 |
18 | self.focus()
19 | self.lift()
20 | self.alpha = alpha
21 | self.attach = attach
22 | self.corner = frame_corner_radius
23 | self.padding = 0
24 | self.focus_something = False
25 | self.disable = True
26 | self.update()
27 |
28 | if sys.platform.startswith("win"):
29 | self.after(100, lambda: self.overrideredirect(True))
30 | self.transparent_color = self._apply_appearance_mode(self._fg_color)
31 | self.attributes("-transparentcolor", self.transparent_color)
32 | elif sys.platform.startswith("darwin"):
33 | self.overrideredirect(True)
34 | self.transparent_color = 'systemTransparent'
35 | self.attributes("-transparent", True)
36 | self.focus_something = True
37 | else:
38 | self.overrideredirect(True)
39 | self.transparent_color = '#000001'
40 | self.corner = 0
41 | self.padding = 18
42 | self.withdraw()
43 |
44 | self.hide = True
45 | self.attach.bind('', lambda e: self._withdraw() if not self.disable else None, add="+")
46 | self.attach.winfo_toplevel().bind('', lambda e: self._withdraw() if not self.disable else None,
47 | add="+")
48 | self.attach.winfo_toplevel().bind("", lambda e: self._withdraw() if not self.disable else None,
49 | add="+")
50 |
51 | self.attributes('-alpha', 0)
52 | self.disable = False
53 | self.fg_color = ctk.ThemeManager.theme["CTkFrame"]["fg_color"] if fg_color is None else fg_color
54 | self.scroll_button_color = ctk.ThemeManager.theme["CTkScrollbar"][
55 | "button_color"] if scrollbar_button_color is None else scrollbar_button_color
56 | self.scroll_hover_color = ctk.ThemeManager.theme["CTkScrollbar"][
57 | "button_hover_color"] if scrollbar_button_hover_color is None else scrollbar_button_hover_color
58 | self.frame_border_color = ctk.ThemeManager.theme["CTkFrame"][
59 | "border_color"] if frame_border_color is None else frame_border_color
60 | self.button_color = ctk.ThemeManager.theme["CTkFrame"][
61 | "top_fg_color"] if button_color is None else button_color
62 | self.text_color = ctk.ThemeManager.theme["CTkLabel"][
63 | "text_color"] if text_color is None else text_color
64 |
65 | if scrollbar is False:
66 | self.scroll_button_color = self.fg_color
67 | self.scroll_hover_color = self.fg_color
68 |
69 | self.frame = ctk.CTkScrollableFrame(self, bg_color=self.transparent_color, fg_color=self.fg_color,
70 | scrollbar_button_hover_color=self.scroll_hover_color,
71 | corner_radius=self.corner, border_width=frame_border_width,
72 | scrollbar_button_color=self.scroll_button_color,
73 | border_color=self.frame_border_color)
74 | self.frame._scrollbar.grid_configure(padx=3)
75 | self.frame.pack(expand=True, fill="both")
76 | self.dummy_entry = ctk.CTkEntry(self.frame, fg_color="transparent", border_width=0, height=1, width=1)
77 | self.no_match = ctk.CTkLabel(self.frame, text="No Match")
78 | self.height = height
79 | self.height_new = height
80 | self.width = width
81 | self.command = command
82 | self.fade = False
83 | self.resize = resize
84 | self.autocomplete = autocomplete
85 | self.var_update = ctk.StringVar()
86 | self.appear = False
87 |
88 | if justify.lower() == "left":
89 | self.justify = "w"
90 | elif justify.lower() == "right":
91 | self.justify = "e"
92 | else:
93 | self.justify = "c"
94 |
95 | self.button_height = button_height
96 | self.values = values
97 | self.button_num = len(self.values)
98 | self.image_values = None if len(image_values) != len(self.values) else image_values
99 |
100 | self.resizable(width=False, height=False)
101 | self.transient(self.master)
102 | self._init_buttons(**button_kwargs)
103 |
104 | # Add binding for different ctk widgets
105 | if double_click or self.attach.winfo_name().startswith("!ctkentry") or self.attach.winfo_name().startswith(
106 | "!ctkcombobox"):
107 | self.attach.bind('', lambda e: self._iconify(), add="+")
108 | else:
109 | self.attach.bind('', lambda e: self._iconify(), add="+")
110 |
111 | if self.attach.winfo_name().startswith("!ctkcombobox"):
112 | self.attach._canvas.tag_bind("right_parts", "", lambda e: self._iconify())
113 | self.attach._canvas.tag_bind("dropdown_arrow", "", lambda e: self._iconify())
114 | if self.command is None:
115 | self.command = self.attach.set
116 |
117 | if self.attach.winfo_name().startswith("!ctkoptionmenu"):
118 | self.attach._canvas.bind("", lambda e: self._iconify())
119 | self.attach._text_label.bind("", lambda e: self._iconify())
120 | if self.command is None:
121 | self.command = self.attach.set
122 |
123 | self.attach.bind("", lambda _: self._destroy(), add="+")
124 |
125 | self.update_idletasks()
126 | self.x = x
127 | self.y = y
128 |
129 | if self.autocomplete:
130 | self.bind_autocomplete()
131 |
132 | self.deiconify()
133 | self.withdraw()
134 |
135 | self.attributes("-alpha", self.alpha)
136 |
137 | def _destroy(self):
138 | self.after(500, self.destroy_popup)
139 |
140 | def _withdraw(self):
141 | if self.winfo_viewable() and self.hide:
142 | self.withdraw()
143 |
144 | self.event_generate("<>")
145 | self.hide = True
146 |
147 | def _update(self, a, b, c):
148 | self.live_update(self.attach._entry.get())
149 |
150 | def bind_autocomplete(self, ):
151 | def appear(x):
152 | self.appear = True
153 |
154 | if self.attach.winfo_name().startswith("!ctkcombobox"):
155 | self.attach._entry.configure(textvariable=self.var_update)
156 | self.attach._entry.bind("", appear)
157 | self.attach.set(self.values[0])
158 | self.var_update.trace_add('write', self._update)
159 |
160 | if self.attach.winfo_name().startswith("!ctkentry"):
161 | self.attach.configure(textvariable=self.var_update)
162 | self.attach.bind("", appear)
163 | self.var_update.trace_add('write', self._update)
164 |
165 | def fade_out(self):
166 | for i in range(100, 0, -10):
167 | if not self.winfo_exists():
168 | break
169 | self.attributes("-alpha", i / 100)
170 | self.update()
171 | time.sleep(1 / 100)
172 |
173 | def fade_in(self):
174 | for i in range(0, 100, 10):
175 | if not self.winfo_exists():
176 | break
177 | self.attributes("-alpha", i / 100)
178 | self.update()
179 | time.sleep(1 / 100)
180 |
181 | def _init_buttons(self, **button_kwargs):
182 | self.i = 0
183 | self.widgets = {}
184 | for row in self.values:
185 | self.widgets[self.i] = ctk.CTkButton(self.frame,
186 | text=row,
187 | height=self.button_height,
188 | fg_color=self.button_color,
189 | text_color=self.text_color,
190 | image=self.image_values[
191 | self.i] if self.image_values is not None else None,
192 | anchor=self.justify,
193 | command=lambda k=row: self._attach_key_press(k),
194 | **button_kwargs)
195 | self.widgets[self.i].pack(fill="x", pady=2, padx=(self.padding, 0))
196 | self.i += 1
197 |
198 | self.hide = False
199 |
200 | def destroy_popup(self):
201 | self.disable = True
202 | self.destroy()
203 |
204 | def place_dropdown(self):
205 | self.x_pos = self.attach.winfo_rootx() if self.x is None else self.x + self.attach.winfo_rootx()
206 | self.y_pos = self.attach.winfo_rooty() + self.attach.winfo_reqheight() + 5 if self.y is None else self.y + self.attach.winfo_rooty()
207 | self.width_new = self.attach.winfo_width() if self.width is None else self.width
208 |
209 | if self.resize:
210 | if self.button_num <= 5:
211 | self.height_new = self.button_height * self.button_num + 55
212 | else:
213 | self.height_new = self.button_height * self.button_num + 35
214 | if self.height_new > self.height:
215 | self.height_new = self.height
216 |
217 | self.geometry('{}x{}+{}+{}'.format(self.width_new, self.height_new,
218 | self.x_pos, self.y_pos))
219 | self.fade_in()
220 | self.attributes('-alpha', self.alpha)
221 | self.attach.focus()
222 |
223 | def _iconify(self):
224 | if self.disable: return
225 | if self.hide:
226 | self.event_generate("<>")
227 | self._deiconify()
228 | self.focus()
229 | self.hide = False
230 | self.place_dropdown()
231 | if self.focus_something:
232 | self.dummy_entry.pack()
233 | self.dummy_entry.focus_set()
234 | self.after(100, self.dummy_entry.pack_forget)
235 | else:
236 | self.withdraw()
237 | self.hide = True
238 |
239 | def _attach_key_press(self, k):
240 | self.event_generate("<>")
241 | self.fade = True
242 | if self.command:
243 | self.command(k)
244 | self.fade = False
245 | self.fade_out()
246 | self.withdraw()
247 | self.hide = True
248 |
249 | def live_update(self, string=None):
250 | if not self.appear: return
251 | if self.disable: return
252 | if self.fade: return
253 | if string:
254 | string = string.lower()
255 | self._deiconify()
256 | i = 1
257 | for key in self.widgets.keys():
258 | s = self.widgets[key].cget("text").lower()
259 | text_similarity = difflib.SequenceMatcher(None, s[0:len(string)], string).ratio()
260 | similar = s.startswith(string) or text_similarity > 0.75
261 | if not similar:
262 | self.widgets[key].pack_forget()
263 | else:
264 | self.widgets[key].pack(fill="x", pady=2, padx=(self.padding, 0))
265 | i += 1
266 |
267 | if i == 1:
268 | self.no_match.pack(fill="x", pady=2, padx=(self.padding, 0))
269 | else:
270 | self.no_match.pack_forget()
271 | self.button_num = i
272 | self.place_dropdown()
273 |
274 | else:
275 | self.no_match.pack_forget()
276 | self.button_num = len(self.values)
277 | for key in self.widgets.keys():
278 | self.widgets[key].destroy()
279 | self._init_buttons()
280 | self.place_dropdown()
281 |
282 | self.frame._parent_canvas.yview_moveto(0.0)
283 | self.appear = False
284 |
285 | def insert(self, value, **kwargs):
286 | self.widgets[self.i] = ctk.CTkButton(self.frame,
287 | text=value,
288 | height=self.button_height,
289 | fg_color=self.button_color,
290 | text_color=self.text_color,
291 | anchor=self.justify,
292 | command=lambda k=value: self._attach_key_press(k), **kwargs)
293 | self.widgets[self.i].pack(fill="x", pady=2, padx=(self.padding, 0))
294 | self.i += 1
295 | self.values.append(value)
296 |
297 | def _deiconify(self):
298 | if len(self.values) > 0:
299 | self.deiconify()
300 |
301 | def popup(self, x=None, y=None):
302 | self.x = x
303 | self.y = y
304 | self.hide = True
305 | self._iconify()
306 |
307 | def configure(self, **kwargs):
308 | if "height" in kwargs:
309 | self.height = kwargs.pop("height")
310 | self.height_new = self.height
311 |
312 | if "alpha" in kwargs:
313 | self.alpha = kwargs.pop("alpha")
314 |
315 | if "width" in kwargs:
316 | self.width = kwargs.pop("width")
317 |
318 | if "fg_color" in kwargs:
319 | self.frame.configure(fg_color=kwargs.pop("fg_color"))
320 |
321 | if "values" in kwargs:
322 | self.values = kwargs.pop("values")
323 | self.image_values = None
324 | self.button_num = len(self.values)
325 | for key in self.widgets.keys():
326 | self.widgets[key].destroy()
327 | self._init_buttons()
328 |
329 | if "image_values" in kwargs:
330 | self.image_values = kwargs.pop("image_values")
331 | self.image_values = None if len(self.image_values) != len(self.values) else self.image_values
332 | if self.image_values is not None:
333 | i = 0
334 | for key in self.widgets.keys():
335 | self.widgets[key].configure(image=self.image_values[i])
336 | i += 1
337 |
338 | if "button_color" in kwargs:
339 | for key in self.widgets.keys():
340 | self.widgets[key].configure(fg_color=kwargs.pop("button_color"))
341 |
342 | for key in self.widgets.keys():
343 | self.widgets[key].configure(**kwargs)
344 |
--------------------------------------------------------------------------------
/src/ui/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import sys
4 | import threading
5 | import webbrowser
6 |
7 | import customtkinter as ctk
8 |
9 | from .ctkAlert import CTkAlert
10 | from .ctkLoader import CTkLoader
11 | from .ctkdropdown import CTkScrollableDropdownFrame
12 | from .icons import icons
13 | from .style import FONTS, DROPDOWN
14 | from ..logic import GPUInfo, SettingsHandler, ModelRequirements
15 |
16 | current_path = os.path.dirname(os.path.realpath(__file__))
17 |
18 | languages = [
19 | "Auto Detection",
20 | "English",
21 | "Chinese",
22 | "German",
23 | "Spanish",
24 | "Russian",
25 | "Korean",
26 | "French",
27 | "Japanese",
28 | "Portuguese",
29 | "Turkish",
30 | "Polish",
31 | "Catalan",
32 | "Dutch",
33 | "Arabic",
34 | "Swedish",
35 | "Italian",
36 | "Indonesian",
37 | "Hindi",
38 | "Finnish",
39 | "Vietnamese",
40 | "Hebrew",
41 | "Ukrainian",
42 | "Greek",
43 | "Malay",
44 | "Czech",
45 | "Romanian",
46 | "Danish",
47 | "Hungarian",
48 | "Tamil",
49 | "Norwegian",
50 | "Thai",
51 | "Urdu",
52 | "Croatian",
53 | "Bulgarian",
54 | "Lithuanian",
55 | "Latin",
56 | "Maori",
57 | "Malayalam",
58 | "Welsh",
59 | "Slovak",
60 | "Telugu",
61 | "Persian",
62 | "Latvian",
63 | "Bengali",
64 | "Serbian",
65 | "Azerbaijani",
66 | "Slovenian",
67 | "Kannada",
68 | "Estonian",
69 | "Macedonian",
70 | "Breton",
71 | "Basque",
72 | "Icelandic",
73 | "Armenian",
74 | "Nepali",
75 | "Mongolian",
76 | "Bosnian",
77 | "Kazakh",
78 | "Albanian",
79 | "Swahili",
80 | "Galician",
81 | "Marathi",
82 | "Punjabi",
83 | "Sinhala",
84 | "Khmer",
85 | "Shona",
86 | "Yoruba",
87 | "Somali",
88 | "Afrikaans",
89 | "Occitan",
90 | "Georgian",
91 | "Belarusian",
92 | "Tajik",
93 | "Sindhi",
94 | "Gujarati",
95 | "Amharic",
96 | "Yiddish",
97 | "Lao",
98 | "Uzbek",
99 | "Faroese",
100 | "Haitian creole",
101 | "Pashto",
102 | "Turkmen",
103 | "Nynorsk",
104 | "Maltese",
105 | "Sanskrit",
106 | "Luxembourgish",
107 | "Myanmar",
108 | "Tibetan",
109 | "Tagalog",
110 | "Malagasy",
111 | "Assamese",
112 | "Tatar",
113 | "Hawaiian",
114 | "Lingala",
115 | "Hausa",
116 | "Bashkir",
117 | "Javanese",
118 | "Sundanese",
119 | ]
120 |
121 |
122 | def help_link():
123 | webbrowser.open("https://github.com/rudymohammadbali/Whisper-Transcriber/discussions/categories/q-a")
124 |
125 |
126 | class SettingsUI(ctk.CTkFrame):
127 | def __init__(self, parent):
128 | super().__init__(master=parent, width=620, height=720, fg_color=("#F2F0EE", "#1E1F22"), border_width=0)
129 | self.grid_propagate(False)
130 | self.grid_columnconfigure(0, weight=1)
131 |
132 | self.master = parent
133 | self.settings_handler = SettingsHandler()
134 |
135 | self.theme_btn = None
136 | self.color_theme_btn = None
137 | self.device_dropdown = None
138 | self.language_dropdown = None
139 | self.model_dropdown = None
140 | self.theme_dropdown = None
141 | self.color_theme_dropdown = None
142 | self.general_frame = None
143 | self.model_frame = None
144 | self.gpu_frame = None
145 | self.reset_btn = None
146 | self.mic_dropdown = None
147 | self.mic_btn = None
148 |
149 | self.loader = None
150 | self.process = None
151 | self.thread = None
152 |
153 | title = ctk.CTkLabel(self, text=" Settings", font=FONTS["title"], image=icons["settings"], compound="left")
154 | title.grid(row=0, column=0, padx=20, pady=20, sticky="w")
155 |
156 | self.close_btn = ctk.CTkButton(self, text="", image=icons["close"], fg_color="transparent", hover=False,
157 | width=30,
158 | height=30, command=self.hide_settings_ui)
159 | self.close_btn.grid(row=0, column=1, padx=20, pady=20, sticky="e")
160 |
161 | self.main_frame = ctk.CTkFrame(self, fg_color="transparent")
162 | self.main_frame.grid(row=1, column=0, padx=20, pady=0, sticky="nsew", columnspan=2)
163 | self.main_frame.grid_columnconfigure(0, weight=1)
164 |
165 | self.default_widget()
166 |
167 | self.reset_btn = ctk.CTkButton(self, text="Reset", height=35, command=self.reset_callback, font=FONTS["normal"])
168 | self.reset_btn.grid(row=3, column=0, padx=20, pady=0, sticky="w")
169 |
170 | self.grid(row=0, column=0, sticky="nsew")
171 |
172 | def default_widget(self):
173 | segmented_button = ctk.CTkSegmentedButton(self.main_frame, values=["General", "Model Settings", "GPU Info"],
174 | width=150,
175 | height=35, command=self.update_view, border_width=0,
176 | font=FONTS["normal"])
177 | segmented_button.grid(row=0, column=0, padx=0, pady=20, sticky="nsew", columnspan=3)
178 | segmented_button.set("General")
179 |
180 | self.general_frame = ctk.CTkFrame(self.main_frame)
181 | self.general_frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew", columnspan=2)
182 | self.general_frame.grid_columnconfigure(0, weight=1)
183 | self.general_widget()
184 |
185 | def update_view(self, frame):
186 | try:
187 | self.general_frame.grid_forget()
188 | self.model_frame.grid_forget()
189 | self.gpu_frame.grid_forget()
190 | except AttributeError:
191 | pass
192 |
193 | if frame == "General":
194 | self.general_frame = ctk.CTkFrame(self.main_frame)
195 | self.general_frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew", columnspan=2)
196 | self.general_frame.grid_columnconfigure(0, weight=1)
197 | self.general_widget()
198 |
199 | elif frame == "Model Settings":
200 | self.model_frame = ctk.CTkFrame(self.main_frame)
201 | self.model_frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew", columnspan=2)
202 | self.model_frame.grid_columnconfigure(0, weight=1)
203 | self.model_frame.grid_columnconfigure(1, weight=1)
204 | self.model_widget()
205 |
206 | elif frame == "GPU Info":
207 | self.gpu_frame = ctk.CTkFrame(self.main_frame)
208 | self.gpu_frame.grid(row=1, column=0, padx=0, pady=20, sticky="nsew", columnspan=2)
209 | self.gpu_frame.grid_columnconfigure(0, weight=1)
210 | self.gpu_widget()
211 |
212 | def gpu_widget(self):
213 | get_info = GPUInfo()
214 | info = get_info.load_gpu_info()
215 | cuda_available = info.get("cuda_available", "N/A")
216 | if cuda_available:
217 | cuda_available = "True"
218 | else:
219 | cuda_available = "False"
220 |
221 | label_1 = ctk.CTkLabel(self.gpu_frame, text="CUDA Available", font=FONTS["normal"])
222 | label_1.grid(row=0, column=0, padx=20, pady=(20, 10), sticky="w")
223 | label_1_value = ctk.CTkLabel(self.gpu_frame, text=cuda_available, font=FONTS["small"])
224 | label_1_value.grid(row=0, column=1, padx=20, pady=(20, 10), sticky="e")
225 |
226 | label_2 = ctk.CTkLabel(self.gpu_frame, text="GPU Count", font=FONTS["normal"])
227 | label_2.grid(row=1, column=0, padx=20, pady=10, sticky="w")
228 | label_2_value = ctk.CTkLabel(self.gpu_frame, text=info.get("gpu_count", "N/A"), font=FONTS["small"])
229 | label_2_value.grid(row=1, column=1, padx=20, pady=10, sticky="e")
230 |
231 | label_3 = ctk.CTkLabel(self.gpu_frame, text="Current GPU", font=FONTS["normal"])
232 | label_3.grid(row=2, column=0, padx=20, pady=10, sticky="w")
233 | label_3_value = ctk.CTkLabel(self.gpu_frame, text=info.get("current_gpu", "N/A"), font=FONTS["small"])
234 | label_3_value.grid(row=2, column=1, padx=20, pady=10, sticky="e")
235 |
236 | label_4 = ctk.CTkLabel(self.gpu_frame, text="GPU Name", font=FONTS["normal"])
237 | label_4.grid(row=3, column=0, padx=20, pady=10, sticky="w")
238 | label_4_value = ctk.CTkLabel(self.gpu_frame, text=info.get("gpu_name", "N/A"), font=FONTS["small"])
239 | label_4_value.grid(row=3, column=1, padx=20, pady=10, sticky="e")
240 |
241 | label_5 = ctk.CTkLabel(self.gpu_frame, text="Total Memory", font=FONTS["normal"])
242 | label_5.grid(row=4, column=0, padx=20, pady=(10, 20), sticky="w")
243 | label_5_value = ctk.CTkLabel(self.gpu_frame, text=f"{info.get('total_memory', 'N/A')} GB", font=FONTS["small"])
244 | label_5_value.grid(row=4, column=1, padx=20, pady=(10, 20), sticky="e")
245 |
246 | def model_widget(self):
247 | settings = self.settings_handler.load_settings()
248 |
249 | model_size = settings.get("model_size", "Base").capitalize()
250 | language = settings.get("language", "Auto Detection").capitalize()
251 | device = settings.get("device", "CPU").upper()
252 |
253 | model_requirements = ModelRequirements()
254 | model_values, device_values = model_requirements.update_model_requirements()
255 | model_size_label = ctk.CTkLabel(self.model_frame, text="Model Size", font=FONTS["normal"])
256 | model_size_label.grid(row=0, column=0, padx=20, pady=(20, 10), sticky="w")
257 | model_size_value = ctk.CTkOptionMenu(self.model_frame, font=FONTS["small"])
258 | model_size_value.grid(row=0, column=1, padx=20, pady=10, sticky="e")
259 | self.model_dropdown = CTkScrollableDropdownFrame(model_size_value, values=model_values, **DROPDOWN,
260 | command=lambda
261 | new_value=model_size_value: self.change_settings_value(
262 | key_name="model_size", new_value=f"{new_value}"))
263 | model_size_value.set(model_size)
264 |
265 | language_label = ctk.CTkLabel(self.model_frame, text="Language", font=FONTS["normal"])
266 | language_label.grid(row=1, column=0, padx=20, pady=10, sticky="w")
267 | language_value = ctk.CTkOptionMenu(self.model_frame, font=FONTS["small"])
268 | language_value.grid(row=1, column=1, padx=20, pady=10, sticky="e")
269 | self.language_dropdown = CTkScrollableDropdownFrame(language_value, values=languages, **DROPDOWN,
270 | command=lambda
271 | new_value=language_value: self.change_settings_value(
272 | key_name="language", new_value=f"{new_value}"))
273 | language_value.set(language)
274 |
275 | device_label = ctk.CTkLabel(self.model_frame, text="Device", font=FONTS["normal"])
276 | device_label.grid(row=2, column=0, padx=20, pady=10, sticky="w")
277 | device_value = ctk.CTkOptionMenu(self.model_frame, font=FONTS["small"])
278 | device_value.grid(row=2, column=1, padx=20, pady=10, sticky="e")
279 | self.device_dropdown = CTkScrollableDropdownFrame(device_value, values=device_values, **DROPDOWN,
280 | command=lambda
281 | new_value=device_value: self.change_settings_value(
282 | key_name="device", new_value=f"{new_value}"))
283 | device_value.set(device)
284 |
285 | download_btn = ctk.CTkButton(self.model_frame, text="Download All Models", command=self.download_callback,
286 | height=50, font=FONTS["small"])
287 | download_btn.grid(row=3, column=1, padx=20, pady=(0, 20), sticky="e")
288 |
289 | btns_frame = ctk.CTkScrollableFrame(self.model_frame, fg_color="transparent", label_text="Installed Models",
290 | label_font=FONTS["normal"])
291 | btns_frame.grid(row=3, column=0, padx=20, pady=(0, 20), sticky="w")
292 |
293 | installed_models = []
294 |
295 | cache_folder = os.path.join(os.path.expanduser('~'), f'.cache{os.path.sep}whisper{os.path.sep}')
296 | files = os.listdir(cache_folder)
297 | for file in files:
298 | name, extension = os.path.splitext(file)
299 |
300 | installed_models.append(name)
301 |
302 | if installed_models:
303 | for index, model in enumerate(installed_models):
304 | model_btn = ctk.CTkButton(btns_frame, text=str(model).capitalize(), height=25, hover=False,
305 | corner_radius=2, font=FONTS["small"])
306 | model_btn.grid(row=index, column=0, padx=20, pady=5, sticky="nsew")
307 | else:
308 | warning_btn = ctk.CTkButton(btns_frame, text="No models are installed", height=25, hover=False,
309 | corner_radius=2, font=FONTS["small"])
310 | warning_btn.grid(row=0, column=0, padx=20, pady=5, sticky="nsew")
311 |
312 | def general_widget(self):
313 | settings = self.settings_handler.load_settings()
314 | theme = settings.get("theme", "System").capitalize()
315 | color_theme = settings.get("color_theme", "Blue").capitalize()
316 |
317 | theme_label = ctk.CTkLabel(self.general_frame, text="Theme", font=FONTS["normal"])
318 | theme_label.grid(row=0, column=0, padx=20, pady=(20, 10), sticky="w")
319 | self.theme_btn = ctk.CTkOptionMenu(self.general_frame, font=FONTS["small"])
320 | values = ["System", "Light", "Dark"]
321 | self.theme_btn.grid(row=0, column=1, padx=20, pady=(20, 10), sticky="e")
322 | self.theme_dropdown = CTkScrollableDropdownFrame(self.theme_btn, values=values, **DROPDOWN,
323 | command=self.change_theme)
324 | self.theme_btn.set(theme)
325 |
326 | color_theme_label = ctk.CTkLabel(self.general_frame, text="Color Theme", font=FONTS["normal"])
327 | color_theme_label.grid(row=1, column=0, padx=20, pady=(20, 10), sticky="w")
328 | self.color_theme_btn = ctk.CTkOptionMenu(self.general_frame, font=FONTS["small"])
329 | color_values = ["Blue", "Dark-Blue", "Green"]
330 | self.color_theme_btn.grid(row=1, column=1, padx=20, pady=(20, 10), sticky="e")
331 | self.color_theme_dropdown = CTkScrollableDropdownFrame(self.color_theme_btn, values=color_values,
332 | **DROPDOWN,
333 | command=self.change_color_theme)
334 | self.color_theme_btn.set(color_theme)
335 |
336 | developer_label = ctk.CTkLabel(self.general_frame, text="Developer", font=FONTS["normal"])
337 | developer_label.grid(row=2, column=0, padx=20, pady=10, sticky="w")
338 | developer_value = ctk.CTkLabel(self.general_frame, text="@iamironman", font=FONTS["small"])
339 | developer_value.grid(row=2, column=1, padx=20, pady=10, sticky="e")
340 |
341 | released_label = ctk.CTkLabel(self.general_frame, text="Released", font=FONTS["normal"])
342 | released_label.grid(row=3, column=0, padx=20, pady=10, sticky="w")
343 | released_value = ctk.CTkLabel(self.general_frame, text="12/2/2023", font=FONTS["small"])
344 | released_value.grid(row=3, column=1, padx=20, pady=10, sticky="e")
345 |
346 | help_label = ctk.CTkLabel(self.general_frame, text="FAQ or Help Center", font=FONTS["normal"])
347 | help_label.grid(row=4, column=0, padx=20, pady=10, sticky="w")
348 | help_value = ctk.CTkButton(self.general_frame, text="Get help", font=FONTS["small"], command=help_link)
349 | help_value.grid(row=4, column=1, padx=20, pady=10, sticky="e")
350 |
351 | def download_callback(self):
352 | self.close_btn.configure(state="disabled")
353 | self.reset_btn.destroy()
354 | widgets = self.main_frame.winfo_children()
355 |
356 | for widget in widgets:
357 | widget.grid_forget()
358 |
359 | self.loader = CTkLoader(parent=self.master, title="Downloading Models", msg="Please wait...",
360 | cancel_func=self.cancel_downloads)
361 |
362 | self.thread = threading.Thread(target=self.download_thread)
363 | self.thread.start()
364 |
365 | def download_thread(self):
366 | interpreter_path = sys.executable
367 | script_path = f"{current_path}{os.path.sep}download_models.py"
368 |
369 | command = [interpreter_path, script_path]
370 | self.process = subprocess.Popen(command)
371 |
372 | return_code = self.process.wait()
373 | if return_code == 0:
374 | CTkAlert(parent=self.master, status="success", title="Download Complete",
375 | msg="All available models are installed successfully.")
376 | self.cancel_downloads()
377 | else:
378 | CTkAlert(parent=self.master, status="error", title="Download Failed",
379 | msg="Failed to download models. Please check your internet connection and try again.")
380 |
381 | self.cancel_downloads()
382 |
383 | def cancel_downloads(self):
384 | self.loader.hide_loader()
385 | self.process.terminate()
386 | self.close_btn.configure(state="normal")
387 | self.main_frame.grid(row=1, column=0, padx=20, pady=0, sticky="nsew", columnspan=2)
388 | self.after(1000, self.default_widget)
389 |
390 | def reset_callback(self):
391 | message = self.settings_handler.reset_settings()
392 | for widget in self.main_frame.winfo_children():
393 | widget.grid_forget()
394 |
395 | self.default_widget()
396 |
397 | CTkAlert(parent=self.master, status="success", title="Success", msg=message)
398 |
399 | def change_settings_value(self, key_name: str, new_value: str):
400 | self.settings_handler.save_settings(**{f"{key_name}": f"{new_value.lower()}"})
401 | for widget in self.model_frame.winfo_children():
402 | widget.grid_forget()
403 |
404 | self.update_view(frame="Model Settings")
405 |
406 | def change_theme(self, theme):
407 | new_theme = str(theme).lower()
408 | ctk.set_appearance_mode(new_theme)
409 | for widget in self.main_frame.winfo_children():
410 | widget.grid_forget()
411 |
412 | self.settings_handler.save_settings(**{"theme": f"{new_theme}"})
413 |
414 | self.default_widget()
415 |
416 | def change_color_theme(self, color_theme):
417 | new_color_theme = str(color_theme).lower()
418 | ctk.set_default_color_theme(new_color_theme)
419 |
420 | for widget in self.main_frame.winfo_children():
421 | widget.grid_forget()
422 |
423 | self.settings_handler.save_settings(**{"color_theme": f"{new_color_theme}"})
424 |
425 | self.default_widget()
426 |
427 | def hide_settings_ui(self):
428 | self.destroy()
429 |
--------------------------------------------------------------------------------