├── .gitignore ├── LICENSE ├── README.md ├── bootstrap.sh ├── cog.yaml ├── lib ├── __init__.py ├── audio.py └── diarization.py └── predict.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Installer logs 7 | pip-log.txt 8 | pip-delete-this-directory.txt 9 | 10 | # Jupyter Notebook 11 | .ipynb_checkpoints 12 | 13 | # IPython 14 | profile_default/ 15 | ipython_config.py 16 | 17 | # Environments 18 | .env 19 | .venv 20 | env/ 21 | venv/ 22 | ENV/ 23 | env.bak/ 24 | venv.bak/ 25 | 26 | # VSCode 27 | .vscode/ 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 CNRS 4 | Copyright (c) 2022 OpenAI 5 | Copyright (c) 2023 Lucian Boca 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # speaker-transcription 2 | 3 | This repository contains the Cog definition files for the associated speaker transcription model [deployed on Replicate](https://replicate.com/meronym/speaker-transcription). 4 | 5 | The pipeline transcribes the speech segments of an audio file, identifies the individual speakers and annotates the transcript with timestamps and speaker labels. An optional `prompt` string can guide the transcription by providing additional context. The pipeline outputs additional global information about the number of detected speakers and an embedding vector for each speaker to describe the quality of their voice. 6 | 7 | ## Model description 8 | 9 | There are two main components involved in this process: 10 | 11 | - a pre-trained speaker diarization pipeline from the [`pyannote.audio`](pyannote.github.io) package (also available as a [stand-alone diarization model](https://replicate.com/meronym/speaker-diarization) without transcription): 12 | 13 | - `pyannote/segmentation` for permutation-invariant speaker segmentation on temporal slices 14 | - `speechbrain/spkrec-ecapa-voxceleb` for generating speaker embeddings 15 | - `AgglomerativeClustering` for matching embeddings across temporal slices 16 | 17 | - OpenAI's `whisper` model for general-purpose English speech transcription (the `medium.en` model size is used for a good balance between accuracy and performance). 18 | 19 | The audio data is first passed in to the speaker diarization pipeline, which computes a list of timestamped segments and associates each segment with a speaker. The segments are then transcribed with `whisper`. 20 | 21 | ## Input format 22 | 23 | The pipeline uses `ffmpeg` to decode the input audio, so it supports a wide variety of input formats - including, but not limited to `mp3`, `aac`, `flac`, `ogg`, `opus`, `wav`. 24 | 25 | The `prompt` string gets injected as (off-screen) additional context at the beginning of the first Whisper transcription window for each segment. It won't be part of the final output, but it can be used for guiding/conditioning the transcription towards a specific domain. 26 | 27 | ## Output format 28 | 29 | The pipeline outputs a single `output.json` file with the following structure: 30 | 31 | ```json 32 | { 33 | "segments": [ 34 | { 35 | "speaker": "A", 36 | "start": "0:00:00.497812", 37 | "stop": "0:00:09.762188", 38 | "transcript": [ 39 | { 40 | "start": "0:00:00.497812", 41 | "text": " What are some cool synthetic organisms that you think about, you dream about?" 42 | }, 43 | { 44 | "start": "0:00:04.357812", 45 | "text": " When you think about embodied mind, what do you imagine?" 46 | }, 47 | { 48 | "start": "0:00:08.017812", 49 | "text": " What do you hope to build?" 50 | } 51 | ] 52 | }, 53 | { 54 | "speaker": "B", 55 | "start": "0:00:09.863438", 56 | "stop": "0:03:34.962188", 57 | "transcript": [ 58 | { 59 | "start": "0:00:09.863438", 60 | "text": " Yeah, on a practical level, what I really hope to do is to gain enough of an understanding of the embodied intelligence of the organs and tissues, such that we can achieve a radically different regenerative medicine, so that we can say, basically, and I think about it as, you know, in terms of like, okay, can you what's the what's the what's the goal, kind of end game for this whole thing? To me, the end game is something that you would call an" 61 | }, 62 | { 63 | "start": "0:00:39.463438", 64 | "text": " anatomical compiler. So the idea is you would sit down in front of the computer and you would draw the body or the organ that you wanted. Not molecular details, but like, here, this is what I want. I want a six legged, you know, frog with a propeller on top, or I want I want a heart that looks like this, or I want a leg that looks like this. And what it would do if we knew what we were doing is put out, convert that anatomical description into a set of stimuli that would have to be given to cells to convince them to build exactly that thing." 65 | }, 66 | { 67 | "start": "0:01:08.503438", 68 | "text": " Right? I probably won't live to see it. But I think it's achievable. And I think what that if, if we can have that, then that is basically the solution to all of medicine, except for infectious disease. So birth defects, right, traumatic injury, cancer, aging, degenerative disease, if we knew how to tell cells what to build, all of those things go away. So those things go away, and the positive feedback spiral of economic costs, where all of the advances are increasingly more" 69 | }, 70 | ] 71 | } 72 | ], 73 | "speakers": { 74 | "count": 2, 75 | "labels": [ 76 | "A", 77 | "B" 78 | ], 79 | "embeddings": { 80 | "A": [], 81 | "B": [] 82 | } 83 | } 84 | } 85 | ``` 86 | 87 | ## Performance 88 | 89 | The current T4 deployment has an average processing speed factor of 4x (relative to the length of the audio input) - e.g. it will take the model approx. 1 minute of computation to process 4 minutes of audio. 90 | 91 | ## Intended use 92 | 93 | Data augmentation and segmentation for a variety of transcription and captioning tasks (e.g. interviews, podcasts, meeting recordings, etc.). Speaker recognition can be implemented by matching the speaker embeddings against a database of known speakers. 94 | 95 | ## Ethical considerations 96 | 97 | This model may have biases based on the data it has been trained on. It is important to use the model in a responsible manner and adhere to ethical and legal standards. 98 | 99 | ## Citations 100 | 101 | For `pyannote.audio`: 102 | 103 | ```bibtex 104 | @inproceedings{Bredin2020, 105 | Title = {{pyannote.audio: neural building blocks for speaker diarization}}, 106 | Author = {{Bredin}, Herv{\'e} and {Yin}, Ruiqing and {Coria}, Juan Manuel and {Gelly}, Gregory and {Korshunov}, Pavel and {Lavechin}, Marvin and {Fustes}, Diego and {Titeux}, Hadrien and {Bouaziz}, Wassim and {Gill}, Marie-Philippe}, 107 | Booktitle = {ICASSP 2020, IEEE International Conference on Acoustics, Speech, and Signal Processing}, 108 | Year = {2020}, 109 | } 110 | ``` 111 | 112 | ```bibtex 113 | @inproceedings{Bredin2021, 114 | Title = {{End-to-end speaker segmentation for overlap-aware resegmentation}}, 115 | Author = {{Bredin}, Herv{\'e} and {Laurent}, Antoine}, 116 | Booktitle = {Proc. Interspeech 2021}, 117 | Year = {2021}, 118 | } 119 | ``` 120 | 121 | For OpenAI `whisper`: 122 | 123 | ```bibtex 124 | @misc{https://doi.org/10.48550/arxiv.2212.04356, 125 | doi = {10.48550/ARXIV.2212.04356}, 126 | url = {https://arxiv.org/abs/2212.04356}, 127 | author = {Radford, Alec and Kim, Jong Wook and Xu, Tao and Brockman, Greg and McLeavey, Christine and Sutskever, Ilya}, 128 | keywords = {Audio and Speech Processing (eess.AS), Computation and Language (cs.CL), Machine Learning (cs.LG), Sound (cs.SD), FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Computer and information sciences, FOS: Computer and information sciences}, 129 | title = {Robust Speech Recognition via Large-Scale Weak Supervision}, 130 | publisher = {arXiv}, 131 | year = {2022}, 132 | copyright = {arXiv.org perpetual, non-exclusive license} 133 | } 134 | ``` 135 | -------------------------------------------------------------------------------- /bootstrap.sh: -------------------------------------------------------------------------------- 1 | # install cog 2 | sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m` 3 | sudo chmod +x /usr/local/bin/cog 4 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | 4 | system_packages: 5 | - 'ffmpeg' 6 | - 'libsndfile1' 7 | 8 | python_version: '3.8' 9 | 10 | python_packages: 11 | - 'ffmpeg-python==0.2.0' 12 | - 'torch==1.11.0' 13 | - 'torchvision==0.12.0' 14 | - 'torchaudio==0.11.0' 15 | - 'pyannote.audio==2.1.1' 16 | - 'openai-whisper==20230314' 17 | 18 | run: 19 | - 'wget -O - https://pyannote-speaker-diarization.s3.eu-west-2.amazonaws.com/data-2023-03-25-02.tar.gz | tar xz -C /' 20 | - 'mkdir /data/whisper' 21 | - 'wget -P /data/whisper https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt' 22 | 23 | image: 'r8.im/meronym/speaker-transcription' 24 | 25 | predict: 'predict.py:Predictor' 26 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/meronym/speaker-transcription/3cdcf316f1a62d260d09475025c83d8df0ffea7f/lib/__init__.py -------------------------------------------------------------------------------- /lib/audio.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import os 3 | import tempfile 4 | 5 | import ffmpeg 6 | 7 | 8 | class AudioPreProcessor: 9 | def __init__(self): 10 | self.tmpdir = None 11 | self.output_path = None 12 | self.error = None 13 | 14 | def process(self, audio_file): 15 | # create a new temp dir for every run 16 | self.tmpdir = pathlib.Path(tempfile.mkdtemp()) 17 | self.output_path = str(self.tmpdir / 'audio.wav') 18 | self.error = None 19 | 20 | # converts audio file to 16kHz 16bit mono wav... 21 | print('pre-processing audio file...') 22 | stream = ffmpeg.input(audio_file, vn=None, hide_banner=None) 23 | stream = stream.output(self.output_path, format='wav', 24 | acodec='pcm_s16le', ac=1, ar='16k').overwrite_output() 25 | try: 26 | ffmpeg.run(stream, capture_stdout=True, capture_stderr=True) 27 | except ffmpeg.Error as e: 28 | self.error = e.stderr.decode('utf8') 29 | 30 | def cleanup(self): 31 | if os.path.exists(self.output_path): 32 | os.remove(self.output_path) 33 | if self.tmpdir: 34 | self.tmpdir.rmdir() 35 | -------------------------------------------------------------------------------- /lib/diarization.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import datetime 3 | 4 | import numpy as np 5 | 6 | 7 | def format_ts(ts): 8 | return str(datetime.timedelta(seconds=ts)) 9 | 10 | 11 | class SpeakerLabelGenerator: 12 | def __init__(self): 13 | self.speakers = {} 14 | self.labels = [] 15 | self.next_speaker = ord('A') 16 | self.count = 0 17 | 18 | def get(self, name): 19 | if name not in self.speakers: 20 | current = chr(self.next_speaker) 21 | self.speakers[name] = current 22 | self.labels.append(current) 23 | self.next_speaker += 1 24 | self.count += 1 25 | return self.speakers[name] 26 | 27 | def get_all(self): 28 | return self.labels 29 | 30 | 31 | class DiarizationPostProcessor: 32 | def __init__(self): 33 | self.MIN_SEGMENT_DURATION = 1.0 34 | self.labels = None 35 | 36 | def process(self, diarization, embeddings): 37 | print('post-processing diarization...') 38 | # create a new label generator 39 | self.labels = SpeakerLabelGenerator() 40 | 41 | # process the diarization 42 | clean_segments = self.clean_segments(diarization) 43 | merged_segments = self.merge_segments(clean_segments) 44 | emb_segments = self.segment_embeddings(merged_segments, embeddings) 45 | 46 | # create the speaker embeddings 47 | speaker_embeddings = self.create_speaker_embeddings(emb_segments) 48 | speaker_count = self.labels.count 49 | speaker_labels = self.labels.get_all() 50 | speaker_emb_map = {} 51 | for label in speaker_labels: 52 | speaker_emb_map[label] = speaker_embeddings[label].tolist() 53 | 54 | # create the final output 55 | # segments = self.format_segments(emb_segments) 56 | # segments = self.format_segments_extra(emb_segments, speaker_embeddings) 57 | return { 58 | "segments": emb_segments, 59 | "speakers": { 60 | "count": speaker_count, 61 | "labels": speaker_labels, 62 | "embeddings": speaker_emb_map, 63 | }, 64 | } 65 | 66 | def empty_result(self): 67 | return { 68 | "segments": [], 69 | "speakers": { 70 | "count": 0, 71 | "labels": [], 72 | "embeddings": {}, 73 | }, 74 | } 75 | 76 | def clean_segments(self, diarization): 77 | speaker_time = collections.defaultdict(float) 78 | total_time = 0.0 79 | for segment, _, speaker in diarization.itertracks(yield_label=True): 80 | # filter out segments that are too short 81 | if segment.duration < self.MIN_SEGMENT_DURATION: 82 | continue 83 | speaker_time[speaker] += segment.duration 84 | total_time += segment.duration 85 | 86 | # filter out speakers that have spoken too little 87 | # (these are likely overlaps misclassified as separate speakers) 88 | speakers = set([ 89 | speaker 90 | for speaker, time in speaker_time.items() 91 | if time > total_time * 0.01 92 | ]) 93 | 94 | segments = [] 95 | for segment, _, speaker in diarization.itertracks(yield_label=True): 96 | if (speaker not in speakers) or segment.duration < self.MIN_SEGMENT_DURATION: 97 | continue 98 | segments.append({ 99 | "speaker": self.labels.get(speaker), 100 | "start": segment.start, 101 | "stop": segment.end, 102 | "embeddings": np.empty((0, 192)), 103 | }) 104 | return segments 105 | 106 | def merge_segments(self, clean_segments): 107 | # merge adjacent segments if they have the same speaker and are close enough 108 | merged = [] 109 | for segment in clean_segments: 110 | if not merged: 111 | merged.append(segment) 112 | continue 113 | if merged[-1]["speaker"] == segment["speaker"]: 114 | if segment["start"] - merged[-1]["stop"] < 2.0 * self.MIN_SEGMENT_DURATION: 115 | merged[-1]["stop"] = segment["stop"] 116 | continue 117 | merged.append(segment) 118 | return merged 119 | 120 | def segment_embeddings(self, merged_segments, embeddings): 121 | # process the embeddings 122 | for i, chunk in enumerate(embeddings['data']): 123 | # chunk shape: (local_num_speakers, dimension) 124 | speakers = [] 125 | for speaker_embedding in chunk: 126 | if not np.all(np.isnan(speaker_embedding)): 127 | speakers.append(speaker_embedding) 128 | if len(speakers) != 1: 129 | # ignore this chunk 130 | continue 131 | # now we have a single speaker for this chunk 132 | speaker = speakers[0] 133 | 134 | # find the segment that this chunk belongs to 135 | chunk_start = i * embeddings['chunk_offset'] 136 | chunk_end = chunk_start + embeddings['chunk_duration'] 137 | 138 | for segment in merged_segments: 139 | if (segment['start'] <= chunk_start) and (chunk_end <= segment['stop']): 140 | # this is the segment we're looking for 141 | segment['embeddings'] = np.append( 142 | segment['embeddings'], 143 | [speaker], 144 | axis=0, 145 | ) 146 | break 147 | return merged_segments 148 | 149 | def create_speaker_embeddings(self, emb_segments): 150 | speaker_embeddings = collections.defaultdict( 151 | lambda: np.empty((0, 192))) 152 | 153 | for segment in emb_segments: 154 | if segment["embeddings"].size == 0: 155 | continue 156 | speaker_embeddings[segment["speaker"]] = np.vstack([ 157 | speaker_embeddings[segment["speaker"]], 158 | segment["embeddings"], 159 | ]) 160 | for speaker in speaker_embeddings: 161 | speaker_embeddings[speaker] = speaker_embeddings[speaker].mean( 162 | axis=0) 163 | return speaker_embeddings 164 | 165 | def format_segments(self, emb_segments): 166 | segments = [] 167 | for segment in emb_segments: 168 | new = segment.copy() 169 | new['start'] = format_ts(new['start']) 170 | new['stop'] = format_ts(new['stop']) 171 | del new['embeddings'] 172 | segments.append(new) 173 | return segments 174 | 175 | def format_segments_extra(self, emb_segments, speaker_embeddings): 176 | from sklearn.metrics.pairwise import cosine_distances 177 | 178 | def format_ts(ts): 179 | return str(datetime.timedelta(seconds=ts)) 180 | 181 | def get_mean(embeddings): 182 | if len(embeddings) == 0: 183 | return None 184 | return embeddings.mean(axis=0) 185 | 186 | def dist(embedding, label): 187 | if embedding is None: 188 | return None 189 | ref = speaker_embeddings[label].reshape(1, -1) 190 | current = embedding.reshape(1, -1) 191 | return cosine_distances(ref, current)[0][0] 192 | 193 | segments = [] 194 | for segment in emb_segments: 195 | embedding = get_mean(segment["embeddings"]) 196 | segments.append({ 197 | "speaker": segment["speaker"], 198 | "start": format_ts(segment["start"]), 199 | "stop": format_ts(segment["stop"]), 200 | "edist": dict((label, dist(embedding, label)) for label in self.labels.get_all()), 201 | }) 202 | return segments 203 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | download model weights to /data 3 | wget -O - https://pyannote-speaker-diarization.s3.eu-west-2.amazonaws.com/data-2023-03-25-02.tar.gz | tar xz -C / 4 | wget -P /data/whisper https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt 5 | # wget -P /data/whisper https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt 6 | """ 7 | 8 | import json 9 | import tempfile 10 | 11 | import numpy as np 12 | import torch 13 | 14 | from cog import BasePredictor, Input, Path 15 | from pyannote.audio import Audio 16 | from pyannote.audio.pipelines import SpeakerDiarization 17 | from pyannote.core import Segment 18 | from whisper.model import Whisper, ModelDimensions 19 | 20 | from lib.diarization import DiarizationPostProcessor, format_ts 21 | from lib.audio import AudioPreProcessor 22 | 23 | 24 | class Predictor(BasePredictor): 25 | def setup(self): 26 | """Load the model into memory to make running multiple predictions efficient""" 27 | self.audio_pre = AudioPreProcessor() 28 | 29 | self.diarization = SpeakerDiarization( 30 | segmentation="/data/pyannote/segmentation/pytorch_model.bin", 31 | embedding="/data/speechbrain/spkrec-ecapa-voxceleb", 32 | clustering="AgglomerativeClustering", 33 | segmentation_batch_size=32, 34 | embedding_batch_size=32, 35 | embedding_exclude_overlap=True, 36 | ) 37 | self.diarization.instantiate({ 38 | "clustering": { 39 | "method": "centroid", 40 | "min_cluster_size": 15, 41 | "threshold": 0.7153814381597874, 42 | }, 43 | "segmentation": { 44 | "min_duration_off": 0.5817029604921046, 45 | "threshold": 0.4442333667381752, 46 | }, 47 | }) 48 | self.diarization_post = DiarizationPostProcessor() 49 | 50 | with open(f"/data/whisper/medium.en.pt", "rb") as f: 51 | checkpoint = torch.load(f, map_location="cpu") 52 | dims = ModelDimensions(**checkpoint["dims"]) 53 | self.whisper = Whisper(dims) 54 | self.whisper.load_state_dict(checkpoint["model_state_dict"]) 55 | 56 | def run_diarization(self): 57 | closure = {'embeddings': None} 58 | 59 | def hook(name, *args, **kwargs): 60 | if name == "embeddings" and len(args) > 0: 61 | closure['embeddings'] = args[0] 62 | 63 | print('diarizing audio file...') 64 | diarization = self.diarization(self.audio_pre.output_path, hook=hook) 65 | embeddings = { 66 | 'data': closure['embeddings'], 67 | 'chunk_duration': self.diarization.segmentation_duration, 68 | 'chunk_offset': self.diarization.segmentation_step * self.diarization.segmentation_duration, 69 | } 70 | return self.diarization_post.process(diarization, embeddings) 71 | 72 | def run_transcription(self, audio, segments, whisper_prompt): 73 | print('transcribing segments...') 74 | if whisper_prompt: 75 | print('using prompt:', repr(whisper_prompt)) 76 | 77 | self.whisper.to("cuda") 78 | trimmer = Audio(sample_rate=16000, mono=True) 79 | for seg in segments: 80 | start = seg['start'] 81 | stop = seg['stop'] 82 | print( 83 | f"transcribing segment {format_ts(start)} to {format_ts(stop)}") 84 | frames, _ = trimmer.crop(audio, Segment(start, stop)) 85 | # audio data was already downmixed to mono, so exract the first (only) channel 86 | frames = frames[0] 87 | seg['transcript'] = self.transcribe_segment(frames, start, whisper_prompt) 88 | 89 | def transcribe_segment(self, audio, ctx_start, whisper_prompt): 90 | # `temperature`: temperature to use for sampling 91 | # `temperature_increment_on_fallback``: temperature to increase when 92 | # falling back when the decoding fails to meet either of the thresholds below 93 | temperature = 0 94 | temperature_increment_on_fallback = 0.2 95 | temperature = tuple( 96 | np.arange(temperature, 1.0 + 1e-6, 97 | temperature_increment_on_fallback)) 98 | 99 | # `patience`: "optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424 100 | # the default (1.0) is equivalent to conventional beam search 101 | patience = None 102 | 103 | # `suppress_tokens`: "comma-separated list of token ids to suppress during sampling 104 | # '-1' will suppress most special characters except common punctuations 105 | suppress_tokens = "-1" 106 | 107 | # `initial_prompt`: optional text to provide as a prompt for the first window 108 | initial_prompt = whisper_prompt 109 | 110 | # `condition_on_previous_text`: if True, provide the previous output of the model 111 | # as a prompt for the next window; disabling may make the text inconsistent across windows, 112 | # but the model becomes less prone to getting stuck in a failure loop 113 | condition_on_previous_text = True 114 | 115 | # `compression_ratio_threshold`: if the gzip compression ratio is higher than this value, 116 | # treat the decoding as failed 117 | compression_ratio_threshold = 2.4 118 | 119 | # `logprob_threshold`: if the average log probability is lower than this value, 120 | # treat the decoding as failed 121 | logprob_threshold = -1.0 122 | 123 | # `no_speech_threshold`: if the probability of the <|nospeech|> token is higher than this value 124 | # AND the decoding has failed due to `logprob_threshold`, consider the segment as silence 125 | no_speech_threshold = 0.6 126 | 127 | args = { 128 | "language": "en", # this is an English-only model 129 | "patience": patience, 130 | "suppress_tokens": suppress_tokens, 131 | "initial_prompt": initial_prompt, 132 | "condition_on_previous_text": condition_on_previous_text, 133 | "compression_ratio_threshold": compression_ratio_threshold, 134 | "logprob_threshold": logprob_threshold, 135 | "no_speech_threshold": no_speech_threshold, 136 | } 137 | trs = self.whisper.transcribe( 138 | audio, temperature=temperature, **args) 139 | 140 | result = [] 141 | for s in trs.get('segments', []): 142 | timestamp = ctx_start + s['start'] 143 | result.append({ 144 | 'start': format_ts(timestamp), 145 | 'text': s['text'] 146 | }) 147 | return result 148 | 149 | def predict( 150 | self, 151 | audio: Path = Input(description="Audio file"), 152 | prompt: str = Input( 153 | default=None, 154 | description="Optional text to provide as a prompt for each Whisper model call.", 155 | ), 156 | ) -> Path: 157 | """Run a single prediction on the model""" 158 | 159 | self.audio_pre.process(audio) 160 | 161 | if self.audio_pre.error: 162 | print(self.audio_pre.error) 163 | result = self.diarization_post.empty_result() 164 | else: 165 | result = self.run_diarization() 166 | 167 | # transcribe segments 168 | self.run_transcription(self.audio_pre.output_path, result["segments"], prompt) 169 | 170 | # format segments 171 | result["segments"] = self.diarization_post.format_segments( 172 | result["segments"]) 173 | 174 | # cleanup 175 | self.audio_pre.cleanup() 176 | 177 | # write output 178 | output = Path(tempfile.mkdtemp()) / "output.json" 179 | with open(output, "w") as f: 180 | f.write(json.dumps(result, indent=2)) 181 | return output 182 | --------------------------------------------------------------------------------