├── README.md
├── auto_diarize.py
├── cleaner.py
├── convert_wav.py
├── extract_speaker_sentences.py
├── fetch_youtube_mp3.py
├── num_to_words.py
├── pyannote_diarize.py
├── realtime_diarize.py
├── requirements.txt
├── speaker_diarize.py
└── split_dataset.py
/README.md:
--------------------------------------------------------------------------------
1 | # WhoSpeaks
2 |
3 | *Toolkit for Enhanced Voice Training Datasets*
4 |
5 | Note: realtime_diarize.py required changes to RealtimeSTT. Please upgrade to latest version.
6 |
7 | WhoSpeaks emerged from the need for better speaker diarization tools. Existing libraries are heavyweight and often fall short in reliability, speed and efficiency. So this project offers a more refined alternative.
8 |
9 | > **Hint:** *Anybody interested in state-of-the-art voice solutions please also have a look at [Linguflex](https://github.com/KoljaB/Linguflex). It lets you control your environment by speaking and is one of the most capable and sophisticated open-source assistants currently available.*
10 |
11 | Here's the core concept:
12 | - **Voice Characteristic Extraction**: For each sentence in your audio, unique voice characteristics are extracted, creating audio embeddings.
13 | - **Sentence Similarity Comparison**: Then cosine similarity is used to compare these embeddings against every other sentence, identifying similarities.
14 | - **Grouping and Averaging**: Similar sounding sentences are grouped together. This approach averages out anomalies and minimizes errors from individual data points.
15 | - **Identification of Distinct Groups**: By analyzing these groups, we can isolate the most distinct ones, which represent unique speaker characteristics.
16 |
17 | These steps allow us to match any sentence against the established speaker profiles with remarkable precision.
18 |
19 | ### Feature Modules
20 |
21 | - **fetch_youtube_mp3.py**: Extracts and converts YouTube audio, like podcasts, to MP3 for voice analysis.
22 | - **split_dataset.py**: This tool divides your input audio into distinct sentences.
23 | - **convert_wav.py**: Converts the sentence-based MP3 files into WAV format.
24 | - **auto_diarize.py**/**speaker_diarize.py**: Heart of WhoSpeaks. Categorizes sentences into speaker groups and selects training sentences based on the unique algorithm described above.
25 | - **pyannote_diarize.py**: Use for comparison against pyannote audio diarization, a current state of the art speaker diarization model
26 |
27 | > **Note**: *auto_diarize is for multiple speakers, speaker_diarize is for two speakers only*
28 |
29 | I initially developed this as a personal project, but was astounded by its effectiveness. In my first tests it outperformed existing solutions like pyannote audio in both reliability and speed while being the more lightweight approach. For me it could be a significant step up in voice diarization capabilities, that's why I've decided to release this rather raw, yet powerful code for others to experiment with.
30 |
31 | ## Performance and Testing
32 |
33 | To demonstrate WhoSpeaks' capabilities, I made a test using a challenging audio sample: the 4:38 Coin Toss scene from "No Country for Old Men". In this scene, the two male speakers have very similar voice profiles, presenting a difficult scenario for diarization libraries.
34 |
35 | ### Process:
36 |
37 | 1. **Download**: Using `fetch_youtube_mp3.py`, download the MP3 from the scene's YouTube video.
38 | 2. **Diarization Comparison**: Run the scene through `pyannote_diarize.py` (from pyannote audio) and set the speaker parameters to 2.
39 | - Pyannote's output was inaccurate, assigning most sentences to one speaker incorrectly.
40 | 3. **WhoSpeaks Analysis**:
41 | - **Sentence Splitting**: Use `split_dataset.py` with `tiny.en` for efficiency, though `large-v2` offers higher accuracy.
42 | - **Conversion**: The MP3 segments are converted to WAV format using `convert_wav.py`.
43 | - **Diarization**: Then run `auto_diarize.py` and visually inspect the dendrogram file to confirm the presence of two speakers.
44 |
45 | To run `auto_diarize.py` and `speaker_diarize.py` it is necessary to set the environment variable COQUI_MODEL_PATH to the path containing the "v2.0.2" model folder for coqui XTTS.
46 |
47 |
48 | ### Results:
49 |
50 | - WhoSpeaks' algorithm assigned 53 sentences correctly to Javier Bardem's voice with only 2 minor errors.
51 | - Of the 33 sentences assigned to the other actor, only one was incorrect.
52 | - The overall error rate was approximately 3.5%, demonstrating a precision of about 95% in correctly assigning sentences.
53 |
54 | The effectiveness of WhoSpeaks in this test, particularly against pyannote audio, showcases its potential in handling complex diarization scenarios with high accuracy and efficiency.
55 |
--------------------------------------------------------------------------------
/auto_diarize.py:
--------------------------------------------------------------------------------
1 | print("Auto Speaker Diarization with Hierarchical Clustering")
2 |
3 | from sklearn.cluster import AgglomerativeClustering, KMeans
4 | from sklearn.preprocessing import StandardScaler
5 | from sklearn.metrics import silhouette_score
6 | import scipy.cluster.hierarchy as sch
7 | from TTS.tts.models import setup_model as setup_tts_model
8 | from TTS.config import load_config
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import librosa
12 | import shutil
13 | import torch
14 | import os
15 |
16 | # Setup
17 | input_directory = 'output_sentences_wav'
18 | output_directory = 'output_speakers'
19 | max_sentences = 1000
20 | minimum_duration = 0.5
21 | two_speaker_threshold = 19
22 | silhouette_diff_threshold = 0.01 # Adjust as needed for your data
23 | data = []
24 |
25 | print("Loading TTS model")
26 | device = torch.device("cuda")
27 | local_models_path = os.environ.get("COQUI_MODEL_PATH")
28 | if local_models_path is None:
29 | local_models_path = "models"
30 | checkpoint = os.path.join(local_models_path, "v2.0.2")
31 | config = load_config((os.path.join(checkpoint, "config.json")))
32 | tts = setup_tts_model(config)
33 | tts.load_checkpoint(
34 | config,
35 | checkpoint_dir=checkpoint,
36 | checkpoint_path=None,
37 | vocab_path=None,
38 | eval=True,
39 | use_deepspeed=False,
40 | )
41 | tts.to(device)
42 | print("TTS model loaded")
43 |
44 |
45 | def get_speaker_embedding(audio_path):
46 | _, speaker_embedding = tts.get_conditioning_latents(audio_path=audio_path, gpt_cond_len=30, max_ref_length=60)
47 | return speaker_embedding
48 |
49 |
50 | # Create 1D embeddings from sentences
51 | count = 0
52 | embeddings = []
53 | for filename in os.listdir(input_directory):
54 | if filename.endswith(".wav") and count < max_sentences:
55 | y, sr = librosa.load(os.path.join(input_directory, filename))
56 | if librosa.get_duration(y=y, sr=sr) >= minimum_duration:
57 | full_path = os.path.join(input_directory, filename)
58 | speaker_embedding = get_speaker_embedding(full_path)
59 | speaker_embedding_1D = speaker_embedding.view(-1).cpu().detach().numpy()
60 | embeddings.append(speaker_embedding_1D)
61 | data.append({'filename': filename, 'speaker_embeds_1D': speaker_embedding_1D})
62 | count += 1
63 |
64 |
65 | # Standard scaling
66 | embeddings_array = np.array(embeddings)
67 | scaler = StandardScaler()
68 | embeddings_scaled = scaler.fit_transform(embeddings_array)
69 |
70 | # Hierarchical Clustering
71 | linked = sch.linkage(embeddings_scaled, method='ward')
72 |
73 | # Safety check using KMeans for initial speaker detection
74 | def determine_optimal_cluster_count(embeddings_scaled):
75 | num_embeddings = len(embeddings_scaled)
76 | if num_embeddings <= 1:
77 | # Only one embedding, so only one speaker
78 | return 1
79 | else:
80 | # Determine single or multiple speakers
81 | # K-means Clustering with k=2
82 | kmeans = KMeans(n_clusters=2, random_state=0).fit(embeddings_scaled)
83 | distances = kmeans.transform(embeddings_scaled)
84 | avg_distance = np.mean(np.min(distances, axis=1))
85 | distance_threshold = two_speaker_threshold # Threshold to decide if we have one or multiple speakers
86 |
87 | if avg_distance < distance_threshold:
88 | print(f"Single Speaker: low embedding distance: {avg_distance} < {distance_threshold}.")
89 | return 1
90 | else:
91 | # Hierarchical Clustering for multiple speakers
92 | max_clusters = min(10, num_embeddings)
93 | range_clusters = range(2, max_clusters + 1)
94 | silhouette_scores = []
95 |
96 | for n_clusters in range_clusters:
97 | hc = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
98 | cluster_labels = hc.fit_predict(embeddings_scaled)
99 |
100 | unique_labels = set(cluster_labels)
101 | if 1 < len(unique_labels) < len(embeddings_scaled):
102 | silhouette_avg = silhouette_score(embeddings_scaled, cluster_labels)
103 | silhouette_scores.append(silhouette_avg)
104 | else:
105 | print(f"Inappropriate number of clusters: {len(unique_labels)}.")
106 | silhouette_scores.append(-1)
107 |
108 | # Find the optimal number of clusters based on silhouette scores
109 | optimal_cluster_count = 2
110 | for i in range(1, len(silhouette_scores)):
111 | # Ensure a significant increase in the silhouette score to add a new cluster
112 | if silhouette_scores[i] - silhouette_scores[i - 1] > silhouette_diff_threshold:
113 | optimal_cluster_count = range_clusters[i]
114 | # else:
115 | # print(f"Silhouette score difference too low: {silhouette_scores[i] - silhouette_scores[i - 1]}.")
116 |
117 | # optimal_cluster_count = range_clusters[silhouette_scores.index(max(silhouette_scores))]
118 |
119 | return optimal_cluster_count
120 |
121 | # Determine the optimal number of clusters
122 | optimal_cluster_count = determine_optimal_cluster_count(embeddings_scaled)
123 |
124 | # Plotting the dendrogram
125 | plt.figure(figsize=(10, 7))
126 | dendrogram = sch.dendrogram(linked, orientation='top', distance_sort='descending', show_leaf_counts=True)
127 | plt.title('Dendrogram')
128 |
129 | # Save or show dendrogram
130 | dendrogram_file = 'dendrogram.png'
131 | plt.savefig(dendrogram_file)
132 | print()
133 | print(f"The dendrogram image showing the detected speaker clusters was saved as {dendrogram_file}.")
134 | print()
135 |
136 | # Explanation for the user
137 | print(f"The automatical speaker detection suggested there were {optimal_cluster_count} speakers.")
138 | print(f"Please verify this by inspecting the dendrogram plot that has been saved as {dendrogram_file}.")
139 | print("You should look for the longest vertical lines that are not crossed by any horizontal lines.")
140 | print("These lines suggest a natural separation between different clusters.")
141 | print("A horizontal 'cut' through these long lines will determine the number of clusters.")
142 | print("Count the number of vertical lines intersected by an imaginary horizontal line to decide the cluster count.")
143 | print("This number will be the number of speakers you should input.")
144 | print()
145 | print(f"Automatical speaker count suggestion: {optimal_cluster_count} speakers.")
146 | print()
147 | print("If you have identified a different number of speakers from the dendragram file, please enter the number.")
148 | print(f"If you are satisfied with the automatic suggestion of {optimal_cluster_count} speakers, you can press Enter to proceed.")
149 | print()
150 |
151 | # Ask the user for the number of clusters with a retry mechanism if the input fails
152 | while True:
153 | try:
154 | input_user = input("Please enter the number of speakers (clusters) you have identified: ")
155 | if input_user == "":
156 | cluster_count = optimal_cluster_count
157 | break
158 | cluster_count = int(input_user)
159 | if cluster_count > 0:
160 | break
161 | else:
162 | print("The number of clusters must be a positive integer. Please try again.")
163 | except ValueError:
164 | print("Invalid input; please enter an integer value. Try again.")
165 |
166 |
167 | # Determine clusters from dendrogram
168 | hc = AgglomerativeClustering(n_clusters=cluster_count, linkage='ward')
169 | clusters = hc.fit_predict(embeddings_scaled)
170 |
171 | # Assign sentences to clusters
172 | for i, entry in enumerate(data):
173 | entry['assigned_cluster'] = clusters[i]
174 |
175 | # Copy files to corresponding directories
176 | for cluster_id in range(cluster_count):
177 | speaker_dir = os.path.join(output_directory, f"speaker_{cluster_id}")
178 | os.makedirs(speaker_dir, exist_ok=True)
179 |
180 | for entry in data:
181 | if entry['assigned_cluster'] == cluster_id:
182 | source_path = os.path.join(input_directory, entry['filename'])
183 | destination_path = os.path.join(speaker_dir, entry['filename'])
184 | shutil.copy(source_path, destination_path)
185 |
186 | print("Speaker diarization completed with hierarchical clustering.")
187 |
--------------------------------------------------------------------------------
/cleaner.py:
--------------------------------------------------------------------------------
1 | """
2 |
3 | Heavily borrowed from:
4 | - https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/layers/xtts/tokenizer.py
5 | - https://github.com/daswer123/xtts-webui/blob/main/scripts/utils/tokenizer.py
6 |
7 | """
8 |
9 | import re
10 |
11 | from num_to_words import TextNorm as zh_num2words
12 | from num2words import num2words
13 |
14 | _whitespace_re = re.compile(r"\s+")
15 | _number_re = re.compile(r"[0-9]+")
16 | _comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
17 | _dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
18 | _decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
19 | _currency_re = {
20 | "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
21 | "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
22 | "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
23 | }
24 | _ordinal_re = {
25 | "en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
26 | "es": re.compile(r"([0-9]+)(º|ª|er|o|a|os|as)"),
27 | "fr": re.compile(r"([0-9]+)(º|ª|er|re|e|ème)"),
28 | "de": re.compile(r"([0-9]+)(st|nd|rd|th|º|ª|\.(?=\s|$))"),
29 | "pt": re.compile(r"([0-9]+)(º|ª|o|a|os|as)"),
30 | "it": re.compile(r"([0-9]+)(º|°|ª|o|a|i|e)"),
31 | "pl": re.compile(r"([0-9]+)(º|ª|st|nd|rd|th)"),
32 | "ar": re.compile(r"([0-9]+)(ون|ين|ث|ر|ى)"),
33 | "cs": re.compile(r"([0-9]+)\.(?=\s|$)"), # In Czech, a dot is often used after the number to indicate ordinals.
34 | "ru": re.compile(r"([0-9]+)(-й|-я|-е|-ое|-ье|-го)"),
35 | "nl": re.compile(r"([0-9]+)(de|ste|e)"),
36 | "tr": re.compile(r"([0-9]+)(\.|inci|nci|uncu|üncü|\.)"),
37 | "hu": re.compile(r"([0-9]+)(\.|adik|edik|odik|edik|ödik|ödike|ik)"),
38 | "ko": re.compile(r"([0-9]+)(번째|번|차|째)"),
39 | "ja": re.compile(r"([0-9]+)(番|回|つ|目|等|位)")
40 | }
41 | _abbreviations = {
42 | "en": [
43 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
44 | for x in [
45 | ("mrs", "misess"),
46 | ("mr", "mister"),
47 | ("dr", "doctor"),
48 | ("st", "saint"),
49 | ("co", "company"),
50 | ("jr", "junior"),
51 | ("maj", "major"),
52 | ("gen", "general"),
53 | ("drs", "doctors"),
54 | ("rev", "reverend"),
55 | ("lt", "lieutenant"),
56 | ("hon", "honorable"),
57 | ("sgt", "sergeant"),
58 | ("capt", "captain"),
59 | ("esq", "esquire"),
60 | ("ltd", "limited"),
61 | ("col", "colonel"),
62 | ("ft", "fort"),
63 | ]
64 | ],
65 | "es": [
66 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
67 | for x in [
68 | ("sra", "señora"),
69 | ("sr", "señor"),
70 | ("dr", "doctor"),
71 | ("dra", "doctora"),
72 | ("st", "santo"),
73 | ("co", "compañía"),
74 | ("jr", "junior"),
75 | ("ltd", "limitada"),
76 | ]
77 | ],
78 | "fr": [
79 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
80 | for x in [
81 | ("mme", "madame"),
82 | ("mr", "monsieur"),
83 | ("dr", "docteur"),
84 | ("st", "saint"),
85 | ("co", "compagnie"),
86 | ("jr", "junior"),
87 | ("ltd", "limitée"),
88 | ]
89 | ],
90 | "de": [
91 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
92 | for x in [
93 | ("fr", "frau"),
94 | ("dr", "doktor"),
95 | ("st", "sankt"),
96 | ("co", "firma"),
97 | ("jr", "junior"),
98 | ]
99 | ],
100 | "pt": [
101 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
102 | for x in [
103 | ("sra", "senhora"),
104 | ("sr", "senhor"),
105 | ("dr", "doutor"),
106 | ("dra", "doutora"),
107 | ("st", "santo"),
108 | ("co", "companhia"),
109 | ("jr", "júnior"),
110 | ("ltd", "limitada"),
111 | ]
112 | ],
113 | "it": [
114 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
115 | for x in [
116 | # ("sig.ra", "signora"),
117 | ("sig", "signore"),
118 | ("dr", "dottore"),
119 | ("st", "santo"),
120 | ("co", "compagnia"),
121 | ("jr", "junior"),
122 | ("ltd", "limitata"),
123 | ]
124 | ],
125 | "pl": [
126 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
127 | for x in [
128 | ("p", "pani"),
129 | ("m", "pan"),
130 | ("dr", "doktor"),
131 | ("sw", "święty"),
132 | ("jr", "junior"),
133 | ]
134 | ],
135 | "ar": [
136 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
137 | for x in [
138 | # There are not many common abbreviations in Arabic as in English.
139 | ]
140 | ],
141 | "zh": [
142 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
143 | for x in [
144 | # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts.
145 | ]
146 | ],
147 | "cs": [
148 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
149 | for x in [
150 | ("dr", "doktor"), # doctor
151 | ("ing", "inženýr"), # engineer
152 | ("p", "pan"), # Could also map to pani for woman but no easy way to do it
153 | # Other abbreviations would be specialized and not as common.
154 | ]
155 | ],
156 | "ru": [
157 | (re.compile("\\b%s\\b" % x[0], re.IGNORECASE), x[1])
158 | for x in [
159 | ("г-жа", "госпожа"), # Mrs.
160 | ("г-н", "господин"), # Mr.
161 | ("д-р", "доктор"), # doctor
162 | # Other abbreviations are less common or specialized.
163 | ]
164 | ],
165 | "nl": [
166 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
167 | for x in [
168 | ("dhr", "de heer"), # Mr.
169 | ("mevr", "mevrouw"), # Mrs.
170 | ("dr", "dokter"), # doctor
171 | ("jhr", "jonkheer"), # young lord or nobleman
172 | # Dutch uses more abbreviations, but these are the most common ones.
173 | ]
174 | ],
175 | "tr": [
176 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
177 | for x in [
178 | ("b", "bay"), # Mr.
179 | ("byk", "büyük"), # büyük
180 | ("dr", "doktor"), # doctor
181 | # Add other Turkish abbreviations here if needed.
182 | ]
183 | ],
184 | "hu": [
185 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
186 | for x in [
187 | ("dr", "doktor"), # doctor
188 | ("b", "bácsi"), # Mr.
189 | ("nőv", "nővér"), # nurse
190 | # Add other Hungarian abbreviations here if needed.
191 | ]
192 | ],
193 | "ko": [
194 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
195 | for x in [
196 | # Korean doesn't typically use abbreviations in the same way as Latin-based scripts.
197 | ]
198 | ],
199 | "ja": [
200 | (re.compile("\\b%s\\b" % x[0]), x[1])
201 | for x in [
202 | ("氏", "さん"), # Mr.
203 | ("夫人", "おんなのひと"), # Mrs.
204 | ("博士", "はかせ"), # Doctor or PhD
205 | ("株", "株式会社"), # Corporation
206 | ("有", "有限会社"), # Limited company
207 | ("大学", "だいがく"), # University
208 | ("先生", "せんせい"), # Teacher/Professor/Master
209 | ("君", "くん") # Used at the end of boys' names to express familiarity or affection.
210 | ]
211 | ],
212 | }
213 |
214 | _symbols_multilingual = {
215 | "en": [
216 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
217 | for x in [
218 | ("&", " and "),
219 | ("@", " at "),
220 | ("%", " percent "),
221 | ("#", " hash "),
222 | ("$", " dollar "),
223 | ("£", " pound "),
224 | ("°", " degree "),
225 | ]
226 | ],
227 | "es": [
228 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
229 | for x in [
230 | ("&", " y "),
231 | ("@", " arroba "),
232 | ("%", " por ciento "),
233 | ("#", " numeral "),
234 | ("$", " dolar "),
235 | ("£", " libra "),
236 | ("°", " grados "),
237 | ]
238 | ],
239 | "fr": [
240 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
241 | for x in [
242 | ("&", " et "),
243 | ("@", " arobase "),
244 | ("%", " pour cent "),
245 | ("#", " dièse "),
246 | ("$", " dollar "),
247 | ("£", " livre "),
248 | ("°", " degrés "),
249 | ]
250 | ],
251 | "de": [
252 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
253 | for x in [
254 | ("&", " und "),
255 | ("@", " at "),
256 | ("%", " prozent "),
257 | ("#", " raute "),
258 | ("$", " dollar "),
259 | ("£", " pfund "),
260 | ("°", " grad "),
261 | ]
262 | ],
263 | "pt": [
264 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
265 | for x in [
266 | ("&", " e "),
267 | ("@", " arroba "),
268 | ("%", " por cento "),
269 | ("#", " cardinal "),
270 | ("$", " dólar "),
271 | ("£", " libra "),
272 | ("°", " graus "),
273 | ]
274 | ],
275 | "it": [
276 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
277 | for x in [
278 | ("&", " e "),
279 | ("@", " chiocciola "),
280 | ("%", " per cento "),
281 | ("#", " cancelletto "),
282 | ("$", " dollaro "),
283 | ("£", " sterlina "),
284 | ("°", " gradi "),
285 | ]
286 | ],
287 | "pl": [
288 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
289 | for x in [
290 | ("&", " i "),
291 | ("@", " małpa "),
292 | ("%", " procent "),
293 | ("#", " krzyżyk "),
294 | ("$", " dolar "),
295 | ("£", " funt "),
296 | ("°", " stopnie "),
297 | ]
298 | ],
299 | "ar": [
300 | # Arabic
301 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
302 | for x in [
303 | ("&", " و "),
304 | ("@", " على "),
305 | ("%", " في المئة "),
306 | ("#", " رقم "),
307 | ("$", " دولار "),
308 | ("£", " جنيه "),
309 | ("°", " درجة "),
310 | ]
311 | ],
312 | "zh": [
313 | # Chinese
314 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
315 | for x in [
316 | ("&", " 和 "),
317 | ("@", " 在 "),
318 | ("%", " 百分之 "),
319 | ("#", " 号 "),
320 | ("$", " 美元 "),
321 | ("£", " 英镑 "),
322 | ("°", " 度 "),
323 | ]
324 | ],
325 | "cs": [
326 | # Czech
327 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
328 | for x in [
329 | ("&", " a "),
330 | ("@", " na "),
331 | ("%", " procento "),
332 | ("#", " křížek "),
333 | ("$", " dolar "),
334 | ("£", " libra "),
335 | ("°", " stupně "),
336 | ]
337 | ],
338 | "ru": [
339 | # Russian
340 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
341 | for x in [
342 | ("&", " и "),
343 | ("@", " собака "),
344 | ("%", " процентов "),
345 | ("#", " номер "),
346 | ("$", " доллар "),
347 | ("£", " фунт "),
348 | ("°", " градус "),
349 | ]
350 | ],
351 | "nl": [
352 | # Dutch
353 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
354 | for x in [
355 | ("&", " en "),
356 | ("@", " bij "),
357 | ("%", " procent "),
358 | ("#", " hekje "),
359 | ("$", " dollar "),
360 | ("£", " pond "),
361 | ("°", " graden "),
362 | ]
363 | ],
364 | "tr": [
365 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
366 | for x in [
367 | ("&", " ve "),
368 | ("@", " at "),
369 | ("%", " yüzde "),
370 | ("#", " diyez "),
371 | ("$", " dolar "),
372 | ("£", " sterlin "),
373 | ("°", " derece "),
374 | ]
375 | ],
376 | "hu": [
377 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
378 | for x in [
379 | ("&", " és "),
380 | ("@", " kukac "),
381 | ("%", " százalék "),
382 | ("#", " kettőskereszt "),
383 | ("$", " dollár "),
384 | ("£", " font "),
385 | ("°", " fok "),
386 | ]
387 | ],
388 | "ko": [
389 | # Korean
390 | (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
391 | for x in [
392 | ("&", " 그리고 "),
393 | ("@", " 에 "),
394 | ("%", " 퍼센트 "),
395 | ("#", " 번호 "),
396 | ("$", " 달러 "),
397 | ("£", " 파운드 "),
398 | ("°", " 도 "),
399 | ]
400 | ],
401 | "ja": [
402 | (re.compile(r"%s" % re.escape(x[0])), x[1])
403 | for x in [
404 | ("&", " と "),
405 | ("@", " アットマーク "),
406 | ("%", " パーセント "),
407 | ("#", " ナンバー "),
408 | ("$", " ドル "),
409 | ("£", " ポンド "),
410 | ("°", " 度"),
411 | ]
412 | ],
413 | }
414 |
415 | def _expand_currency(m, lang="en", currency="USD"):
416 | amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
417 | full_amount = num2words(amount, to="currency", currency=currency, lang=lang if lang != "cs" else "cz")
418 |
419 | and_equivalents = {
420 | "en": ", ",
421 | "es": " con ",
422 | "fr": " et ",
423 | "de": " und ",
424 | "pt": " e ",
425 | "it": " e ",
426 | "pl": ", ",
427 | "cs": ", ",
428 | "ru": ", ",
429 | "nl": ", ",
430 | "ar": ", ",
431 | "tr": ", ",
432 | "hu": ", ",
433 | "ko": ", ",
434 | }
435 |
436 | if amount.is_integer():
437 | last_and = full_amount.rfind(and_equivalents[lang])
438 | if last_and != -1:
439 | full_amount = full_amount[:last_and]
440 |
441 | return full_amount
442 |
443 | def _remove_commas(m):
444 | text = m.group(0)
445 | if "," in text:
446 | text = text.replace(",", "")
447 | return text
448 |
449 | def _remove_dots(m):
450 | text = m.group(0)
451 | if "." in text:
452 | text = text.replace(".", "")
453 | return text
454 |
455 | def _expand_decimal_point(m, lang="en"):
456 | amount = m.group(1).replace(",", ".")
457 | return num2words(float(amount), lang=lang if lang != "cs" else "cz")
458 |
459 | def _expand_number(m, lang="en"):
460 | return num2words(int(m.group(0)), lang=lang if lang != "cs" else "cz")
461 |
462 | def expand_numbers_multilingual(text, lang="en"):
463 | if lang == "zh":
464 | text = zh_num2words()(text)
465 | else:
466 | if lang in ["en", "ru"]:
467 | text = re.sub(_comma_number_re, _remove_commas, text)
468 | else:
469 | text = re.sub(_dot_number_re, _remove_dots, text)
470 | try:
471 | text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
472 | text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
473 | text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
474 | except:
475 | pass
476 | if lang != "tr":
477 | text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
478 | text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
479 | text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
480 | return text
481 |
482 | def _expand_ordinal(m, lang="en"):
483 | return num2words(int(m.group(1)), ordinal=True, lang=lang if lang != "cs" else "cz")
484 |
485 | def expand_abbreviations_multilingual(text, lang="en"):
486 | for regex, replacement in _abbreviations[lang]:
487 | text = re.sub(regex, replacement, text)
488 | return text
489 |
490 | def expand_symbols_multilingual(text, lang="en"):
491 | for regex, replacement in _symbols_multilingual[lang]:
492 | text = re.sub(regex, replacement, text)
493 | text = text.replace(" ", " ") # Ensure there are no double spaces
494 | return text.strip()
495 |
496 | def collapse_whitespace(text):
497 | return re.sub(_whitespace_re, " ", text)
498 |
499 | def multilingual_cleaners(text, lang):
500 | text = text.replace('"', "")
501 | if lang == "tr":
502 | text = text.replace("İ", "i")
503 | text = text.replace("Ö", "ö")
504 | text = text.replace("Ü", "ü")
505 | text = text.lower()
506 | text = expand_numbers_multilingual(text, lang)
507 | text = expand_abbreviations_multilingual(text, lang)
508 | text = expand_symbols_multilingual(text, lang=lang)
509 | text = collapse_whitespace(text)
510 | return text
--------------------------------------------------------------------------------
/convert_wav.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ffmpeg
3 |
4 | def convert_mp3_to_wav(source_dir, target_dir, sample_rate=24000):
5 | """
6 | Converts all MP3 files in the source directory to WAV files in the target directory
7 | with a sample rate of 24000 Hz and mono channel.
8 | """
9 | # Create the target directory if it does not exist
10 | if not os.path.exists(target_dir):
11 | os.makedirs(target_dir)
12 |
13 | # Process each file in the source directory
14 | for filename in os.listdir(source_dir):
15 | if filename.endswith('.mp3'):
16 | source_file = os.path.join(source_dir, filename)
17 | target_file = os.path.join(target_dir, os.path.splitext(filename)[0] + '.wav')
18 |
19 | try:
20 | # Convert MP3 to WAV using ffmpeg with specified sample rate and mono channel
21 | ffmpeg.input(source_file).output(target_file, ar=sample_rate, ac=1).run()
22 | print(f'Converted {filename} to WAV (24000 Hz, Mono)')
23 | except ffmpeg.Error as e:
24 | print(f'Error converting {filename}: {e}')
25 |
26 | # Example usage
27 | source_dir = 'output_sentences' # Replace with your source directory path
28 | target_dir = 'output_sentences_wav' # Replace with your target directory path
29 |
30 | convert_mp3_to_wav(source_dir, target_dir)
31 |
32 | print('All conversions complete.')
33 |
--------------------------------------------------------------------------------
/extract_speaker_sentences.py:
--------------------------------------------------------------------------------
1 | """
2 | Speaker Diarization
3 |
4 | Idea:
5 | - create 1D embeddings specs from sentences
6 | - for every sentence
7 | - find most similar 10% other sentences
8 | - average out the 1Ds and make a "speech group" embedding from that
9 | - for every sentence
10 | - compare speech group embedding with all other sentence speech group embeddings
11 | - find the two speech groups with least similar embeddings
12 | - the 2 "speech group" embedding from that will be our "speaker" characteristics 1D embeddings
13 | - for every sentence
14 | - find cosine similarity between the sentence and the two "speaker" characteristics 1D embeddings
15 | - assign to the speaker with higher similarity
16 |
17 | => every sentence assigned to one to two speakers
18 |
19 | notes:
20 | - cut out every < 3s file before processing
21 |
22 | """
23 |
24 | from TTS.tts.models import setup_model as setup_tts_model
25 | from scipy.spatial.distance import cosine
26 | from TTS.config import load_config
27 | import librosa.display
28 | import librosa
29 | import numpy as np
30 | import shutil
31 | import torch
32 | import os
33 |
34 | input_directory = 'output_sentences_wav'
35 | output_directory = 'output_speakers'
36 | speaker_reference_file = '0035093-0035300_25_so_it_s_easier_.wav' # no path
37 | max_sentences = 1000000
38 | group_percentage = 0.1
39 | minimum_duration = 1
40 | only_keep_most_confident_percentage = 0.8
41 |
42 | data = []
43 |
44 | device = torch.device("cuda")
45 | local_models_path = os.environ.get("COQUI_MODEL_PATH")
46 | checkpoint = os.path.join(local_models_path, "v2.0.2")
47 | config = load_config((os.path.join(checkpoint, "config.json")))
48 | tts = setup_tts_model(config)
49 | tts.load_checkpoint(
50 | config,
51 | checkpoint_dir=checkpoint,
52 | checkpoint_path=None,
53 | vocab_path=None,
54 | eval=True,
55 | use_deepspeed=False,
56 | )
57 | tts.to(device)
58 | print("TTS model loaded")
59 |
60 | # create 1D embeddings from sentences
61 | count = 0
62 | speaker_embeddings = None
63 | for filename in os.listdir(input_directory):
64 | if filename.endswith(".wav"):
65 | count += 1
66 | if count > max_sentences:
67 | break
68 |
69 | # skip if file is too short
70 | y, sr = librosa.load(os.path.join(input_directory, filename))
71 | if librosa.get_duration(y=y, sr=sr) < minimum_duration:
72 | continue
73 |
74 | full_path = os.path.join(input_directory, filename)
75 | print(full_path)
76 |
77 | gpt_cond_latent, speaker_embedding = tts.get_conditioning_latents(audio_path=full_path, gpt_cond_len=30, max_ref_length=60)
78 | spealer_embedding = speaker_embedding.cpu().squeeze().half().tolist()
79 | speaker_embedding_1D = speaker_embedding.view(-1).cpu().detach().numpy() # Reshape to 1D then convert to NumPy
80 |
81 | if speaker_reference_file == filename:
82 | print(f"Speaker reference found: {filename}")
83 | speaker_embeddings = speaker_embedding_1D
84 |
85 | entry = {
86 | 'filename': filename,
87 | 'speaker_embeds_1D': speaker_embedding_1D
88 | }
89 | data.append(entry)
90 | else:
91 | continue
92 |
93 | if speaker_embeddings is None:
94 | raise Exception("Speaker reference not found")
95 |
96 | # Check similarity of each sentence to the speaker reference
97 | for index, entry in enumerate(data):
98 | embedding = entry['speaker_embeds_1D']
99 | similarity = 1 - cosine(embedding, speaker_embeddings)
100 | entry['confidence'] = similarity
101 |
102 | # Sort the data by confidence
103 | data.sort(key=lambda x: x['confidence'], reverse=True)
104 |
105 | # Create subdirectories for each percentile
106 | percentile_directories = []
107 | for i in range(10):
108 | dir_name = os.path.join(output_directory, f'percentile_{i * 10}-{(i + 1) * 10}')
109 | os.makedirs(dir_name, exist_ok=True)
110 | percentile_directories.append(dir_name)
111 |
112 |
113 | # Assign each file to its percentile directory
114 | total_files = len(data)
115 | for index, entry in enumerate(data):
116 | percentile_index = (index * 10) // total_files # Find the correct percentile
117 | destination_dir = percentile_directories[percentile_index]
118 | base_name, extension = os.path.splitext(entry['filename'])
119 | new_filename = f"{base_name}_conf_{entry['confidence']:.2f}{extension}" # Append confidence to filename
120 | source_path = os.path.join(input_directory, entry['filename'])
121 | destination_path = os.path.join(destination_dir, new_filename)
122 |
123 | # Copy the file to the percentile directory with new filename
124 | shutil.copy(source_path, destination_path)
125 | print(f"Copied {entry['filename']} to {destination_path}")
126 |
--------------------------------------------------------------------------------
/fetch_youtube_mp3.py:
--------------------------------------------------------------------------------
1 | from yt_dlp import YoutubeDL
2 | from os.path import exists, join, splitext
3 | import os
4 |
5 | urls = [
6 | "https://www.youtube.com/watch?v=ZY0DG8rUnCA",
7 | #"https://www.youtube.com/watch?v=0tvebuNmp-I"
8 | # "https://www.youtube.com/watch?v=GGoCBAo9N_g"
9 | # "https://www.youtube.com/watch?v=JN3KPFbWCy8", # Elon Musk / Lex Fridman Round 4
10 | # "https://www.youtube.com/watch?v=DxREm3s1scA", # Elon Musk / Lex Fridman Round 3
11 | # "https://www.youtube.com/watch?v=smK9dgdTl40", # Elon Musk / Lex Fridman Round 2
12 | # "https://www.youtube.com/watch?v=dEv99vxKjVI", # Elon Musk / Lex Fridman Round 1
13 | ]
14 |
15 |
16 | directory = "input"
17 |
18 |
19 | def fetch_youtube(
20 | url: str,
21 | filetype: str,
22 | directory: str = "downloaded_files"
23 | ):
24 |
25 | """
26 | Downloads a specific type of file (video, audio, or muted video)
27 | from the provided YouTube URL.
28 |
29 | Args:
30 | url (str): The URL of the YouTube video to be downloaded.
31 | filetype (str): Type of file to download - 'video', 'audio',
32 | or 'muted_video'.
33 | directory (str): The directory to download the file to.
34 |
35 | Returns:
36 | str: The filename of the downloaded file.
37 | """
38 | if directory and not exists(directory):
39 | os.makedirs(directory)
40 |
41 | if filetype == 'video':
42 | # Download video with audio
43 | outtmpl = join(directory, '%(title)s.%(ext)s')
44 | ydl_opts = {
45 | 'format': 'best',
46 | 'outtmpl': outtmpl,
47 | 'noplaylist': True,
48 | }
49 | elif filetype == 'audio':
50 | # Download audio only
51 | outtmpl = join(directory, '%(title)s_audio.%(ext)s')
52 | ydl_opts = {
53 | 'format': 'bestaudio/best',
54 | 'outtmpl': outtmpl,
55 | 'noplaylist': True,
56 | }
57 | elif filetype == 'mp3_audio':
58 | # Download audio as MP3
59 | outtmpl = join(directory, '%(title)s.%(ext)s')
60 | ydl_opts = {
61 | 'format': 'bestaudio/best',
62 | 'outtmpl': outtmpl,
63 | 'postprocessors': [{
64 | 'key': 'FFmpegExtractAudio',
65 | 'preferredcodec': 'mp3',
66 | 'preferredquality': '192',
67 | }],
68 | 'noplaylist': True,
69 | }
70 | elif filetype == 'muted_video':
71 | # Download video without audio
72 | outtmpl = join(directory, '%(title)s_mutedvideo.%(ext)s')
73 | ydl_opts = {
74 | 'format': 'bestvideo',
75 | 'outtmpl': outtmpl,
76 | 'noplaylist': True,
77 | }
78 | else:
79 | raise ValueError(
80 | "Invalid filetype. Choose 'video', 'audio', or 'muted_video'."
81 | )
82 |
83 | with YoutubeDL(ydl_opts) as ydl:
84 | info = ydl.extract_info(url, download=True)
85 | downloaded_file = ydl.prepare_filename(info)
86 |
87 | return downloaded_file
88 |
89 | for url in urls:
90 | audio_file = fetch_youtube(url, 'mp3_audio', directory)
91 |
--------------------------------------------------------------------------------
/num_to_words.py:
--------------------------------------------------------------------------------
1 | # Authors:
2 | # 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
3 | # 2019.9 - 2022 Jiayu DU
4 |
5 | import argparse
6 | import csv
7 | import os
8 | import re
9 | import string
10 | import sys
11 |
12 | # fmt: off
13 |
14 | # ================================================================================ #
15 | # basic constant
16 | # ================================================================================ #
17 | CHINESE_DIGIS = "零一二三四五六七八九"
18 | BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
19 | BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
20 | SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
21 | SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
22 | LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
23 | LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
24 | SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
25 | SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
26 |
27 | ZERO_ALT = "〇"
28 | ONE_ALT = "幺"
29 | TWO_ALTS = ["两", "兩"]
30 |
31 | POSITIVE = ["正", "正"]
32 | NEGATIVE = ["负", "負"]
33 | POINT = ["点", "點"]
34 | # PLUS = [u'加', u'加']
35 | # SIL = [u'杠', u'槓']
36 |
37 | FILLER_CHARS = ["呃", "啊"]
38 |
39 | ER_WHITELIST = (
40 | "(儿女|儿子|儿孙|女儿|儿媳|妻儿|"
41 | "胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|"
42 | "儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|"
43 | "佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)"
44 | )
45 | ER_WHITELIST_PATTERN = re.compile(ER_WHITELIST)
46 |
47 | # 中文数字系统类型
48 | NUMBERING_TYPES = ["low", "mid", "high"]
49 |
50 | CURRENCY_NAMES = "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
51 | CURRENCY_UNITS = "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
52 | COM_QUANTIFIERS = (
53 | "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
54 | "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
55 | "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
56 | "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
57 | "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
58 | "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)"
59 | )
60 |
61 |
62 | # Punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
63 | CN_PUNCS_STOP = "!?。。"
64 | CN_PUNCS_NONSTOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏·〈〉-"
65 | CN_PUNCS = CN_PUNCS_STOP + CN_PUNCS_NONSTOP
66 |
67 | PUNCS = CN_PUNCS + string.punctuation
68 | PUNCS_TRANSFORM = str.maketrans(PUNCS, "," * len(PUNCS), "") # replace puncs with English comma
69 |
70 |
71 | # https://zh.wikipedia.org/wiki/全行和半行
72 | QJ2BJ = {
73 | " ": " ",
74 | "!": "!",
75 | """: '"',
76 | "#": "#",
77 | "$": "$",
78 | "%": "%",
79 | "&": "&",
80 | "'": "'",
81 | "(": "(",
82 | ")": ")",
83 | "*": "*",
84 | "+": "+",
85 | ",": ",",
86 | "-": "-",
87 | ".": ".",
88 | "/": "/",
89 | "0": "0",
90 | "1": "1",
91 | "2": "2",
92 | "3": "3",
93 | "4": "4",
94 | "5": "5",
95 | "6": "6",
96 | "7": "7",
97 | "8": "8",
98 | "9": "9",
99 | ":": ":",
100 | ";": ";",
101 | "<": "<",
102 | "=": "=",
103 | ">": ">",
104 | "?": "?",
105 | "@": "@",
106 | "A": "A",
107 | "B": "B",
108 | "C": "C",
109 | "D": "D",
110 | "E": "E",
111 | "F": "F",
112 | "G": "G",
113 | "H": "H",
114 | "I": "I",
115 | "J": "J",
116 | "K": "K",
117 | "L": "L",
118 | "M": "M",
119 | "N": "N",
120 | "O": "O",
121 | "P": "P",
122 | "Q": "Q",
123 | "R": "R",
124 | "S": "S",
125 | "T": "T",
126 | "U": "U",
127 | "V": "V",
128 | "W": "W",
129 | "X": "X",
130 | "Y": "Y",
131 | "Z": "Z",
132 | "[": "[",
133 | "\": "\\",
134 | "]": "]",
135 | "^": "^",
136 | "_": "_",
137 | "`": "`",
138 | "a": "a",
139 | "b": "b",
140 | "c": "c",
141 | "d": "d",
142 | "e": "e",
143 | "f": "f",
144 | "g": "g",
145 | "h": "h",
146 | "i": "i",
147 | "j": "j",
148 | "k": "k",
149 | "l": "l",
150 | "m": "m",
151 | "n": "n",
152 | "o": "o",
153 | "p": "p",
154 | "q": "q",
155 | "r": "r",
156 | "s": "s",
157 | "t": "t",
158 | "u": "u",
159 | "v": "v",
160 | "w": "w",
161 | "x": "x",
162 | "y": "y",
163 | "z": "z",
164 | "{": "{",
165 | "|": "|",
166 | "}": "}",
167 | "~": "~",
168 | }
169 | QJ2BJ_TRANSFORM = str.maketrans("".join(QJ2BJ.keys()), "".join(QJ2BJ.values()), "")
170 |
171 |
172 | # 2013 China National Standard: https://zh.wikipedia.org/wiki/通用规范汉字表, raw resources:
173 | # https://github.com/mozillazg/pinyin-data/blob/master/kMandarin_8105.txt with 8105 chinese chars in total
174 | CN_CHARS_COMMON = (
175 | "一丁七万丈三上下不与丏丐丑专且丕世丘丙业丛东丝丞丢两严丧个丫中丰串临丸丹为主丽举"
176 | "乂乃久么义之乌乍乎乏乐乒乓乔乖乘乙乜九乞也习乡书乩买乱乳乸乾了予争事二亍于亏云互"
177 | "亓五井亘亚些亟亡亢交亥亦产亨亩享京亭亮亲亳亵亶亸亹人亿什仁仂仃仄仅仆仇仉今介仍从"
178 | "仑仓仔仕他仗付仙仝仞仟仡代令以仨仪仫们仰仲仳仵件价任份仿企伈伉伊伋伍伎伏伐休众优"
179 | "伙会伛伞伟传伢伣伤伥伦伧伪伫伭伯估伲伴伶伸伺似伽伾佁佃但位低住佐佑体何佖佗佘余佚"
180 | "佛作佝佞佟你佣佤佥佩佬佯佰佳佴佶佸佺佻佼佽佾使侁侂侃侄侈侉例侍侏侑侔侗侘供依侠侣"
181 | "侥侦侧侨侩侪侬侮侯侴侵侹便促俄俅俊俍俎俏俐俑俗俘俙俚俜保俞俟信俣俦俨俩俪俫俭修俯"
182 | "俱俳俵俶俸俺俾倌倍倏倒倓倔倕倘候倚倜倞借倡倥倦倧倨倩倪倬倭倮倴债倻值倾偁偃假偈偌"
183 | "偎偏偓偕做停偡健偬偭偰偲偶偷偻偾偿傀傃傅傈傉傍傒傕傣傥傧储傩催傲傺傻僇僎像僔僖僚"
184 | "僦僧僬僭僮僰僳僵僻儆儇儋儒儡儦儳儴儿兀允元兄充兆先光克免兑兔兕兖党兜兢入全八公六"
185 | "兮兰共关兴兵其具典兹养兼兽冀冁内冈冉册再冏冒冔冕冗写军农冠冢冤冥冬冮冯冰冱冲决况"
186 | "冶冷冻冼冽净凄准凇凉凋凌减凑凓凘凛凝几凡凤凫凭凯凰凳凶凸凹出击凼函凿刀刁刃分切刈"
187 | "刊刍刎刑划刖列刘则刚创初删判刨利别刬刭刮到刳制刷券刹刺刻刽刿剀剁剂剃剅削剋剌前剐"
188 | "剑剔剕剖剜剞剟剡剥剧剩剪副割剽剿劁劂劄劈劐劓力劝办功加务劢劣动助努劫劬劭励劲劳劼"
189 | "劾势勃勇勉勋勍勐勒勔勖勘勚募勠勤勰勺勾勿匀包匆匈匍匏匐匕化北匙匜匝匠匡匣匦匪匮匹"
190 | "区医匼匾匿十千卅升午卉半华协卑卒卓单卖南博卜卞卟占卡卢卣卤卦卧卫卬卮卯印危即却卵"
191 | "卷卸卺卿厂厄厅历厉压厌厍厕厖厘厚厝原厢厣厥厦厨厩厮去厾县叁参叆叇又叉及友双反发叔"
192 | "叕取受变叙叚叛叟叠口古句另叨叩只叫召叭叮可台叱史右叵叶号司叹叻叼叽吁吃各吆合吉吊"
193 | "同名后吏吐向吒吓吕吖吗君吝吞吟吠吡吣否吧吨吩含听吭吮启吱吲吴吵吸吹吻吼吽吾呀呃呆"
194 | "呇呈告呋呐呒呓呔呕呖呗员呙呛呜呢呣呤呦周呱呲味呵呶呷呸呻呼命咀咂咄咆咇咉咋和咍咎"
195 | "咏咐咒咔咕咖咙咚咛咝咡咣咤咥咦咧咨咩咪咫咬咯咱咳咴咸咺咻咽咿哀品哂哃哄哆哇哈哉哌"
196 | "响哎哏哐哑哒哓哔哕哗哙哚哝哞哟哢哥哦哧哨哩哪哭哮哱哲哳哺哼哽哿唁唆唇唉唏唐唑唔唛"
197 | "唝唠唢唣唤唧唪唬售唯唰唱唳唵唷唼唾唿啁啃啄商啉啊啐啕啖啜啡啤啥啦啧啪啫啬啭啮啰啴"
198 | "啵啶啷啸啻啼啾喀喁喂喃善喆喇喈喉喊喋喏喑喔喘喙喜喝喟喤喧喱喳喵喷喹喻喽喾嗄嗅嗉嗌"
199 | "嗍嗐嗑嗒嗓嗔嗖嗜嗝嗞嗟嗡嗣嗤嗥嗦嗨嗪嗫嗬嗯嗲嗳嗵嗷嗽嗾嘀嘁嘈嘉嘌嘎嘏嘘嘚嘛嘞嘟嘡"
200 | "嘣嘤嘧嘬嘭嘱嘲嘴嘶嘹嘻嘿噀噂噇噌噍噎噔噗噘噙噜噢噤器噩噪噫噬噱噶噻噼嚄嚅嚆嚎嚏嚓"
201 | "嚚嚣嚭嚯嚷嚼囊囔囚四回囟因囡团囤囫园困囱围囵囷囹固国图囿圃圄圆圈圉圊圌圐圙圜土圢"
202 | "圣在圩圪圫圬圭圮圯地圲圳圹场圻圾址坂均坉坊坋坌坍坎坏坐坑坒块坚坛坜坝坞坟坠坡坤坥"
203 | "坦坨坩坪坫坬坭坯坰坳坷坻坼坽垂垃垄垆垈型垌垍垎垏垒垓垕垙垚垛垞垟垠垡垢垣垤垦垧垩"
204 | "垫垭垮垯垱垲垴垵垸垺垾垿埂埃埆埇埋埌城埏埒埔埕埗埘埙埚埝域埠埤埪埫埭埯埴埵埸培基"
205 | "埼埽堂堃堆堇堉堋堌堍堎堐堑堕堙堞堠堡堤堧堨堪堰堲堵堼堽堾塄塅塆塌塍塑塔塘塝塞塥填"
206 | "塬塱塾墀墁境墅墈墉墐墒墓墕墘墙墚增墟墡墣墦墨墩墼壁壅壑壕壤士壬壮声壳壶壸壹处备复"
207 | "夏夐夔夕外夙多夜够夤夥大天太夫夬夭央夯失头夷夸夹夺夼奁奂奄奇奈奉奋奎奏契奓奔奕奖"
208 | "套奘奚奠奡奢奥奭女奴奶奸她好妁如妃妄妆妇妈妊妍妒妓妖妗妘妙妞妣妤妥妧妨妩妪妫妭妮"
209 | "妯妲妹妻妾姆姈姊始姐姑姒姓委姗姘姚姜姝姞姣姤姥姨姬姮姱姶姹姻姽姿娀威娃娄娅娆娇娈"
210 | "娉娌娑娓娘娜娟娠娣娥娩娱娲娴娵娶娼婀婆婉婊婌婍婕婘婚婞婠婢婤婧婪婫婳婴婵婶婷婺婻"
211 | "婼婿媂媄媆媒媓媖媚媛媞媪媭媱媲媳媵媸媾嫁嫂嫄嫉嫌嫒嫔嫕嫖嫘嫚嫜嫠嫡嫣嫦嫩嫪嫫嫭嫱"
212 | "嫽嬉嬖嬗嬛嬥嬬嬴嬷嬿孀孅子孑孓孔孕孖字存孙孚孛孜孝孟孢季孤孥学孩孪孬孰孱孳孵孺孽"
213 | "宁它宄宅宇守安宋完宏宓宕宗官宙定宛宜宝实宠审客宣室宥宦宧宪宫宬宰害宴宵家宸容宽宾"
214 | "宿寁寂寄寅密寇富寐寒寓寝寞察寡寤寥寨寮寰寸对寺寻导寿封射将尉尊小少尔尕尖尘尚尜尝"
215 | "尢尤尥尧尨尪尬就尴尸尹尺尻尼尽尾尿局屁层屃居屈屉届屋屎屏屐屑展屙属屠屡屣履屦屯山"
216 | "屹屺屼屾屿岁岂岈岊岌岍岐岑岔岖岗岘岙岚岛岜岞岠岢岣岨岩岫岬岭岱岳岵岷岸岽岿峁峂峃"
217 | "峄峋峒峗峘峙峛峡峣峤峥峦峧峨峪峭峰峱峻峿崀崁崂崃崄崆崇崌崎崒崔崖崚崛崞崟崡崤崦崧"
218 | "崩崭崮崴崶崽崾崿嵁嵅嵇嵊嵋嵌嵎嵖嵘嵚嵛嵝嵩嵫嵬嵯嵲嵴嶂嶅嶍嶒嶓嶙嶝嶟嶦嶲嶷巅巇巉"
219 | "巍川州巡巢工左巧巨巩巫差巯己已巳巴巷巽巾币市布帅帆师希帏帐帑帔帕帖帘帙帚帛帜帝帡"
220 | "带帧帨席帮帱帷常帻帼帽幂幄幅幌幔幕幖幛幞幡幢幪干平年并幸幺幻幼幽广庄庆庇床庋序庐"
221 | "庑库应底庖店庙庚府庞废庠庤庥度座庭庱庳庵庶康庸庹庼庾廆廉廊廋廑廒廓廖廙廛廨廪延廷"
222 | "建廿开弁异弃弄弆弇弈弊弋式弑弓引弗弘弛弟张弢弥弦弧弨弩弭弯弱弶弸弹强弼彀归当录彖"
223 | "彗彘彝彟形彤彦彧彩彪彬彭彰影彳彷役彻彼往征徂径待徇很徉徊律徐徒徕得徘徙徛徜御徨循"
224 | "徭微徵德徼徽心必忆忉忌忍忏忐忑忒忖志忘忙忝忞忠忡忤忧忪快忭忮忱忳念忸忺忻忽忾忿怀"
225 | "态怂怃怄怅怆怊怍怎怏怒怔怕怖怙怛怜思怠怡急怦性怨怩怪怫怯怵总怼怿恁恂恃恋恍恐恒恓"
226 | "恔恕恙恚恝恢恣恤恧恨恩恪恫恬恭息恰恳恶恸恹恺恻恼恽恿悃悄悆悈悉悌悍悒悔悖悚悛悝悟"
227 | "悠悢患悦您悫悬悭悯悰悱悲悴悸悻悼情惆惇惊惋惎惑惔惕惘惙惚惛惜惝惟惠惦惧惨惩惫惬惭"
228 | "惮惯惰想惴惶惹惺愀愁愃愆愈愉愍愎意愐愔愕愚感愠愣愤愦愧愫愭愿慆慈慊慌慎慑慕慝慢慥"
229 | "慧慨慬慭慰慵慷憋憎憔憕憙憧憨憩憬憭憷憺憾懂懈懊懋懑懒懔懦懵懿戆戈戊戋戌戍戎戏成我"
230 | "戒戕或戗战戚戛戟戡戢戣戤戥截戬戭戮戳戴户戽戾房所扁扂扃扅扆扇扈扉扊手才扎扑扒打扔"
231 | "托扛扞扣扦执扩扪扫扬扭扮扯扰扳扶批扺扼扽找承技抃抄抉把抑抒抓抔投抖抗折抚抛抟抠抡"
232 | "抢护报抨披抬抱抵抹抻押抽抿拂拃拄担拆拇拈拉拊拌拍拎拐拒拓拔拖拗拘拙招拜拟拢拣拤拥"
233 | "拦拧拨择括拭拮拯拱拳拴拶拷拼拽拾拿持挂指挈按挎挑挓挖挚挛挝挞挟挠挡挣挤挥挦挨挪挫"
234 | "振挲挹挺挽捂捃捅捆捉捋捌捍捎捏捐捕捞损捡换捣捧捩捭据捯捶捷捺捻捽掀掂掇授掉掊掌掎"
235 | "掏掐排掖掘掞掠探掣接控推掩措掬掭掮掰掳掴掷掸掺掼掾揄揆揉揍描提插揕揖揠握揣揩揪揭"
236 | "揳援揶揸揽揿搀搁搂搅搋搌搏搐搒搓搔搛搜搞搠搡搦搪搬搭搴携搽摁摄摅摆摇摈摊摏摒摔摘"
237 | "摛摞摧摩摭摴摸摹摽撂撄撅撇撑撒撕撖撙撞撤撩撬播撮撰撵撷撸撺撼擀擂擅操擎擐擒擘擞擢"
238 | "擤擦擿攀攉攒攘攥攫攮支收攸改攻攽放政故效敉敌敏救敔敕敖教敛敝敞敢散敦敩敫敬数敲整"
239 | "敷文斋斌斐斑斓斗料斛斜斝斟斠斡斤斥斧斩斫断斯新斶方於施旁旃旄旅旆旋旌旎族旐旒旖旗"
240 | "旞无既日旦旧旨早旬旭旮旯旰旱旴旵时旷旸旺旻旿昀昂昃昄昆昇昈昉昊昌明昏昒易昔昕昙昝"
241 | "星映昡昣昤春昧昨昪昫昭是昱昳昴昵昶昺昼昽显晁晃晅晊晋晌晏晐晒晓晔晕晖晗晙晚晞晟晡"
242 | "晢晤晦晨晪晫普景晰晱晴晶晷智晾暂暄暅暇暌暑暕暖暗暝暧暨暮暲暴暵暶暹暾暿曈曌曙曛曜"
243 | "曝曦曩曰曲曳更曷曹曼曾替最月有朋服朏朐朓朔朕朗望朝期朦木未末本札术朱朳朴朵朸机朽"
244 | "杀杂权杄杆杈杉杌李杏材村杓杕杖杙杜杞束杠条来杧杨杩杪杭杯杰杲杳杵杷杻杼松板极构枅"
245 | "枇枉枋枍析枕林枘枚果枝枞枢枣枥枧枨枪枫枭枯枰枲枳枵架枷枸枹柁柃柄柈柊柏某柑柒染柔"
246 | "柖柘柙柚柜柝柞柠柢查柩柬柯柰柱柳柴柷柽柿栀栅标栈栉栊栋栌栎栏栐树栒栓栖栗栝栟校栩"
247 | "株栲栳栴样核根栻格栽栾桀桁桂桃桄桅框案桉桊桌桎桐桑桓桔桕桠桡桢档桤桥桦桧桨桩桫桯"
248 | "桲桴桶桷桹梁梃梅梆梌梏梓梗梠梢梣梦梧梨梭梯械梳梴梵梼梽梾梿检棁棂棉棋棍棐棒棓棕棘"
249 | "棚棠棣棤棨棪棫棬森棰棱棵棹棺棻棼棽椀椁椅椆椋植椎椐椑椒椓椟椠椤椪椭椰椴椸椹椽椿楂"
250 | "楒楔楗楙楚楝楞楠楣楦楩楪楫楮楯楷楸楹楼概榃榄榅榆榇榈榉榍榑榔榕榖榛榜榧榨榫榭榰榱"
251 | "榴榷榻槁槃槊槌槎槐槔槚槛槜槟槠槭槱槲槽槿樊樗樘樟模樨横樯樱樵樽樾橄橇橐橑橘橙橛橞"
252 | "橡橥橦橱橹橼檀檄檎檐檑檗檞檠檩檫檬櫆欂欠次欢欣欤欧欲欸欹欺欻款歃歅歆歇歉歌歙止正"
253 | "此步武歧歪歹死歼殁殂殃殄殆殇殉殊残殍殒殓殖殚殛殡殣殪殳殴段殷殿毁毂毅毋毌母每毐毒"
254 | "毓比毕毖毗毙毛毡毪毫毯毳毵毹毽氅氆氇氍氏氐民氓气氕氖氘氙氚氛氟氡氢氤氦氧氨氩氪氮"
255 | "氯氰氲水永氾氿汀汁求汆汇汈汉汊汋汐汔汕汗汛汜汝汞江池污汤汧汨汩汪汫汭汰汲汴汶汹汽"
256 | "汾沁沂沃沄沅沆沇沈沉沌沏沐沓沔沘沙沚沛沟没沣沤沥沦沧沨沩沪沫沭沮沱河沸油沺治沼沽"
257 | "沾沿泂泃泄泅泇泉泊泌泐泓泔法泖泗泙泚泛泜泞泠泡波泣泥注泪泫泮泯泰泱泳泵泷泸泺泻泼"
258 | "泽泾洁洄洇洈洋洌洎洑洒洓洗洘洙洚洛洞洢洣津洧洨洪洫洭洮洱洲洳洴洵洸洹洺活洼洽派洿"
259 | "流浃浅浆浇浈浉浊测浍济浏浐浑浒浓浔浕浙浚浛浜浞浟浠浡浣浥浦浩浪浬浭浮浯浰浲浴海浸"
260 | "浼涂涄涅消涉涌涍涎涐涑涓涔涕涘涛涝涞涟涠涡涢涣涤润涧涨涩涪涫涮涯液涴涵涸涿淀淄淅"
261 | "淆淇淋淌淏淑淖淘淙淜淝淞淟淠淡淤淦淫淬淮淯深淳淴混淹添淼清渊渌渍渎渐渑渔渗渚渝渟"
262 | "渠渡渣渤渥温渫渭港渰渲渴游渺渼湃湄湉湍湎湑湓湔湖湘湛湜湝湟湣湫湮湲湴湾湿溁溃溅溆"
263 | "溇溉溍溏源溘溚溜溞溟溠溢溥溦溧溪溯溱溲溴溵溶溷溹溺溻溽滁滂滃滆滇滉滋滍滏滑滓滔滕"
264 | "滗滘滚滞滟滠满滢滤滥滦滧滨滩滪滫滴滹漂漆漈漉漋漏漓演漕漖漠漤漦漩漪漫漭漯漱漳漴漶"
265 | "漷漹漻漼漾潆潇潋潍潏潖潘潜潞潟潢潦潩潭潮潲潴潵潸潺潼潽潾澂澄澈澉澌澍澎澛澜澡澥澧"
266 | "澪澭澳澴澶澹澼澽激濂濉濋濑濒濞濠濡濩濮濯瀌瀍瀑瀔瀚瀛瀣瀱瀵瀹瀼灈灌灏灞火灭灯灰灵"
267 | "灶灸灼灾灿炀炅炆炉炊炌炎炒炔炕炖炘炙炜炝炟炣炫炬炭炮炯炱炳炷炸点炻炼炽烀烁烂烃烈"
268 | "烊烔烘烙烛烜烝烟烠烤烦烧烨烩烫烬热烯烶烷烹烺烻烽焆焉焊焌焐焓焕焖焗焘焙焚焜焞焦焯"
269 | "焰焱然煁煃煅煊煋煌煎煓煜煞煟煤煦照煨煮煲煳煴煸煺煽熄熇熊熏熔熘熙熛熜熟熠熥熨熬熵"
270 | "熹熻燃燊燋燎燏燔燕燚燠燥燧燮燹爆爇爔爚爝爟爨爪爬爰爱爵父爷爸爹爻爽爿牁牂片版牌牍"
271 | "牒牖牙牚牛牝牟牡牢牤牥牦牧物牮牯牲牵特牺牻牾牿犀犁犄犇犊犋犍犏犒犟犨犬犯犰犴状犷"
272 | "犸犹狁狂狃狄狈狉狍狎狐狒狗狙狝狞狠狡狨狩独狭狮狯狰狱狲狳狴狷狸狺狻狼猁猃猄猇猊猎"
273 | "猕猖猗猛猜猝猞猡猢猥猩猪猫猬献猯猰猱猴猷猹猺猾猿獍獐獒獗獠獬獭獯獴獾玃玄率玉王玎"
274 | "玑玒玓玕玖玘玙玚玛玞玟玠玡玢玤玥玦玩玫玭玮环现玱玲玳玶玷玹玺玻玼玿珀珂珅珇珈珉珊"
275 | "珋珌珍珏珐珑珒珕珖珙珛珝珞珠珢珣珥珦珧珩珪珫班珰珲珵珷珸珹珺珽琀球琄琅理琇琈琉琊"
276 | "琎琏琐琔琚琛琟琡琢琤琥琦琨琪琫琬琭琮琯琰琲琳琴琵琶琼瑀瑁瑂瑃瑄瑅瑆瑑瑓瑔瑕瑖瑗瑙"
277 | "瑚瑛瑜瑝瑞瑟瑢瑧瑨瑬瑭瑰瑱瑳瑶瑷瑾璀璁璃璆璇璈璋璎璐璒璘璜璞璟璠璥璧璨璩璪璬璮璱"
278 | "璲璺瓀瓒瓖瓘瓜瓞瓠瓢瓣瓤瓦瓮瓯瓴瓶瓷瓻瓿甄甍甏甑甓甗甘甚甜生甡甥甦用甩甪甫甬甭甯"
279 | "田由甲申电男甸町画甾畀畅畈畋界畎畏畔畖留畚畛畜畤略畦番畬畯畲畴畸畹畿疁疃疆疍疏疐"
280 | "疑疔疖疗疙疚疝疟疠疡疢疣疤疥疫疬疭疮疯疰疱疲疳疴疵疸疹疼疽疾痂痃痄病症痈痉痊痍痒"
281 | "痓痔痕痘痛痞痢痣痤痦痧痨痪痫痰痱痴痹痼痿瘀瘁瘃瘅瘆瘊瘌瘐瘕瘗瘘瘙瘛瘟瘠瘢瘤瘥瘦瘩"
282 | "瘪瘫瘭瘰瘳瘴瘵瘸瘼瘾瘿癀癃癌癍癔癖癗癜癞癣癫癯癸登白百癿皂的皆皇皈皋皎皑皓皕皖皙"
283 | "皛皞皤皦皭皮皱皲皴皿盂盅盆盈盉益盍盎盏盐监盒盔盖盗盘盛盟盥盦目盯盱盲直盷相盹盼盾"
284 | "省眄眇眈眉眊看眍眙眚真眠眢眦眨眩眬眭眯眵眶眷眸眺眼着睁睃睄睇睎睐睑睚睛睡睢督睥睦"
285 | "睨睫睬睹睽睾睿瞀瞄瞅瞋瞌瞍瞎瞑瞒瞟瞠瞢瞥瞧瞩瞪瞫瞬瞭瞰瞳瞵瞻瞽瞿矍矗矛矜矞矢矣知"
286 | "矧矩矫矬短矮矰石矶矸矻矼矾矿砀码砂砄砆砉砌砍砑砒研砖砗砘砚砜砝砟砠砣砥砧砫砬砭砮"
287 | "砰破砵砷砸砹砺砻砼砾础硁硅硇硊硌硍硎硐硒硔硕硖硗硙硚硝硪硫硬硭确硼硿碃碇碈碉碌碍"
288 | "碎碏碑碓碗碘碚碛碜碟碡碣碥碧碨碰碱碲碳碴碶碹碾磁磅磉磊磋磏磐磔磕磙磜磡磨磬磲磴磷"
289 | "磹磻礁礅礌礓礞礴礵示礼社祀祁祃祆祇祈祉祊祋祎祏祐祓祕祖祗祚祛祜祝神祟祠祢祥祧票祭"
290 | "祯祲祷祸祺祼祾禀禁禄禅禊禋福禒禔禘禚禛禤禧禳禹禺离禽禾秀私秃秆秉秋种科秒秕秘租秣"
291 | "秤秦秧秩秫秬秭积称秸移秽秾稀稂稃稆程稌稍税稑稔稗稙稚稞稠稣稳稷稹稻稼稽稿穄穆穑穗"
292 | "穙穜穟穰穴究穷穸穹空穿窀突窃窄窅窈窊窍窎窑窒窕窖窗窘窜窝窟窠窣窥窦窨窬窭窳窸窿立"
293 | "竑竖竘站竞竟章竣童竦竫竭端竹竺竽竿笃笄笆笈笊笋笏笑笔笕笙笛笞笠笤笥符笨笪笫第笮笯"
294 | "笱笳笸笺笼笾筀筅筇等筋筌筏筐筑筒答策筘筚筛筜筝筠筢筤筥筦筮筱筲筵筶筷筹筻筼签简箅"
295 | "箍箐箓箔箕箖算箜管箢箦箧箨箩箪箫箬箭箱箴箸篁篆篇篌篑篓篙篚篝篡篥篦篪篮篯篱篷篼篾"
296 | "簃簇簉簋簌簏簕簖簝簟簠簧簪簰簸簿籀籁籍籥米籴类籼籽粉粑粒粕粗粘粜粝粞粟粢粤粥粪粮"
297 | "粱粲粳粹粼粽精粿糁糅糇糈糊糌糍糒糕糖糗糙糜糟糠糨糯糵系紊素索紧紫累絜絮絷綦綮縠縢"
298 | "縻繁繄繇纂纛纠纡红纣纤纥约级纨纩纪纫纬纭纮纯纰纱纲纳纴纵纶纷纸纹纺纻纼纽纾线绀绁"
299 | "绂练组绅细织终绉绊绋绌绍绎经绐绑绒结绔绕绖绗绘给绚绛络绝绞统绠绡绢绣绤绥绦继绨绩"
300 | "绪绫续绮绯绰绱绲绳维绵绶绷绸绹绺绻综绽绾绿缀缁缂缃缄缅缆缇缈缉缊缌缎缐缑缒缓缔缕"
301 | "编缗缘缙缚缛缜缝缞缟缠缡缢缣缤缥缦缧缨缩缪缫缬缭缮缯缰缱缲缳缴缵缶缸缺罂罄罅罍罐"
302 | "网罔罕罗罘罚罟罡罢罨罩罪置罱署罴罶罹罽罾羁羊羌美羑羓羔羕羖羚羝羞羟羡群羧羯羰羱羲"
303 | "羸羹羼羽羿翀翁翂翃翅翈翊翌翎翔翕翘翙翚翛翟翠翡翥翦翩翮翯翰翱翳翷翻翼翾耀老考耄者"
304 | "耆耇耋而耍耏耐耑耒耔耕耖耗耘耙耜耠耢耤耥耦耧耨耩耪耰耱耳耵耶耷耸耻耽耿聂聃聆聊聋"
305 | "职聍聒联聘聚聩聪聱聿肃肄肆肇肉肋肌肓肖肘肚肛肝肟肠股肢肤肥肩肪肫肭肮肯肱育肴肷肸"
306 | "肺肼肽肾肿胀胁胂胃胄胆胈背胍胎胖胗胙胚胛胜胝胞胠胡胣胤胥胧胨胩胪胫胬胭胯胰胱胲胳"
307 | "胴胶胸胺胼能脂脆脉脊脍脎脏脐脑脒脓脔脖脘脚脞脟脩脬脯脱脲脶脸脾脿腆腈腊腋腌腐腑腒"
308 | "腓腔腕腘腙腚腠腥腧腨腩腭腮腯腰腱腴腹腺腻腼腽腾腿膀膂膈膊膏膑膘膙膛膜膝膦膨膳膺膻"
309 | "臀臂臃臆臊臌臑臜臣臧自臬臭至致臻臼臾舀舁舂舄舅舆舌舍舐舒舔舛舜舞舟舠舢舣舥航舫般"
310 | "舭舯舰舱舲舳舴舵舶舷舸船舻舾艄艅艇艉艋艎艏艘艚艟艨艮良艰色艳艴艺艽艾艿节芃芄芈芊"
311 | "芋芍芎芏芑芒芗芘芙芜芝芟芠芡芣芤芥芦芨芩芪芫芬芭芮芯芰花芳芴芷芸芹芼芽芾苁苄苇苈"
312 | "苉苊苋苌苍苎苏苑苒苓苔苕苗苘苛苜苞苟苠苡苣苤若苦苧苫苯英苴苷苹苻苾茀茁茂范茄茅茆"
313 | "茈茉茋茌茎茏茑茓茔茕茗茚茛茜茝茧茨茫茬茭茯茱茳茴茵茶茸茹茺茼茽荀荁荃荄荆荇草荏荐"
314 | "荑荒荓荔荖荙荚荛荜荞荟荠荡荣荤荥荦荧荨荩荪荫荬荭荮药荷荸荻荼荽莅莆莉莎莒莓莘莙莛"
315 | "莜莝莞莠莨莩莪莫莰莱莲莳莴莶获莸莹莺莼莽莿菀菁菂菅菇菉菊菌菍菏菔菖菘菜菝菟菠菡菥"
316 | "菩菪菰菱菲菹菼菽萁萃萄萆萋萌萍萎萏萑萘萚萜萝萣萤营萦萧萨萩萱萳萸萹萼落葆葎葑葖著"
317 | "葙葚葛葜葡董葩葫葬葭葰葱葳葴葵葶葸葺蒂蒄蒇蒈蒉蒋蒌蒎蒐蒗蒙蒜蒟蒡蒨蒯蒱蒲蒴蒸蒹蒺"
318 | "蒻蒽蒿蓁蓂蓄蓇蓉蓊蓍蓏蓐蓑蓓蓖蓝蓟蓠蓢蓣蓥蓦蓬蓰蓼蓿蔀蔃蔈蔊蔌蔑蔓蔗蔚蔟蔡蔫蔬蔷"
319 | "蔸蔹蔺蔻蔼蔽蕃蕈蕉蕊蕖蕗蕙蕞蕤蕨蕰蕲蕴蕹蕺蕻蕾薁薄薅薇薏薛薜薢薤薨薪薮薯薰薳薷薸"
320 | "薹薿藁藉藏藐藓藕藜藟藠藤藦藨藩藻藿蘅蘑蘖蘘蘧蘩蘸蘼虎虏虐虑虒虓虔虚虞虢虤虫虬虮虱"
321 | "虷虸虹虺虻虼虽虾虿蚀蚁蚂蚄蚆蚊蚋蚌蚍蚓蚕蚜蚝蚣蚤蚧蚨蚩蚪蚬蚯蚰蚱蚲蚴蚶蚺蛀蛃蛄蛆"
322 | "蛇蛉蛊蛋蛎蛏蛐蛑蛔蛘蛙蛛蛞蛟蛤蛩蛭蛮蛰蛱蛲蛳蛴蛸蛹蛾蜀蜂蜃蜇蜈蜉蜊蜍蜎蜐蜒蜓蜕蜗"
323 | "蜘蜚蜜蜞蜡蜢蜣蜥蜩蜮蜱蜴蜷蜻蜾蜿蝇蝈蝉蝌蝎蝓蝗蝘蝙蝠蝣蝤蝥蝮蝰蝲蝴蝶蝻蝼蝽蝾螂螃"
324 | "螅螈螋融螗螟螠螣螨螫螬螭螯螱螳螵螺螽蟀蟆蟊蟋蟏蟑蟒蟛蟠蟥蟪蟫蟮蟹蟾蠃蠊蠋蠓蠕蠖蠡"
325 | "蠢蠲蠹蠼血衃衄衅行衍衎衒衔街衙衠衡衢衣补表衩衫衬衮衰衲衷衽衾衿袁袂袄袅袆袈袋袍袒"
326 | "袖袗袜袢袤袪被袭袯袱袷袼裁裂装裆裈裉裎裒裔裕裘裙裛裟裢裣裤裥裨裰裱裳裴裸裹裼裾褂"
327 | "褊褐褒褓褕褙褚褛褟褡褥褪褫褯褰褴褶襁襄襕襚襜襞襟襦襫襻西要覃覆见观觃规觅视觇览觉"
328 | "觊觋觌觎觏觐觑角觖觚觜觞觟解觥触觫觭觯觱觳觿言訄訇訚訾詈詟詹誉誊誓謇警譬计订讣认"
329 | "讥讦讧讨让讪讫训议讯记讱讲讳讴讵讶讷许讹论讻讼讽设访诀证诂诃评诅识诇诈诉诊诋诌词"
330 | "诎诏诐译诒诓诔试诖诗诘诙诚诛诜话诞诟诠诡询诣诤该详诧诨诩诫诬语诮误诰诱诲诳说诵请"
331 | "诸诹诺读诼诽课诿谀谁谂调谄谅谆谇谈谊谋谌谍谎谏谐谑谒谓谔谕谖谗谙谚谛谜谝谞谟谠谡"
332 | "谢谣谤谥谦谧谨谩谪谫谬谭谮谯谰谱谲谳谴谵谶谷谼谿豁豆豇豉豌豕豚象豢豨豪豫豮豳豸豹"
333 | "豺貂貅貆貉貊貌貔貘贝贞负贡财责贤败账货质贩贪贫贬购贮贯贰贱贲贳贴贵贶贷贸费贺贻贼"
334 | "贽贾贿赀赁赂赃资赅赆赇赈赉赊赋赌赍赎赏赐赑赒赓赔赕赖赗赘赙赚赛赜赝赞赟赠赡赢赣赤"
335 | "赦赧赪赫赭走赳赴赵赶起趁趄超越趋趑趔趟趣趯趱足趴趵趸趺趼趾趿跂跃跄跆跋跌跎跏跐跑"
336 | "跖跗跚跛距跞跟跣跤跨跪跬路跱跳践跶跷跸跹跺跻跽踅踉踊踌踏踒踔踝踞踟踢踣踦踩踪踬踮"
337 | "踯踱踵踶踹踺踽蹀蹁蹂蹄蹅蹇蹈蹉蹊蹋蹐蹑蹒蹙蹚蹜蹢蹦蹩蹬蹭蹯蹰蹲蹴蹶蹼蹽蹾蹿躁躅躇"
338 | "躏躐躔躜躞身躬躯躲躺车轧轨轩轪轫转轭轮软轰轱轲轳轴轵轶轷轸轹轺轻轼载轾轿辀辁辂较"
339 | "辄辅辆辇辈辉辊辋辌辍辎辏辐辑辒输辔辕辖辗辘辙辚辛辜辞辟辣辨辩辫辰辱边辽达辿迁迂迄"
340 | "迅过迈迎运近迓返迕还这进远违连迟迢迤迥迦迨迩迪迫迭迮述迳迷迸迹迺追退送适逃逄逅逆"
341 | "选逊逋逍透逐逑递途逖逗通逛逝逞速造逡逢逦逭逮逯逴逵逶逸逻逼逾遁遂遄遆遇遍遏遐遑遒"
342 | "道遗遘遛遢遣遥遨遭遮遴遵遹遽避邀邂邃邈邋邑邓邕邗邘邙邛邝邠邡邢那邦邨邪邬邮邯邰邱"
343 | "邲邳邴邵邶邸邹邺邻邽邾邿郁郃郄郅郇郈郊郎郏郐郑郓郗郚郛郜郝郡郢郤郦郧部郪郫郭郯郴"
344 | "郸都郾郿鄀鄂鄃鄄鄅鄌鄑鄗鄘鄙鄚鄜鄞鄠鄢鄣鄫鄯鄱鄹酂酃酅酆酉酊酋酌配酎酏酐酒酗酚酝"
345 | "酞酡酢酣酤酥酦酩酪酬酮酯酰酱酲酴酵酶酷酸酹酺酽酾酿醅醇醉醋醌醍醐醑醒醚醛醢醨醪醭"
346 | "醮醯醴醵醺醾采釉释里重野量釐金釜鉴銎銮鋆鋈錾鍪鎏鏊鏖鐾鑫钆钇针钉钊钋钌钍钎钏钐钒"
347 | "钓钔钕钖钗钘钙钚钛钜钝钞钟钠钡钢钣钤钥钦钧钨钩钪钫钬钭钮钯钰钱钲钳钴钵钷钹钺钻钼"
348 | "钽钾钿铀铁铂铃铄铅铆铈铉铊铋铌铍铎铏铐铑铒铕铖铗铘铙铚铛铜铝铞铟铠铡铢铣铤铥铧铨"
349 | "铩铪铫铬铭铮铯铰铱铲铳铴铵银铷铸铹铺铻铼铽链铿销锁锂锃锄锅锆锇锈锉锊锋锌锍锎锏锐"
350 | "锑锒锓锔锕锖锗锘错锚锛锜锝锞锟锡锢锣锤锥锦锧锨锩锪锫锬锭键锯锰锱锲锳锴锵锶锷锸锹"
351 | "锺锻锼锽锾锿镀镁镂镃镄镅镆镇镈镉镊镋镌镍镎镏镐镑镒镓镔镕镖镗镘镚镛镜镝镞镠镡镢镣"
352 | "镤镥镦镧镨镩镪镫镬镭镮镯镰镱镲镳镴镵镶长门闩闪闫闭问闯闰闱闲闳间闵闶闷闸闹闺闻闼"
353 | "闽闾闿阀阁阂阃阄阅阆阇阈阉阊阋阌阍阎阏阐阑阒阔阕阖阗阘阙阚阜队阡阪阮阱防阳阴阵阶"
354 | "阻阼阽阿陀陂附际陆陇陈陉陋陌降陎限陑陔陕陛陞陟陡院除陧陨险陪陬陲陴陵陶陷隃隅隆隈"
355 | "隋隍随隐隔隗隘隙障隧隩隰隳隶隹隺隼隽难雀雁雄雅集雇雉雊雌雍雎雏雒雕雠雨雩雪雯雱雳"
356 | "零雷雹雾需霁霄霅霆震霈霉霍霎霏霓霖霜霞霨霪霭霰露霸霹霾青靓靖静靛非靠靡面靥革靬靰"
357 | "靳靴靶靸靺靼靽靿鞁鞅鞋鞍鞑鞒鞔鞘鞠鞡鞣鞧鞨鞫鞬鞭鞮鞯鞲鞳鞴韂韦韧韨韩韪韫韬韭音韵"
358 | "韶页顶顷顸项顺须顼顽顾顿颀颁颂颃预颅领颇颈颉颊颋颌颍颎颏颐频颓颔颖颗题颙颚颛颜额"
359 | "颞颟颠颡颢颤颥颦颧风飏飐飑飒飓飔飕飗飘飙飞食飧飨餍餐餮饔饕饥饧饨饩饪饫饬饭饮饯饰"
360 | "饱饲饳饴饵饶饷饸饹饺饻饼饽饿馁馃馄馅馆馇馈馉馊馋馌馍馏馐馑馒馓馔馕首馗馘香馝馞馥"
361 | "馧馨马驭驮驯驰驱驲驳驴驵驶驷驸驹驺驻驼驽驾驿骀骁骂骃骄骅骆骇骈骉骊骋验骍骎骏骐骑"
362 | "骒骓骕骖骗骘骙骚骛骜骝骞骟骠骡骢骣骤骥骦骧骨骰骱骶骷骸骺骼髀髁髂髃髅髋髌髎髑髓高"
363 | "髡髢髦髫髭髯髹髻髽鬃鬈鬏鬒鬓鬘鬟鬣鬯鬲鬶鬷鬻鬼魁魂魃魄魅魆魇魈魉魋魍魏魑魔鱼鱽鱾"
364 | "鱿鲀鲁鲂鲃鲅鲆鲇鲈鲉鲊鲋鲌鲍鲎鲏鲐鲑鲒鲔鲕鲖鲗鲘鲙鲚鲛鲜鲝鲞鲟鲠鲡鲢鲣鲤鲥鲦鲧鲨"
365 | "鲩鲪鲫鲬鲭鲮鲯鲰鲱鲲鲳鲴鲵鲷鲸鲹鲺鲻鲼鲽鲾鲿鳀鳁鳂鳃鳄鳅鳇鳈鳉鳊鳌鳍鳎鳏鳐鳑鳒鳓"
366 | "鳔鳕鳖鳗鳘鳙鳚鳛鳜鳝鳞鳟鳠鳡鳢鳣鳤鸟鸠鸡鸢鸣鸤鸥鸦鸧鸨鸩鸪鸫鸬鸭鸮鸯鸰鸱鸲鸳鸵鸶"
367 | "鸷鸸鸹鸺鸻鸼鸽鸾鸿鹀鹁鹂鹃鹄鹅鹆鹇鹈鹉鹊鹋鹌鹍鹎鹏鹐鹑鹒鹔鹕鹖鹗鹘鹙鹚鹛鹜鹝鹞鹟"
368 | "鹠鹡鹢鹣鹤鹦鹧鹨鹩鹪鹫鹬鹭鹮鹯鹰鹱鹲鹳鹴鹾鹿麀麂麇麈麋麑麒麓麖麝麟麦麸麹麻麽麾黄"
369 | "黇黉黍黎黏黑黔默黛黜黝黟黠黡黢黥黧黩黪黯黹黻黼黾鼋鼍鼎鼐鼒鼓鼗鼙鼠鼢鼩鼫鼬鼯鼱鼷"
370 | "鼹鼻鼽鼾齁齇齉齐齑齿龀龁龂龃龄龅龆龇龈龉龊龋龌龙龚龛龟龠龢鿍鿎鿏㑇㑊㕮㘎㙍㙘㙦㛃"
371 | "㛚㛹㟃㠇㠓㤘㥄㧐㧑㧟㫰㬊㬎㬚㭎㭕㮾㰀㳇㳘㳚㴔㵐㶲㸆㸌㺄㻬㽏㿠䁖䂮䃅䃎䅟䌹䎃䎖䏝䏡"
372 | "䏲䐃䓖䓛䓨䓫䓬䗖䗛䗪䗴䜣䝙䢺䢼䣘䥽䦃䲟䲠䲢䴓䴔䴕䴖䴗䴘䴙䶮𠅤𠙶𠳐𡎚𡐓𣗋𣲗𣲘𣸣𤧛𤩽"
373 | "𤫉𥔲𥕢𥖨𥻗𦈡𦒍𦙶𦝼𦭜𦰡𧿹𨐈𨙸𨚕𨟠𨭉𨱇𨱏𨱑𨱔𨺙𩽾𩾃𩾌𪟝𪣻𪤗𪨰𪨶𪩘𪾢𫄧𫄨𫄷𫄸𫇭𫌀𫍣𫍯"
374 | "𫍲𫍽𫐄𫐐𫐓𫑡𫓧𫓯𫓶𫓹𫔍𫔎𫔶𫖮𫖯𫖳𫗧𫗴𫘜𫘝𫘦𫘧𫘨𫘪𫘬𫚕𫚖𫚭𫛭𫞩𫟅𫟦𫟹𫟼𫠆𫠊𫠜𫢸𫫇𫭟"
375 | "𫭢𫭼𫮃𫰛𫵷𫶇𫷷𫸩𬀩𬀪𬂩𬃊𬇕𬇙𬇹𬉼𬊈𬊤𬌗𬍛𬍡𬍤𬒈𬒔𬒗𬕂𬘓𬘘𬘡𬘩𬘫𬘬𬘭𬘯𬙂𬙊𬙋𬜬𬜯𬞟"
376 | "𬟁𬟽𬣙𬣞𬣡𬣳𬤇𬤊𬤝𬨂𬨎𬩽𬪩𬬩𬬭𬬮𬬱𬬸𬬹𬬻𬬿𬭁𬭊𬭎𬭚𬭛𬭤𬭩𬭬𬭯𬭳𬭶𬭸𬭼𬮱𬮿𬯀𬯎𬱖𬱟"
377 | "𬳵𬳶𬳽𬳿𬴂𬴃𬴊𬶋𬶍𬶏𬶐𬶟𬶠𬶨𬶭𬶮𬷕𬸘𬸚𬸣𬸦𬸪𬹼𬺈𬺓"
378 | )
379 | CN_CHARS_EXT = "吶诶屌囧飚屄"
380 |
381 | CN_CHARS = CN_CHARS_COMMON + CN_CHARS_EXT
382 | IN_CH_CHARS = {c: True for c in CN_CHARS}
383 |
384 | EN_CHARS = string.ascii_letters + string.digits
385 | IN_EN_CHARS = {c: True for c in EN_CHARS}
386 |
387 | VALID_CHARS = CN_CHARS + EN_CHARS + " "
388 | IN_VALID_CHARS = {c: True for c in VALID_CHARS}
389 |
390 |
391 | # ================================================================================ #
392 | # basic class
393 | # ================================================================================ #
394 | class ChineseChar(object):
395 | """
396 | 中文字符
397 | 每个字符对应简体和繁体,
398 | e.g. 简体 = '负', 繁体 = '負'
399 | 转换时可转换为简体或繁体
400 | """
401 |
402 | def __init__(self, simplified, traditional):
403 | self.simplified = simplified
404 | self.traditional = traditional
405 | # self.__repr__ = self.__str__
406 |
407 | def __str__(self):
408 | return self.simplified or self.traditional or None
409 |
410 | def __repr__(self):
411 | return self.__str__()
412 |
413 |
414 | class ChineseNumberUnit(ChineseChar):
415 | """
416 | 中文数字/数位字符
417 | 每个字符除繁简体外还有一个额外的大写字符
418 | e.g. '陆' 和 '陸'
419 | """
420 |
421 | def __init__(self, power, simplified, traditional, big_s, big_t):
422 | super(ChineseNumberUnit, self).__init__(simplified, traditional)
423 | self.power = power
424 | self.big_s = big_s
425 | self.big_t = big_t
426 |
427 | def __str__(self):
428 | return "10^{}".format(self.power)
429 |
430 | @classmethod
431 | def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
432 | if small_unit:
433 | return ChineseNumberUnit(
434 | power=index + 1, simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1]
435 | )
436 | elif numbering_type == NUMBERING_TYPES[0]:
437 | return ChineseNumberUnit(
438 | power=index + 8, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]
439 | )
440 | elif numbering_type == NUMBERING_TYPES[1]:
441 | return ChineseNumberUnit(
442 | power=(index + 2) * 4, simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]
443 | )
444 | elif numbering_type == NUMBERING_TYPES[2]:
445 | return ChineseNumberUnit(
446 | power=pow(2, index + 3), simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]
447 | )
448 | else:
449 | raise ValueError("Counting type should be in {0} ({1} provided).".format(NUMBERING_TYPES, numbering_type))
450 |
451 |
452 | class ChineseNumberDigit(ChineseChar):
453 | """
454 | 中文数字字符
455 | """
456 |
457 | def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
458 | super(ChineseNumberDigit, self).__init__(simplified, traditional)
459 | self.value = value
460 | self.big_s = big_s
461 | self.big_t = big_t
462 | self.alt_s = alt_s
463 | self.alt_t = alt_t
464 |
465 | def __str__(self):
466 | return str(self.value)
467 |
468 | @classmethod
469 | def create(cls, i, v):
470 | return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
471 |
472 |
473 | class ChineseMath(ChineseChar):
474 | """
475 | 中文数位字符
476 | """
477 |
478 | def __init__(self, simplified, traditional, symbol, expression=None):
479 | super(ChineseMath, self).__init__(simplified, traditional)
480 | self.symbol = symbol
481 | self.expression = expression
482 | self.big_s = simplified
483 | self.big_t = traditional
484 |
485 |
486 | CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
487 |
488 |
489 | class NumberSystem(object):
490 | """
491 | 中文数字系统
492 | """
493 |
494 | pass
495 |
496 |
497 | class MathSymbol(object):
498 | """
499 | 用于中文数字系统的数学符号 (繁/简体), e.g.
500 | positive = ['正', '正']
501 | negative = ['负', '負']
502 | point = ['点', '點']
503 | """
504 |
505 | def __init__(self, positive, negative, point):
506 | self.positive = positive
507 | self.negative = negative
508 | self.point = point
509 |
510 | def __iter__(self):
511 | for v in self.__dict__.values():
512 | yield v
513 |
514 |
515 | # class OtherSymbol(object):
516 | # """
517 | # 其他符号
518 | # """
519 | #
520 | # def __init__(self, sil):
521 | # self.sil = sil
522 | #
523 | # def __iter__(self):
524 | # for v in self.__dict__.values():
525 | # yield v
526 |
527 |
528 | # ================================================================================ #
529 | # basic utils
530 | # ================================================================================ #
531 | def create_system(numbering_type=NUMBERING_TYPES[1]):
532 | """
533 | 根据数字系统类型返回创建相应的数字系统,默认为 mid
534 | NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
535 | low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
536 | mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
537 | high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
538 | 返回对应的数字系统
539 | """
540 |
541 | # chinese number units of '亿' and larger
542 | all_larger_units = zip(LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
543 | larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
544 | # chinese number units of '十, 百, 千, 万'
545 | all_smaller_units = zip(SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
546 | smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
547 | # digis
548 | chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
549 | digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
550 | digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
551 | digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
552 | digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
553 |
554 | # symbols
555 | positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
556 | negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
557 | point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
558 | # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
559 | system = NumberSystem()
560 | system.units = smaller_units + larger_units
561 | system.digits = digits
562 | system.math = MathSymbol(positive_cn, negative_cn, point_cn)
563 | # system.symbols = OtherSymbol(sil_cn)
564 | return system
565 |
566 |
567 | def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
568 | def get_symbol(char, system):
569 | for u in system.units:
570 | if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
571 | return u
572 | for d in system.digits:
573 | if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
574 | return d
575 | for m in system.math:
576 | if char in [m.traditional, m.simplified]:
577 | return m
578 |
579 | def string2symbols(chinese_string, system):
580 | int_string, dec_string = chinese_string, ""
581 | for p in [system.math.point.simplified, system.math.point.traditional]:
582 | if p in chinese_string:
583 | int_string, dec_string = chinese_string.split(p)
584 | break
585 | return [get_symbol(c, system) for c in int_string], [get_symbol(c, system) for c in dec_string]
586 |
587 | def correct_symbols(integer_symbols, system):
588 | """
589 | 一百八 to 一百八十
590 | 一亿一千三百万 to 一亿 一千万 三百万
591 | """
592 |
593 | if integer_symbols and isinstance(integer_symbols[0], CNU):
594 | if integer_symbols[0].power == 1:
595 | integer_symbols = [system.digits[1]] + integer_symbols
596 |
597 | if len(integer_symbols) > 1:
598 | if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
599 | integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
600 |
601 | result = []
602 | unit_count = 0
603 | for s in integer_symbols:
604 | if isinstance(s, CND):
605 | result.append(s)
606 | unit_count = 0
607 | elif isinstance(s, CNU):
608 | current_unit = CNU(s.power, None, None, None, None)
609 | unit_count += 1
610 |
611 | if unit_count == 1:
612 | result.append(current_unit)
613 | elif unit_count > 1:
614 | for i in range(len(result)):
615 | if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
616 | result[-i - 1] = CNU(result[-i - 1].power + current_unit.power, None, None, None, None)
617 | return result
618 |
619 | def compute_value(integer_symbols):
620 | """
621 | Compute the value.
622 | When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
623 | e.g. '两千万' = 2000 * 10000 not 2000 + 10000
624 | """
625 | value = [0]
626 | last_power = 0
627 | for s in integer_symbols:
628 | if isinstance(s, CND):
629 | value[-1] = s.value
630 | elif isinstance(s, CNU):
631 | value[-1] *= pow(10, s.power)
632 | if s.power > last_power:
633 | value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
634 | last_power = s.power
635 | value.append(0)
636 | return sum(value)
637 |
638 | system = create_system(numbering_type)
639 | int_part, dec_part = string2symbols(chinese_string, system)
640 | int_part = correct_symbols(int_part, system)
641 | int_str = str(compute_value(int_part))
642 | dec_str = "".join([str(d.value) for d in dec_part])
643 | if dec_part:
644 | return "{0}.{1}".format(int_str, dec_str)
645 | else:
646 | return int_str
647 |
648 |
649 | def num2chn(
650 | number_string,
651 | numbering_type=NUMBERING_TYPES[1],
652 | big=False,
653 | traditional=False,
654 | alt_zero=False,
655 | alt_one=False,
656 | alt_two=True,
657 | use_zeros=True,
658 | use_units=True,
659 | ):
660 | def get_value(value_string, use_zeros=True):
661 | striped_string = value_string.lstrip("0")
662 |
663 | # record nothing if all zeros
664 | if not striped_string:
665 | return []
666 |
667 | # record one digits
668 | elif len(striped_string) == 1:
669 | if use_zeros and len(value_string) != len(striped_string):
670 | return [system.digits[0], system.digits[int(striped_string)]]
671 | else:
672 | return [system.digits[int(striped_string)]]
673 |
674 | # recursively record multiple digits
675 | else:
676 | result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
677 | result_string = value_string[: -result_unit.power]
678 | return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power :])
679 |
680 | system = create_system(numbering_type)
681 |
682 | int_dec = number_string.split(".")
683 | if len(int_dec) == 1:
684 | int_string = int_dec[0]
685 | dec_string = ""
686 | elif len(int_dec) == 2:
687 | int_string = int_dec[0]
688 | dec_string = int_dec[1]
689 | else:
690 | raise ValueError("invalid input num string with more than one dot: {}".format(number_string))
691 |
692 | if use_units and len(int_string) > 1:
693 | result_symbols = get_value(int_string)
694 | else:
695 | result_symbols = [system.digits[int(c)] for c in int_string]
696 | dec_symbols = [system.digits[int(c)] for c in dec_string]
697 | if dec_string:
698 | result_symbols += [system.math.point] + dec_symbols
699 |
700 | if alt_two:
701 | liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, system.digits[2].big_s, system.digits[2].big_t)
702 | for i, v in enumerate(result_symbols):
703 | if isinstance(v, CND) and v.value == 2:
704 | next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
705 | previous_symbol = result_symbols[i - 1] if i > 0 else None
706 | if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
707 | if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
708 | result_symbols[i] = liang
709 |
710 | # if big is True, '两' will not be used and `alt_two` has no impact on output
711 | if big:
712 | attr_name = "big_"
713 | if traditional:
714 | attr_name += "t"
715 | else:
716 | attr_name += "s"
717 | else:
718 | if traditional:
719 | attr_name = "traditional"
720 | else:
721 | attr_name = "simplified"
722 |
723 | result = "".join([getattr(s, attr_name) for s in result_symbols])
724 |
725 | # if not use_zeros:
726 | # result = result.strip(getattr(system.digits[0], attr_name))
727 |
728 | if alt_zero:
729 | result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
730 |
731 | if alt_one:
732 | result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
733 |
734 | for i, p in enumerate(POINT):
735 | if result.startswith(p):
736 | return CHINESE_DIGIS[0] + result
737 |
738 | # ^10, 11, .., 19
739 | if (
740 | len(result) >= 2
741 | and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]]
742 | and result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]
743 | ):
744 | result = result[1:]
745 |
746 | return result
747 |
748 |
749 | # ================================================================================ #
750 | # different types of rewriters
751 | # ================================================================================ #
752 | class Cardinal:
753 | """
754 | CARDINAL类
755 | """
756 |
757 | def __init__(self, cardinal=None, chntext=None):
758 | self.cardinal = cardinal
759 | self.chntext = chntext
760 |
761 | def chntext2cardinal(self):
762 | return chn2num(self.chntext)
763 |
764 | def cardinal2chntext(self):
765 | return num2chn(self.cardinal)
766 |
767 |
768 | class Digit:
769 | """
770 | DIGIT类
771 | """
772 |
773 | def __init__(self, digit=None, chntext=None):
774 | self.digit = digit
775 | self.chntext = chntext
776 |
777 | # def chntext2digit(self):
778 | # return chn2num(self.chntext)
779 |
780 | def digit2chntext(self):
781 | return num2chn(self.digit, alt_two=False, use_units=False)
782 |
783 |
784 | class TelePhone:
785 | """
786 | TELEPHONE类
787 | """
788 |
789 | def __init__(self, telephone=None, raw_chntext=None, chntext=None):
790 | self.telephone = telephone
791 | self.raw_chntext = raw_chntext
792 | self.chntext = chntext
793 |
794 | # def chntext2telephone(self):
795 | # sil_parts = self.raw_chntext.split('')
796 | # self.telephone = '-'.join([
797 | # str(chn2num(p)) for p in sil_parts
798 | # ])
799 | # return self.telephone
800 |
801 | def telephone2chntext(self, fixed=False):
802 | if fixed:
803 | sil_parts = self.telephone.split("-")
804 | self.raw_chntext = "".join([num2chn(part, alt_two=False, use_units=False) for part in sil_parts])
805 | self.chntext = self.raw_chntext.replace("", "")
806 | else:
807 | sp_parts = self.telephone.strip("+").split()
808 | self.raw_chntext = "".join([num2chn(part, alt_two=False, use_units=False) for part in sp_parts])
809 | self.chntext = self.raw_chntext.replace("", "")
810 | return self.chntext
811 |
812 |
813 | class Fraction:
814 | """
815 | FRACTION类
816 | """
817 |
818 | def __init__(self, fraction=None, chntext=None):
819 | self.fraction = fraction
820 | self.chntext = chntext
821 |
822 | def chntext2fraction(self):
823 | denominator, numerator = self.chntext.split("分之")
824 | return chn2num(numerator) + "/" + chn2num(denominator)
825 |
826 | def fraction2chntext(self):
827 | numerator, denominator = self.fraction.split("/")
828 | return num2chn(denominator) + "分之" + num2chn(numerator)
829 |
830 |
831 | class Date:
832 | """
833 | DATE类
834 | """
835 |
836 | def __init__(self, date=None, chntext=None):
837 | self.date = date
838 | self.chntext = chntext
839 |
840 | # def chntext2date(self):
841 | # chntext = self.chntext
842 | # try:
843 | # year, other = chntext.strip().split('年', maxsplit=1)
844 | # year = Digit(chntext=year).digit2chntext() + '年'
845 | # except ValueError:
846 | # other = chntext
847 | # year = ''
848 | # if other:
849 | # try:
850 | # month, day = other.strip().split('月', maxsplit=1)
851 | # month = Cardinal(chntext=month).chntext2cardinal() + '月'
852 | # except ValueError:
853 | # day = chntext
854 | # month = ''
855 | # if day:
856 | # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
857 | # else:
858 | # month = ''
859 | # day = ''
860 | # date = year + month + day
861 | # self.date = date
862 | # return self.date
863 |
864 | def date2chntext(self):
865 | date = self.date
866 | try:
867 | year, other = date.strip().split("年", 1)
868 | year = Digit(digit=year).digit2chntext() + "年"
869 | except ValueError:
870 | other = date
871 | year = ""
872 | if other:
873 | try:
874 | month, day = other.strip().split("月", 1)
875 | month = Cardinal(cardinal=month).cardinal2chntext() + "月"
876 | except ValueError:
877 | day = date
878 | month = ""
879 | if day:
880 | day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
881 | else:
882 | month = ""
883 | day = ""
884 | chntext = year + month + day
885 | self.chntext = chntext
886 | return self.chntext
887 |
888 |
889 | class Money:
890 | """
891 | MONEY类
892 | """
893 |
894 | def __init__(self, money=None, chntext=None):
895 | self.money = money
896 | self.chntext = chntext
897 |
898 | # def chntext2money(self):
899 | # return self.money
900 |
901 | def money2chntext(self):
902 | money = self.money
903 | pattern = re.compile(r"(\d+(\.\d+)?)")
904 | matchers = pattern.findall(money)
905 | if matchers:
906 | for matcher in matchers:
907 | money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
908 | self.chntext = money
909 | return self.chntext
910 |
911 |
912 | class Percentage:
913 | """
914 | PERCENTAGE类
915 | """
916 |
917 | def __init__(self, percentage=None, chntext=None):
918 | self.percentage = percentage
919 | self.chntext = chntext
920 |
921 | def chntext2percentage(self):
922 | return chn2num(self.chntext.strip().strip("百分之")) + "%"
923 |
924 | def percentage2chntext(self):
925 | return "百分之" + num2chn(self.percentage.strip().strip("%"))
926 |
927 |
928 | def normalize_nsw(raw_text):
929 | text = "^" + raw_text + "$"
930 |
931 | # 规范化日期
932 | pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
933 | matchers = pattern.findall(text)
934 | if matchers:
935 | # print('date')
936 | for matcher in matchers:
937 | text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
938 |
939 | # 规范化金钱
940 | pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
941 | matchers = pattern.findall(text)
942 | if matchers:
943 | # print('money')
944 | for matcher in matchers:
945 | text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
946 |
947 | # 规范化固话/手机号码
948 | # 手机
949 | # http://www.jihaoba.com/news/show/13680
950 | # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
951 | # 联通:130、131、132、156、155、186、185、176
952 | # 电信:133、153、189、180、181、177
953 | pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
954 | matchers = pattern.findall(text)
955 | if matchers:
956 | # print('telephone')
957 | for matcher in matchers:
958 | text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
959 | # 固话
960 | pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
961 | matchers = pattern.findall(text)
962 | if matchers:
963 | # print('fixed telephone')
964 | for matcher in matchers:
965 | text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
966 |
967 | # 规范化分数
968 | pattern = re.compile(r"(\d+/\d+)")
969 | matchers = pattern.findall(text)
970 | if matchers:
971 | # print('fraction')
972 | for matcher in matchers:
973 | text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
974 |
975 | # 规范化百分数
976 | text = text.replace("%", "%")
977 | pattern = re.compile(r"(\d+(\.\d+)?%)")
978 | matchers = pattern.findall(text)
979 | if matchers:
980 | # print('percentage')
981 | for matcher in matchers:
982 | text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
983 |
984 | # 规范化纯数+量词
985 | pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
986 | matchers = pattern.findall(text)
987 | if matchers:
988 | # print('cardinal+quantifier')
989 | for matcher in matchers:
990 | text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
991 |
992 | # 规范化数字编号
993 | pattern = re.compile(r"(\d{4,32})")
994 | matchers = pattern.findall(text)
995 | if matchers:
996 | # print('digit')
997 | for matcher in matchers:
998 | text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
999 |
1000 | # 规范化纯数
1001 | pattern = re.compile(r"(\d+(\.\d+)?)")
1002 | matchers = pattern.findall(text)
1003 | if matchers:
1004 | # print('cardinal')
1005 | for matcher in matchers:
1006 | text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
1007 |
1008 | # restore P2P, O2O, B2C, B2B etc
1009 | pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
1010 | matchers = pattern.findall(text)
1011 | if matchers:
1012 | # print('particular')
1013 | for matcher in matchers:
1014 | text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
1015 |
1016 | return text.lstrip("^").rstrip("$")
1017 |
1018 |
1019 | def remove_erhua(text):
1020 | """
1021 | 去除儿化音词中的儿:
1022 | 他女儿在那边儿 -> 他女儿在那边
1023 | """
1024 |
1025 | new_str = ""
1026 | while re.search("儿", text):
1027 | a = re.search("儿", text).span()
1028 | remove_er_flag = 0
1029 |
1030 | if ER_WHITELIST_PATTERN.search(text):
1031 | b = ER_WHITELIST_PATTERN.search(text).span()
1032 | if b[0] <= a[0]:
1033 | remove_er_flag = 1
1034 |
1035 | if remove_er_flag == 0:
1036 | new_str = new_str + text[0 : a[0]]
1037 | text = text[a[1] :]
1038 | else:
1039 | new_str = new_str + text[0 : b[1]]
1040 | text = text[b[1] :]
1041 |
1042 | text = new_str + text
1043 | return text
1044 |
1045 |
1046 | def remove_space(text):
1047 | tokens = text.split()
1048 | new = []
1049 | for k, t in enumerate(tokens):
1050 | if k != 0:
1051 | if IN_EN_CHARS.get(tokens[k - 1][-1]) and IN_EN_CHARS.get(t[0]):
1052 | new.append(" ")
1053 | new.append(t)
1054 | return "".join(new)
1055 |
1056 |
1057 | class TextNorm:
1058 | def __init__(
1059 | self,
1060 | to_banjiao: bool = False,
1061 | to_upper: bool = False,
1062 | to_lower: bool = False,
1063 | remove_fillers: bool = False,
1064 | remove_erhua: bool = False,
1065 | check_chars: bool = False,
1066 | remove_space: bool = False,
1067 | cc_mode: str = "",
1068 | ):
1069 | self.to_banjiao = to_banjiao
1070 | self.to_upper = to_upper
1071 | self.to_lower = to_lower
1072 | self.remove_fillers = remove_fillers
1073 | self.remove_erhua = remove_erhua
1074 | self.check_chars = check_chars
1075 | self.remove_space = remove_space
1076 |
1077 | self.cc = None
1078 | if cc_mode:
1079 | from opencc import OpenCC # Open Chinese Convert: pip install opencc
1080 |
1081 | self.cc = OpenCC(cc_mode)
1082 |
1083 | def __call__(self, text):
1084 | if self.cc:
1085 | text = self.cc.convert(text)
1086 |
1087 | if self.to_banjiao:
1088 | text = text.translate(QJ2BJ_TRANSFORM)
1089 |
1090 | if self.to_upper:
1091 | text = text.upper()
1092 |
1093 | if self.to_lower:
1094 | text = text.lower()
1095 |
1096 | if self.remove_fillers:
1097 | for c in FILLER_CHARS:
1098 | text = text.replace(c, "")
1099 |
1100 | if self.remove_erhua:
1101 | text = remove_erhua(text)
1102 |
1103 | text = normalize_nsw(text)
1104 |
1105 | text = text.translate(PUNCS_TRANSFORM)
1106 |
1107 | if self.check_chars:
1108 | for c in text:
1109 | if not IN_VALID_CHARS.get(c):
1110 | print(f"WARNING: illegal char {c} in: {text}", file=sys.stderr)
1111 | return ""
1112 |
1113 | if self.remove_space:
1114 | text = remove_space(text)
1115 |
1116 | return text
1117 |
1118 |
1119 | if __name__ == "__main__":
1120 | p = argparse.ArgumentParser()
1121 |
1122 | # normalizer options
1123 | p.add_argument("--to_banjiao", action="store_true", help="convert quanjiao chars to banjiao")
1124 | p.add_argument("--to_upper", action="store_true", help="convert to upper case")
1125 | p.add_argument("--to_lower", action="store_true", help="convert to lower case")
1126 | p.add_argument("--remove_fillers", action="store_true", help='remove filler chars such as "呃, 啊"')
1127 | p.add_argument("--remove_erhua", action="store_true", help='remove erhua chars such as "他女儿在那边儿 -> 他女儿在那边"')
1128 | p.add_argument("--check_chars", action="store_true", help="skip sentences containing illegal chars")
1129 | p.add_argument("--remove_space", action="store_true", help="remove whitespace")
1130 | p.add_argument(
1131 | "--cc_mode", choices=["", "t2s", "s2t"], default="", help="convert between traditional to simplified"
1132 | )
1133 |
1134 | # I/O options
1135 | p.add_argument("--log_interval", type=int, default=10000, help="log interval in number of processed lines")
1136 | p.add_argument("--has_key", action="store_true", help="will be deprecated, set --format ark instead")
1137 | p.add_argument("--format", type=str, choices=["txt", "ark", "tsv"], default="txt", help="input format")
1138 | p.add_argument("ifile", help="input filename, assume utf-8 encoding")
1139 | p.add_argument("ofile", help="output filename")
1140 |
1141 | args = p.parse_args()
1142 |
1143 | if args.has_key:
1144 | args.format = "ark"
1145 |
1146 | normalizer = TextNorm(
1147 | to_banjiao=args.to_banjiao,
1148 | to_upper=args.to_upper,
1149 | to_lower=args.to_lower,
1150 | remove_fillers=args.remove_fillers,
1151 | remove_erhua=args.remove_erhua,
1152 | check_chars=args.check_chars,
1153 | remove_space=args.remove_space,
1154 | cc_mode=args.cc_mode,
1155 | )
1156 |
1157 | normalizer = TextNorm(
1158 | to_banjiao=args.to_banjiao,
1159 | to_upper=args.to_upper,
1160 | to_lower=args.to_lower,
1161 | remove_fillers=args.remove_fillers,
1162 | remove_erhua=args.remove_erhua,
1163 | check_chars=args.check_chars,
1164 | remove_space=args.remove_space,
1165 | cc_mode=args.cc_mode,
1166 | )
1167 |
1168 | ndone = 0
1169 | with open(args.ifile, "r", encoding="utf8") as istream, open(args.ofile, "w+", encoding="utf8") as ostream:
1170 | if args.format == "tsv":
1171 | reader = csv.DictReader(istream, delimiter="\t")
1172 | assert "TEXT" in reader.fieldnames
1173 | print("\t".join(reader.fieldnames), file=ostream)
1174 |
1175 | for item in reader:
1176 | text = item["TEXT"]
1177 |
1178 | if text:
1179 | text = normalizer(text)
1180 |
1181 | if text:
1182 | item["TEXT"] = text
1183 | print("\t".join([item[f] for f in reader.fieldnames]), file=ostream)
1184 |
1185 | ndone += 1
1186 | if ndone % args.log_interval == 0:
1187 | print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True)
1188 | else:
1189 | for l in istream:
1190 | key, text = "", ""
1191 | if args.format == "ark": # KALDI archive, line format: "key text"
1192 | cols = l.strip().split(maxsplit=1)
1193 | key, text = cols[0], cols[1] if len(cols) == 2 else ""
1194 | else:
1195 | text = l.strip()
1196 |
1197 | if text:
1198 | text = normalizer(text)
1199 |
1200 | if text:
1201 | if args.format == "ark":
1202 | print(key + "\t" + text, file=ostream)
1203 | else:
1204 | print(text, file=ostream)
1205 |
1206 | ndone += 1
1207 | if ndone % args.log_interval == 0:
1208 | print(f"text norm: {ndone} lines done.", file=sys.stderr, flush=True)
1209 | print(f"text norm: {ndone} lines done in total.", file=sys.stderr, flush=True)
--------------------------------------------------------------------------------
/pyannote_diarize.py:
--------------------------------------------------------------------------------
1 | from pyannote.audio import Pipeline
2 | from collections import defaultdict
3 | import torch
4 | import os
5 | import re
6 | import gc
7 |
8 | access_token = os.getenv("HF_ACCESS_TOKEN")
9 |
10 | audio_file = "input/CoinToss.mp3"
11 | exact_speakers = 2
12 | min_speakers = 2
13 | max_speakers = 2
14 | output_directory = "output_pyannote"
15 |
16 |
17 | def diarize(audio_file, num_speakers=0, min_speakers=0, max_speakers=0):
18 | """
19 | Perform speaker diarization on an audio file
20 | using a pre-trained model from pyannote.audio.
21 |
22 | Parameters:
23 | - audio_file (str): Path to the audio file.
24 | - num_speakers (int, optional): The number of speakers in the audio.
25 | Default is 0 (unknown).
26 | - min_speakers (int, optional): The minimum number of speakers
27 | expected in the audio. Default is 0.
28 | - max_speakers (int, optional): The maximum number of speakers
29 | expected in the audio. Default is 0.
30 |
31 | Returns:
32 | - list: A sorted list of dictionaries with speaker information
33 | ('name', 'total_time', 'segments').
34 | """
35 |
36 | print(f"Running diarization on {audio_file}...")
37 |
38 | access_token = os.getenv("HF_ACCESS_TOKEN")
39 |
40 | pipeline = Pipeline.from_pretrained(
41 | "pyannote/speaker-diarization-3.1",
42 | use_auth_token=access_token
43 | )
44 |
45 | # Send pipeline to GPU (when available)
46 | pipeline.to(torch.device("cuda"))
47 | print("Model moved to GPU.")
48 |
49 | # Prepare diarization options
50 | options = {}
51 | if num_speakers > 0:
52 | options['num_speakers'] = num_speakers
53 | if min_speakers > 0:
54 | options['min_speakers'] = min_speakers
55 | if max_speakers > 0:
56 | options['max_speakers'] = max_speakers
57 |
58 | # Apply diarization
59 | diarization = pipeline(audio_file, **options)
60 | results = [
61 | (turn.start, turn.end, speaker)
62 | for turn, _, speaker in diarization.itertracks(yield_label=True)
63 | ]
64 |
65 | # Print results
66 | for start, end, speaker in results:
67 | print(f"Speaker {speaker} ({start:.1f} - {end:.1f})")
68 |
69 | # Create a defaultdict to temporarily store the data
70 | temp_speaker_data = defaultdict(lambda: {"total_time": 0, "segments": []})
71 |
72 | # Process results
73 | for start, end, speaker_name in results:
74 | duration = end - start
75 | temp_speaker_data[speaker_name]["total_time"] += duration
76 | temp_speaker_data[speaker_name]["segments"].append(
77 | {"start": start, "end": end}
78 | )
79 |
80 | # Sort and format speaker data
81 | speakers = sorted(
82 | [
83 | {
84 | "name": name,
85 | "total_time": data["total_time"],
86 | "segments": data["segments"]
87 | }
88 | for name, data in temp_speaker_data.items()
89 | ],
90 | key=lambda x: x["total_time"],
91 | reverse=True,
92 | )
93 |
94 | # Clean-up resources
95 | del pipeline
96 | torch.cuda.empty_cache()
97 | gc.collect()
98 |
99 | return speakers
100 |
101 |
102 | def print_speakers(speakers):
103 | """
104 | Prints the details of speakers detected in an audio file.
105 |
106 | This function lists each speaker along with their total time spoken.
107 | It also displays the time segments for when each speaker was talking.
108 |
109 | Args:
110 | speakers (list): A list of dictionaries where each dictionary
111 | represents a speaker. Each dictionary contains the total time
112 | spoken by the speaker and their speaking segments.
113 | """
114 |
115 | # Check the number of speakers and print it
116 | print(f"\nThere were {len(speakers)} speakers detected in the audio file. "
117 | "List of the speakers sorted by their total time spoken "
118 | "in descending order:")
119 |
120 | # Iterate over each speaker and print their details
121 | for speaker_number, speaker in enumerate(speakers, start=1):
122 | # Print total time spoken by the speaker
123 | print(f'Speaker {speaker_number} total: {speaker["total_time"]:.1f}s')
124 |
125 | # Iterate over each segment and print the start and end times
126 | for index, seg in enumerate(speaker["segments"]):
127 | print(f' [{seg["start"]:.1f}s - {seg["end"]:.1f}s]', end="")
128 |
129 | # Add a new line after every 5 segments for better readability
130 | if (index + 1) % 5 == 0:
131 | print()
132 |
133 | # Print a newline for separation between speakers
134 | print()
135 |
136 |
137 | def speaker_files_exist(speakers):
138 | """
139 | Checks if text files exist for each speaker.
140 |
141 | This function iterates over a list of speakers and checks if
142 | a corresponding text file named "speaker{speaker_number}.txt"
143 | exists for each speaker.
144 |
145 | Args:
146 | speakers (list): A list of speaker information, used to
147 | determine the number of speakers.
148 |
149 | Returns:
150 | bool: True if all speaker files exist, False otherwise.
151 | """
152 |
153 | # Iterate over each speaker
154 | for speaker_number, _ in enumerate(speakers, start=1):
155 | # Construct the filename for each speaker
156 | timefile = f"speaker{speaker_number}.txt"
157 |
158 | # Check if the file for the current speaker exists
159 | if not os.path.exists(timefile):
160 | # Return False immediately if a file is missing
161 | return False
162 |
163 | # Return True if all files are found
164 | return True
165 |
166 |
167 | def import_time_file(timefile):
168 | """
169 | Imports a time file and returns a list of tuples.
170 |
171 | Parameters:
172 | - timefile (str): The path to the time file.
173 |
174 | Returns:
175 | - list: A list of tuples in the format of (start_time, end_time).
176 | """
177 | with open(timefile, "r", encoding='utf-8') as f:
178 | lines = f.readlines()
179 |
180 | time_list = []
181 |
182 | for line in lines:
183 | # Remove whitespace and find time strings enclosed in square brackets
184 | line = line.strip()
185 | time_strings = re.findall(r"\[(.*?)\]", line)
186 |
187 | for time_string in time_strings:
188 | # Split each time string into start and end times
189 | start_time, end_time = time_string.split("-")
190 |
191 | # Convert time strings to seconds and add to list
192 | start_time = time_to_seconds(start_time)
193 | end_time = time_to_seconds(end_time)
194 | time_list.append((start_time, end_time))
195 |
196 | return time_list
197 |
198 |
199 | def read_speaker_timefiles(directory):
200 | """
201 | Reads speaker time files and returns a list of dictionaries
202 | with speaker information.
203 |
204 | Each dictionary in the list contains 'name', 'total_time',
205 | and 'segments' keys.
206 |
207 | Returns:
208 | - list: A list of dictionaries where each dictionary contains the
209 | name, total time spoken, and speaking segments of a speaker.
210 | """
211 |
212 | speakers = []
213 |
214 | speaker_number = 1
215 | while True:
216 | timefile = f"speaker{speaker_number}.txt"
217 | timefile = os.path.join(directory, timefile)
218 | if not os.path.exists(timefile):
219 | break
220 |
221 | segments = import_time_file(timefile)
222 | total_time = sum(end - start for start, end in segments)
223 | speakers.append({
224 | "name": f"Speaker{speaker_number}",
225 | "total_time": total_time,
226 | "segments": segments
227 | })
228 |
229 | speaker_number += 1
230 |
231 | return speakers
232 |
233 |
234 | def write_speaker_timefiles(speakers, directory):
235 | """
236 | Writes time information of speakers to individual text files.
237 |
238 | For each speaker, this function creates a text file named
239 | "speaker{speaker_number}.txt". The file contains the total time
240 | spoken by the speaker and the time segments of their speech.
241 |
242 | Args:
243 | speakers (list): A list of dictionaries where each dictionary
244 | represents a speaker. Each dictionary contains the total time
245 | spoken by the speaker and their speaking segments.
246 | """
247 |
248 | # Iterate over each speaker
249 | for speaker_number, speaker in enumerate(speakers, start=1):
250 | # Define the filename for each speaker
251 | timefile = f"speaker{speaker_number}.txt"
252 | timefile = os.path.join(directory, timefile)
253 | print(f"Writing time file for speaker {speaker_number} "
254 | f"to {timefile}...")
255 |
256 | # Open the file in write mode
257 | with open(timefile, "w", encoding='utf-8') as f:
258 | # Write the total time spoken by the speaker
259 | f.write(f"Speaker {speaker_number} total: "
260 | f"{speaker['total_time']:.1f}s\n\n")
261 |
262 | # Write each time segment of the speaker
263 | for segment in speaker["segments"]:
264 | f.write(f"[{segment['start']:.1f}-{segment['end']:.1f}]\n")
265 |
266 |
267 | def time_to_seconds(time_str):
268 | """
269 | Converts a time string in various formats to seconds, now including hours,
270 | decimal seconds, plain numbers, and decimal seconds without 's'.
271 |
272 | Supported formats:
273 | - 'XhYmZs': X hours, Y minutes and Z seconds (e.g., '1h2m3s')
274 | - 'XmYs': X minutes and Y seconds (e.g., '3m23s')
275 | - 'X.Ys': X.Y seconds with decimal (e.g., '34.4s', '38.92255s')
276 | - 'X.Y': X.Y seconds without 's' (e.g., '34.4', '38.92255')
277 | - 'Xs': X seconds (e.g., '34s')
278 | - 'X:Y:Z': X hours, Y minutes and Z seconds (e.g., '1:02:03')
279 | - 'X:Y': X minutes and Y seconds (e.g., '3:00')
280 | - 'X': X seconds (e.g., '45')
281 |
282 | Parameters:
283 | - time_str (str): The time string to convert.
284 |
285 | Returns:
286 | - float: The number of seconds.
287 | """
288 |
289 | time_str = time_str.strip()
290 |
291 | # Regex patterns for different time formats
292 | pattern_hours_minutes_seconds = r'(\d+)h(\d+)m(\d+)s'
293 | pattern_minutes_seconds = r'(\d+)m(\d+)s'
294 | pattern_decimal_seconds = r'(\d+\.\d+)s?'
295 | pattern_seconds = r'(\d+)s'
296 | pattern_hours_colon = r'(\d+):(\d+):(\d+)'
297 | pattern_minutes_colon = r'(\d+):(\d+)'
298 | pattern_plain_number = r'^\d+$'
299 |
300 | # Match the time string against different patterns
301 | if re.match(pattern_hours_minutes_seconds, time_str):
302 | hours, minutes, seconds = map(
303 | int,
304 | re.findall(pattern_hours_minutes_seconds, time_str)[0]
305 | )
306 | return hours * 3600 + minutes * 60 + seconds
307 | elif re.match(pattern_minutes_seconds, time_str):
308 | minutes, seconds = map(
309 | int,
310 | re.findall(pattern_minutes_seconds, time_str)[0]
311 | )
312 | return minutes * 60 + seconds
313 | elif re.match(pattern_decimal_seconds, time_str):
314 | seconds = float(re.findall(pattern_decimal_seconds, time_str)[0])
315 | return seconds
316 | elif re.match(pattern_seconds, time_str):
317 | seconds = int(re.findall(pattern_seconds, time_str)[0])
318 | return seconds
319 | elif re.match(pattern_hours_colon, time_str):
320 | hours, minutes, seconds = map(
321 | int,
322 | re.findall(pattern_hours_colon, time_str)[0]
323 | )
324 | return hours * 3600 + minutes * 60 + seconds
325 | elif re.match(pattern_minutes_colon, time_str):
326 | minutes, seconds = map(
327 | int,
328 | re.findall(pattern_minutes_colon, time_str)[0]
329 | )
330 | return minutes * 60 + seconds
331 | elif re.match(pattern_plain_number, time_str):
332 | seconds = int(time_str)
333 | return seconds
334 | else:
335 | raise ValueError("Unsupported time format")
336 |
337 |
338 | def filter_speakers(speakers, time_start=0, time_end=None):
339 | """
340 | Filters a list of speakers based on a specified time window.
341 |
342 | Parameters:
343 | - speakers (list): A list of speaker dictionaries.
344 | - time_start (int): The start time of the filtering window in seconds.
345 | - time_end (int): The end time of the filtering window in seconds.
346 |
347 | Returns:
348 | - list: A new list of speakers filtered by the given time window.
349 | """
350 | filtered_speakers = []
351 |
352 | for speaker in speakers:
353 | # Filter segments within the time window
354 | filtered_segments = []
355 | for segment in speaker['segments']:
356 | start, end = segment['start'], segment['end']
357 | # Check if segment overlaps with the time window
358 | if start <= time_end and end >= time_start:
359 | filtered_segments.append(segment)
360 |
361 | # Update speaker data if there are filtered segments
362 | if filtered_segments:
363 |
364 | total_time = sum(
365 | [seg['end'] - seg['start'] for seg in filtered_segments]
366 | )
367 |
368 | filtered_speaker = {
369 | 'name': speaker['name'],
370 | 'total_time': total_time,
371 | 'segments': filtered_segments
372 | }
373 | filtered_speakers.append(filtered_speaker)
374 |
375 | return filtered_speakers
376 |
377 |
378 | def test_diarization():
379 | # Ensure output directory exists
380 | if not os.path.exists(output_directory):
381 | os.makedirs(output_directory)
382 |
383 | # Perform diarization
384 | speakers = diarize(audio_file, exact_speakers, min_speakers, max_speakers)
385 |
386 | # Write diarization results to files
387 | write_speaker_timefiles(speakers, output_directory)
388 |
389 | print("Diarization test completed. Check the 'output' directory for results.")
390 |
391 |
392 | # Run the test
393 | test_diarization()
--------------------------------------------------------------------------------
/realtime_diarize.py:
--------------------------------------------------------------------------------
1 | from PyQt6.QtWidgets import QApplication, QTextEdit, QMainWindow, QLabel, QVBoxLayout, QWidget, QDoubleSpinBox, QHBoxLayout, QPushButton, QSpacerItem, QSizePolicy
2 | from PyQt6.QtCore import Qt, pyqtSignal, QThread, QEvent, QTimer
3 | from sklearn.cluster import AgglomerativeClustering, KMeans
4 | from TTS.tts.models import setup_model as setup_tts_model
5 | from sklearn.preprocessing import StandardScaler
6 | from sklearn.metrics import silhouette_score
7 | from RealtimeSTT import AudioToTextRecorder
8 | from TTS.config import load_config
9 | import numpy as np
10 | import pyaudio
11 | import queue
12 | import torch
13 | import wave
14 | import sys
15 | import os
16 |
17 | SILENCE_THRESHS = [0, 0.4]
18 | FINAL_TRANSCRIPTION_MODEL = "large-v2"
19 | FINAL_BEAM_SIZE = 5
20 | REALTIME_TRANSCRIPTION_MODEL = "distil-small.en"
21 | REALTIME_BEAM_SIZE = 5
22 | TRANSCRIPTION_LANGUAGE = "en"
23 | SILERO_SENSITIVITY = 0.4
24 | WEBRTC_SENSITIVITY = 3
25 | MIN_LENGTH_OF_RECORDING = 0.7
26 | PRE_RECORDING_BUFFER_DURATION = 0.35
27 | INIT_TWO_SPEAKER_THRESHOLD = 17
28 | INIT_SILHOUETTE_DIFF_THRESHOLD = 0.0001
29 |
30 | FAST_SENTENCE_END = True
31 | USE_MICROPHONE = True
32 | LOOPBACK_DEVICE_NAME = "stereomix"
33 | LOOPBACK_DEVICE_HOST_API = 0
34 |
35 | FORMAT = pyaudio.paInt16
36 | CHANNELS = 1
37 | SAMPLE_RATE = 16000
38 | BUFFER_SIZE = 512
39 |
40 |
41 | COLOR_TABLE_HEX = [
42 | "#FFFF00", # yellow
43 | "#FF0000", # red
44 | "#00FFFF", # cyan
45 | "#FF00FF", # magenta
46 | "#FFA500", # orange
47 | "#00FF00", # lime
48 | "#800080", # purple
49 | "#FFC0CB", # pink
50 | "#008080", # teal
51 | "#FF7F50", # coral
52 | "#00FFFF", # aqua
53 | "#8A2BE2", # violet
54 | "#FFD700", # gold
55 | "#7FFF00", # chartreuse
56 | "#FF00FF", # fuchsia
57 | "#A0522D", # sienna
58 | "#40E0D0", # turquoise
59 | "#D2691E", # chocolate
60 | "#DC143C", # crimson
61 | "#FA8072", # salmon
62 | "#DA70D6", # orchid
63 | "#DDA0DD", # plum
64 | "#FFBF00", # amber
65 | "#007FFF", # azure
66 | "#F5F5DC", # beige
67 | "#E6E6FA", # lavender
68 | "#CC7722", # ochre
69 | "#FFDAB9", # peach
70 | "#9B111E", # ruby
71 | "#C0C0C0", # silver
72 | "#D2B48C", # tan
73 | "#F5DEB3", # wheat
74 | "#CD7F32", # bronze
75 | "#3EB489", # mint
76 | "#EAE0C8", # pearl
77 | "#0F52BA", # sapphire
78 | "#F28500", # tangerine
79 | "#50C878", # emerald
80 | "#FF007F", # rose
81 | "#9966CC", # amethyst
82 | "#2A52BE", # cerulean
83 | "#B87333", # copper
84 | "#FFFFF0", # ivory
85 | "#C3B091", # khaki
86 | "#E30B5D", # raspberry
87 | "#D9381E", # vermilion
88 | "#36454F", # charcoal
89 | "#FC8EAC", # flamingo
90 | "#00A36C", # jade
91 | "#FFF44F", # lemon
92 | "#D9E650", # quartz
93 | "#FF6347", # tomato
94 | "#0047AB", # cobalt
95 | "#F4C430", # saffron
96 | "#F9543B", # zinnia
97 | "#808000", # olive
98 | "#800000", # maroon
99 | "#000080", # navy
100 | "#008000", # green
101 | "#0000FF", # blue
102 | "#800000", # merlot
103 | "#4B0082", # indigo
104 | ]
105 |
106 | two_speaker_threshold = INIT_TWO_SPEAKER_THRESHOLD
107 | silhouette_diff_threshold = INIT_SILHOUETTE_DIFF_THRESHOLD
108 |
109 |
110 | class TextRetrievalThread(QThread):
111 | textRetrievedFinal = pyqtSignal(str, np.ndarray)
112 | textRetrievedLive = pyqtSignal(str)
113 | recorderStarted = pyqtSignal()
114 |
115 | def __init__(self):
116 | super().__init__()
117 |
118 | def live_text_detected(self, text):
119 | self.textRetrievedLive.emit(text)
120 |
121 | def run(self):
122 | print("Emitted Starting recorder")
123 | recorder_config = {
124 | 'spinner': False,
125 | 'use_microphone': False,
126 | 'model': FINAL_TRANSCRIPTION_MODEL,
127 | 'language': TRANSCRIPTION_LANGUAGE,
128 | 'silero_sensitivity': SILERO_SENSITIVITY,
129 | 'webrtc_sensitivity': WEBRTC_SENSITIVITY,
130 | 'post_speech_silence_duration': SILENCE_THRESHS[1],
131 | 'min_length_of_recording': MIN_LENGTH_OF_RECORDING,
132 | 'pre_recording_buffer_duration': PRE_RECORDING_BUFFER_DURATION,
133 | 'min_gap_between_recordings': 0,
134 | 'enable_realtime_transcription': True,
135 | 'realtime_processing_pause': 0,
136 | 'realtime_model_type': REALTIME_TRANSCRIPTION_MODEL,
137 | 'on_realtime_transcription_update': self.live_text_detected,
138 | 'beam_size': FINAL_BEAM_SIZE,
139 | 'beam_size_realtime': REALTIME_BEAM_SIZE,
140 | 'buffer_size': BUFFER_SIZE,
141 | 'sample_rate': SAMPLE_RATE,
142 | }
143 |
144 | self.recorder = AudioToTextRecorder(**recorder_config)
145 | self.recorderStarted.emit()
146 |
147 | def process_text(text):
148 | bytes = self.recorder.last_transcription_bytes
149 | self.textRetrievedFinal.emit(text, bytes)
150 |
151 | while True:
152 | self.recorder.text(process_text)
153 |
154 |
155 | class TextUpdateThread(QThread):
156 | text_update_signal = pyqtSignal(str)
157 |
158 | def __init__(self, text):
159 | super().__init__()
160 | self.text = text
161 |
162 | def run(self):
163 | self.text_update_signal.emit(self.text)
164 |
165 |
166 | class RecordingThread(QThread):
167 | def __init__(self, recorder):
168 | super().__init__()
169 |
170 | self.audio = pyaudio.PyAudio()
171 |
172 | def find_stereo_mix_index():
173 | devices = ""
174 | for i in range(self.audio.get_device_count()):
175 | dev = self.audio.get_device_info_by_index(i)
176 | devices += f"{dev['index']}: {dev['name']} "
177 | f"(hostApi: {dev['hostApi']})\n"
178 |
179 | if (LOOPBACK_DEVICE_NAME in dev['name'].lower()
180 | and dev['hostApi'] == LOOPBACK_DEVICE_HOST_API):
181 | return dev['index'], devices
182 |
183 | return None, devices
184 |
185 | # Selecting the input device based on USE_MICROPHONE flag
186 | if USE_MICROPHONE:
187 | input_device_index = 0 # Default input device (microphone)
188 | else:
189 | input_device_index, _ = find_stereo_mix_index()
190 | if input_device_index is None:
191 | print("Loopback / Stereo Mix device not found")
192 | print("Available devices:\n", devices)
193 | self.audio.terminate()
194 | exit()
195 | else:
196 | print(f"Stereo Mix device found at index: {input_device_index}")
197 |
198 | self.stream = self.audio.open(
199 | format=FORMAT,
200 | channels=CHANNELS,
201 | rate=SAMPLE_RATE,
202 | input=True,
203 | frames_per_buffer=BUFFER_SIZE,
204 | input_device_index=input_device_index)
205 | self.recorder = recorder
206 | self._is_running = True
207 |
208 | def run(self):
209 | while self._is_running:
210 | data = self.stream.read(BUFFER_SIZE, exception_on_overflow=False)
211 | self.recorder.feed_audio(data)
212 |
213 | def stop(self):
214 | self._is_running = False
215 | self.stream.stop_stream()
216 | self.stream.close()
217 | self.audio.terminate()
218 |
219 |
220 | class SentenceWorker(QThread):
221 | sentence_update_signal = pyqtSignal(list, list)
222 |
223 | def __init__(self, queue, tts_model):
224 | super().__init__()
225 | self.queue = queue
226 | self.tts = tts_model
227 | self._is_running = True
228 | self.full_sentences = []
229 | self.sentence_speakers = []
230 | self.speaker_index = 0
231 | self.speakers = []
232 |
233 | def run(self):
234 | while self._is_running:
235 | try:
236 | text, bytes = self.queue.get(timeout=1)
237 | self.process_item(text, bytes)
238 | except queue.Empty:
239 | continue
240 |
241 | # Safety check using KMeans for initial speaker detection
242 | def determine_optimal_cluster_count(self, embeddings_scaled):
243 | num_embeddings = len(embeddings_scaled)
244 | if num_embeddings <= 1:
245 | # Only one embedding, so only one speaker
246 | return 1
247 |
248 | # Determine single or multiple speakers
249 | # K-means Clustering with k=2
250 | kmeans = KMeans(n_clusters=2, random_state=0).fit(embeddings_scaled)
251 | distances = kmeans.transform(embeddings_scaled)
252 | avg_distance = np.mean(np.min(distances, axis=1))
253 | distance_threshold = two_speaker_threshold # Threshold to decide if we have one or multiple speakers
254 |
255 | # Check if the average distance is below threshold for single speaker
256 | if avg_distance < distance_threshold:
257 | print(f"Single Speaker: low embedding distance: {avg_distance} < {distance_threshold}.")
258 | return 1
259 |
260 | # Hierarchical Clustering for multiple speakers
261 | max_clusters = min(10, num_embeddings)
262 | range_clusters = range(2, max_clusters + 1)
263 | silhouette_scores = []
264 |
265 | for n_clusters in range_clusters:
266 | hc = AgglomerativeClustering(n_clusters=n_clusters, linkage='ward')
267 | cluster_labels = hc.fit_predict(embeddings_scaled)
268 |
269 | unique_labels = set(cluster_labels)
270 | if 1 < len(unique_labels) < len(embeddings_scaled):
271 | silhouette_avg = silhouette_score(embeddings_scaled, cluster_labels)
272 | silhouette_scores.append(silhouette_avg)
273 | else:
274 | print(f"Inappropriate number of clusters: {len(unique_labels)}.")
275 | silhouette_scores.append(-1)
276 |
277 |
278 | # Find the optimal number of clusters
279 | # It's the point before the silhouette score starts to decrease significantly
280 | optimal_cluster_count = 2
281 | for i in range(1, len(silhouette_scores)):
282 | if silhouette_scores[i] < silhouette_scores[i-1] + silhouette_diff_threshold:
283 | optimal_cluster_count = range_clusters[i-1]
284 | break
285 |
286 | return optimal_cluster_count
287 |
288 | def process_speakers(self):
289 | embeddings = [speaker_embedding for _, speaker_embedding in self.full_sentences]
290 |
291 | # Standard scaling
292 | embeddings_array = np.array(embeddings)
293 | scaler = StandardScaler()
294 | embeddings_scaled = scaler.fit_transform(embeddings_array)
295 |
296 | optimal_cluster_count = self.determine_optimal_cluster_count(embeddings_scaled)
297 |
298 | if optimal_cluster_count == 1:
299 | self.sentence_speakers = [0] * len(self.full_sentences)
300 | else:
301 | self.sentence_speakers = []
302 |
303 | # Determine clusters
304 | hc = AgglomerativeClustering(n_clusters=optimal_cluster_count, linkage='ward')
305 | clusters = hc.fit_predict(embeddings_scaled)
306 |
307 | # Assign sentences to clusters
308 | # Create a mapping from old to new cluster indices
309 | cluster_mapping = {}
310 | new_index = 0
311 | for cluster in clusters:
312 | if cluster not in cluster_mapping:
313 | cluster_mapping[cluster] = new_index
314 | new_index += 1
315 |
316 | # Assign sentences to clusters with new indices
317 | for cluster in clusters:
318 | self.sentence_speakers.append(cluster_mapping[cluster])
319 |
320 | def process_item(self, text, bytes):
321 | audio_int16 = np.int16(bytes * 32767)
322 |
323 | tempfile = "output.wav"
324 | with wave.open(tempfile, 'w') as wav_file:
325 | wav_file.setnchannels(1)
326 | wav_file.setsampwidth(2)
327 | wav_file.setframerate(16000)
328 | wav_file.writeframes(audio_int16.tobytes())
329 |
330 | for tries in range(3):
331 | try:
332 | _, speaker_embedding = self.tts.get_conditioning_latents(
333 | audio_path=tempfile,
334 | gpt_cond_len=30,
335 | max_ref_length=60)
336 | speaker_embedding = \
337 | speaker_embedding.view(-1).cpu().detach().numpy()
338 | break
339 | except Exception as e:
340 | print(f"Error in try {tries}: {e}")
341 | speaker_embedding = np.zeros(512)
342 |
343 | self.full_sentences.append((text, speaker_embedding))
344 | self.process_speakers()
345 | self.sentence_update_signal.emit(self.full_sentences, self.sentence_speakers)
346 |
347 | def stop(self):
348 | self._is_running = False
349 |
350 |
351 | class MainWindow(QMainWindow):
352 | def __init__(self):
353 | super().__init__()
354 |
355 | self.setWindowTitle("Realtime Speaker Diarization")
356 |
357 | self.tts = None
358 | self.initialized = False
359 | self.displayed_text = ""
360 | self.last_realtime_text = ""
361 | self.full_sentences = []
362 | self.sentence_speakers = []
363 | self.speaker_index = 0
364 | self.pending_sentences = []
365 | # self.speakers = []
366 | self.queue = queue.Queue()
367 |
368 | # Create the main layout as horizontal
369 | self.mainLayout = QHBoxLayout()
370 |
371 | # Add the text edit to the main layout
372 | self.text_edit = QTextEdit(self)
373 | self.mainLayout.addWidget(self.text_edit, 1)
374 |
375 | # Create the right layout for controls and add them to the main layout
376 | self.rightLayout = QVBoxLayout()
377 | self.rightLayout.setAlignment(Qt.AlignmentFlag.AlignTop) # Align controls to the top
378 | self.create_controls()
379 |
380 | # Create a container for the right layout
381 | self.rightContainer = QWidget()
382 | self.rightContainer.setLayout(self.rightLayout)
383 | self.mainLayout.addWidget(self.rightContainer, 0) # Controls get the space they need
384 |
385 | # Set the main layout to the central widget
386 | self.centralWidget = QWidget()
387 | self.centralWidget.setLayout(self.mainLayout)
388 | self.setCentralWidget(self.centralWidget)
389 |
390 | self.setStyleSheet("""
391 | QLabel {
392 | color: #ddd;
393 | }
394 | QDoubleSpinBox {
395 | background: #333;
396 | color: #ddd;
397 | border: 1px solid #555;
398 | margin-bottom: 22px;
399 | }
400 | QTextEdit {
401 | background-color: #1e1e1e;
402 | color: #ffffff;
403 | font-family: 'Arial';
404 | font-size: 16pt;
405 | }
406 | """)
407 |
408 | def create_controls(self):
409 |
410 | self.two_speaker_threshold_desc = QLabel("For one or two speakers differentiation:")
411 | self.two_speaker_threshold_label = QLabel("Two cluster similarity (0.1-100)")
412 | self.two_speaker_threshold_spinbox = QDoubleSpinBox()
413 | self.two_speaker_threshold_spinbox.setRange(0.1, 100)
414 | self.two_speaker_threshold_spinbox.setSingleStep(0.1)
415 | self.two_speaker_threshold_spinbox.setValue(two_speaker_threshold)
416 | self.two_speaker_threshold_spinbox.valueChanged.connect(self.update_two_speaker_threshold)
417 |
418 | self.silhouette_diff_threshold_desc = QLabel("For more than two speakers differentiation:")
419 | self.silhouette_diff_threshold_label = QLabel("Silhouette similarity (0.001-1)")
420 | self.silhouette_diff_threshold_spinbox = QDoubleSpinBox()
421 | self.silhouette_diff_threshold_spinbox.setDecimals(5)
422 | self.silhouette_diff_threshold_spinbox.setRange(0, 0.01)
423 | self.silhouette_diff_threshold_spinbox.setSingleStep(0.00001)
424 | self.silhouette_diff_threshold_spinbox.setValue(silhouette_diff_threshold)
425 | self.silhouette_diff_threshold_spinbox.valueChanged.connect(self.update_silhouette_diff_threshold)
426 |
427 | self.two_speaker_threshold_label.setToolTip(
428 | "Adjust this threshold to control how the program differentiates between one or two speakers. "
429 | "Lower values mean only highly distinct voices are considered separate speakers. "
430 | "Higher values allow more leniency in identifying different speakers."
431 | )
432 |
433 | self.silhouette_diff_threshold_spinbox.setToolTip(
434 | "This value determines the required increase in similarity score to identify an additional speaker. "
435 | "Lower values make it easier to identify more speakers. "
436 | "Higher values prevent too many speakers from being identified, especially in noisy conditions."
437 | )
438 |
439 | # Add the controls to the right layout
440 | self.rightLayout.addWidget(self.two_speaker_threshold_desc)
441 | self.rightLayout.addWidget(self.two_speaker_threshold_label)
442 | self.rightLayout.addWidget(self.two_speaker_threshold_spinbox)
443 |
444 | self.rightLayout.addWidget(self.silhouette_diff_threshold_desc)
445 | self.rightLayout.addWidget(self.silhouette_diff_threshold_label)
446 | self.rightLayout.addWidget(self.silhouette_diff_threshold_spinbox)
447 |
448 | self.clear_button = QPushButton("Clear")
449 | self.clear_button.clicked.connect(self.clear_state)
450 | self.rightLayout.addWidget(self.clear_button)
451 |
452 | def clear_state(self):
453 | # Clear text edit
454 | self.text_edit.clear()
455 |
456 | # Reset state variables
457 | self.displayed_text = ""
458 | self.last_realtime_text = ""
459 | self.full_sentences = []
460 | self.sentence_speakers = []
461 | self.pending_sentences = []
462 | self.worker_thread.full_sentences = []
463 | self.worker_thread.sentence_speakers = []
464 | self.worker_thread.speakers = []
465 |
466 | # Optional: Provide a message in text edit to indicate clearing
467 | self.text_edit.setHtml("All cleared. Ready for new input.")
468 |
469 | def update_ui(self):
470 | self.worker_thread.process_speakers()
471 | self.sentence_updated(
472 | self.worker_thread.full_sentences,
473 | self.worker_thread.sentence_speakers)
474 |
475 | def update_two_speaker_threshold(self, value):
476 | global two_speaker_threshold
477 | two_speaker_threshold = value
478 | self.update_ui()
479 |
480 | def update_silhouette_diff_threshold(self, value):
481 | global silhouette_diff_threshold
482 | silhouette_diff_threshold = value
483 | self.update_ui()
484 |
485 | def showEvent(self, event):
486 | super().showEvent(event)
487 | if event.type() == QEvent.Type.Show:
488 | if not self.initialized:
489 | self.initialized = True
490 | self.resize(1200, 800)
491 | self.update_text("Please wait until app is loaded")
492 |
493 | QTimer.singleShot(500, self.init)
494 |
495 | def process_live_text(self, text):
496 | text = text.strip()
497 |
498 | if text:
499 | sentence_delimiters = '.?!。'
500 | prob_sentence_end = (
501 | len(self.last_realtime_text) > 0
502 | and text[-1] in sentence_delimiters
503 | and self.last_realtime_text[-1] in sentence_delimiters
504 | )
505 |
506 | self.last_realtime_text = text
507 |
508 | if prob_sentence_end:
509 | if FAST_SENTENCE_END:
510 | self.text_retrieval_thread.recorder.stop()
511 | else:
512 | self.text_retrieval_thread.recorder.post_speech_silence_duration = SILENCE_THRESHS[0]
513 | else:
514 | self.text_retrieval_thread.recorder.post_speech_silence_duration = SILENCE_THRESHS[1]
515 |
516 | self.last_realtime_text = text
517 |
518 | self.text_detected(text)
519 |
520 | def text_detected(self, text):
521 |
522 | try:
523 | sentences_with_style = []
524 | for i, sentence in enumerate(self.full_sentences):
525 | sentence_text, speaker_embedding = sentence
526 | sentence_tshort = sentence_text[:40]
527 | if i >= len(self.sentence_speakers):
528 | print(f"Index {i} out of range")
529 | color = "#FFFFFF"
530 | else:
531 | speaker_index = self.sentence_speakers[i]
532 | color = COLOR_TABLE_HEX[speaker_index % len(COLOR_TABLE_HEX)]
533 |
534 | sentences_with_style.append(
535 | f'{sentence_text}')
536 |
537 | for pending_sentence in self.pending_sentences:
538 | sentences_with_style.append(
539 | f'{pending_sentence}')
540 |
541 | new_text = " ".join(sentences_with_style).strip() + " " + text if len(sentences_with_style) > 0 else text
542 |
543 | if new_text != self.displayed_text:
544 | self.displayed_text = new_text
545 | self.update_text(new_text)
546 | except Exception as e:
547 | print(f"Error: {e}")
548 |
549 |
550 | def process_final(self, text, bytes):
551 | text = text.strip()
552 | if text:
553 | try:
554 | self.pending_sentences.append(text)
555 | self.queue.put((text, bytes))
556 | except Exception as e:
557 | print(f"Error: {e}")
558 |
559 | def recording_thread(self, stream):
560 | while True:
561 | data = stream.read(BUFFER_SIZE)
562 | self.text_retrieval_thread.recorder.feed(data)
563 |
564 | def capture_output_and_feed_to_recorder(self):
565 | self.recording_thread = RecordingThread(
566 | self.text_retrieval_thread.recorder)
567 | self.recording_thread.start()
568 |
569 | def recorder_ready(self):
570 | print("Recorder ready")
571 | self.update_text("Ready to record")
572 |
573 | self.capture_output_and_feed_to_recorder()
574 |
575 | def init(self):
576 | self.start_tts()
577 |
578 | print("Starting recorder thread")
579 | self.text_retrieval_thread = TextRetrievalThread()
580 | self.text_retrieval_thread.recorderStarted.connect(
581 | self.recorder_ready)
582 | self.text_retrieval_thread.textRetrievedLive.connect(
583 | self.process_live_text)
584 | self.text_retrieval_thread.textRetrievedFinal.connect(
585 | self.process_final)
586 | self.text_retrieval_thread.start()
587 |
588 | self.worker_thread = SentenceWorker(self.queue, self.tts)
589 | self.worker_thread.sentence_update_signal.connect(
590 | self.sentence_updated)
591 | self.worker_thread.start()
592 |
593 | def sentence_updated(self, full_sentences,sentence_speakers):
594 | self.pending_text = ""
595 | self.full_sentences = full_sentences
596 | # self.speakers = speakers
597 | self.sentence_speakers = sentence_speakers
598 | for sentence in self.full_sentences:
599 | sentence_text, speaker_embedding = sentence
600 | if sentence_text in self.pending_sentences:
601 | self.pending_sentences.remove(sentence_text)
602 | self.text_detected("")
603 |
604 | def start_tts(self):
605 | print("Loading TTS model")
606 | device = torch.device("cuda")
607 | local_models_path = os.environ.get("COQUI_MODEL_PATH")
608 | checkpoint = os.path.join(local_models_path, "v2.0.2")
609 | config = load_config((os.path.join(checkpoint, "config.json")))
610 | self.tts = setup_tts_model(config)
611 | self.tts.load_checkpoint(
612 | config,
613 | checkpoint_dir=checkpoint,
614 | checkpoint_path=None,
615 | vocab_path=None,
616 | eval=True,
617 | use_deepspeed=False,
618 | )
619 | self.tts.to(device)
620 | print("TTS model loaded")
621 |
622 | def set_text(self, text):
623 | self.update_thread = TextUpdateThread(text)
624 | self.update_thread.text_update_signal.connect(self.update_text)
625 | self.update_thread.start()
626 |
627 | def update_text(self, text):
628 | self.text_edit.setHtml(text)
629 | self.text_edit.verticalScrollBar().setValue(
630 | self.text_edit.verticalScrollBar().maximum())
631 |
632 |
633 | def main():
634 | app = QApplication(sys.argv)
635 |
636 | dark_stylesheet = """
637 | QMainWindow {
638 | background-color: #323232;
639 | }
640 | QTextEdit {
641 | background-color: #1e1e1e;
642 | color: #ffffff;
643 | }
644 | """
645 | app.setStyleSheet(dark_stylesheet)
646 |
647 | main_window = MainWindow()
648 | main_window.show()
649 |
650 | sys.exit(app.exec())
651 |
652 |
653 | if __name__ == "__main__":
654 | main()
655 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | aiohttp==3.9.3
3 | aiosignal==1.3.1
4 | annotated-types==0.6.0
5 | anyascii==0.3.2
6 | anyio==4.3.0
7 | asttokens==2.4.1
8 | attrs==23.2.0
9 | audioread==3.0.1
10 | av==11.0.0
11 | azure-cognitiveservices-speech==1.36.0
12 | Babel==2.14.0
13 | bangla==0.0.2
14 | blinker==1.7.0
15 | blis==0.7.11
16 | bnnumerizer==0.0.2
17 | bnunicodenormalizer==0.1.6
18 | catalogue==2.0.10
19 | certifi==2024.2.2
20 | cffi==1.16.0
21 | charset-normalizer==3.3.2
22 | click==8.1.7
23 | cloudpathlib==0.16.0
24 | colorama==0.4.6
25 | coloredlogs==15.0.1
26 | comtypes==1.3.1
27 | confection==0.1.4
28 | contourpy==1.2.0
29 | coqpit==0.0.17
30 | ctranslate2==4.1.0
31 | cycler==0.12.1
32 | cymem==2.0.8
33 | Cython==3.0.9
34 | dateparser==1.1.8
35 | decorator==5.1.1
36 | distro==1.9.0
37 | docopt==0.6.2
38 | einops==0.7.0
39 | elevenlabs==0.2.27
40 | emoji==2.8.0
41 | encodec==0.1.1
42 | enum34==1.1.10
43 | executing==2.0.1
44 | faster-whisper==1.0.1
45 | ffmpeg-python==0.2.0
46 | filelock==3.9.0
47 | Flask==3.0.2
48 | flatbuffers==24.3.25
49 | fonttools==4.50.0
50 | frozenlist==1.4.1
51 | fsspec==2024.3.1
52 | future==1.0.0
53 | g2pkk==0.1.2
54 | grpcio==1.62.1
55 | gruut==2.2.3
56 | gruut-ipa==0.13.0
57 | gruut_lang_de==2.0.0
58 | gruut_lang_en==2.0.0
59 | gruut_lang_es==2.0.0
60 | gruut_lang_fr==2.0.2
61 | h11==0.14.0
62 | halo==0.0.31
63 | hangul-romanize==0.1.0
64 | httpcore==1.0.5
65 | httpx==0.27.0
66 | huggingface-hub==0.22.2
67 | humanfriendly==10.0
68 | idna==3.6
69 | inflect==7.0.0
70 | ipython==8.22.2
71 | itsdangerous==2.1.2
72 | jamo==0.4.1
73 | jedi==0.19.1
74 | jieba==0.42.1
75 | Jinja2==3.1.2
76 | joblib==1.3.2
77 | jsonlines==1.2.0
78 | kiwisolver==1.4.5
79 | langcodes==3.3.0
80 | lazy_loader==0.3
81 | librosa==0.10.1
82 | llvmlite==0.42.0
83 | log-symbols==0.0.14
84 | Markdown==3.6
85 | MarkupSafe==2.1.3
86 | matplotlib==3.8.3
87 | matplotlib-inline==0.1.6
88 | more-itertools==10.2.0
89 | mpmath==1.3.0
90 | msgpack==1.0.8
91 | multidict==6.0.5
92 | murmurhash==1.0.10
93 | networkx==2.8.8
94 | nltk==3.8.1
95 | num2words==0.5.13
96 | numba==0.59.1
97 | numpy==1.26.4
98 | onnxruntime==1.17.1
99 | openai==1.13.3
100 | openai-whisper==20231117
101 | packaging==24.0
102 | pandas==1.5.3
103 | parso==0.8.3
104 | pillow==10.2.0
105 | platformdirs==4.2.0
106 | pooch==1.8.1
107 | preshed==3.0.9
108 | prompt-toolkit==3.0.43
109 | protobuf==5.26.1
110 | psutil==5.9.8
111 | pure-eval==0.2.2
112 | pvporcupine==1.9.5
113 | pyannote-audio==3.1.1
114 | PyAudio==0.2.14
115 | pycparser==2.22
116 | pydantic==2.6.4
117 | pydantic_core==2.16.3
118 | pydub==0.25.1
119 | Pygments==2.17.2
120 | pynndescent==0.5.12
121 | pyparsing==3.1.2
122 | pypinyin==0.51.0
123 | pypiwin32==223
124 | pyreadline3==3.4.1
125 | pysbd==0.3.4
126 | python-crfsuite==0.9.10
127 | python-dateutil==2.9.0.post0
128 | pyttsx3==2.90
129 | pytz==2024.1
130 | pywin32==306
131 | PyYAML==6.0.1
132 | RealTimeSTT==0.1.13
133 | RealTimeTTS==0.3.44
134 | regex==2023.12.25
135 | requests==2.31.0
136 | safetensors==0.4.2
137 | scikit-learn==1.4.1.post1
138 | scipy==1.12.0
139 | six==1.16.0
140 | smart-open==6.4.0
141 | sniffio==1.3.1
142 | soundfile==0.12.1
143 | soxr==0.3.7
144 | spacy==3.7.4
145 | spacy-legacy==3.0.12
146 | spacy-loggers==1.0.5
147 | spinners==0.0.24
148 | srsly==2.4.8
149 | stable-ts==2.15.10
150 | stack-data==0.6.3
151 | stanza==1.6.1
152 | stream2sentence==0.2.3
153 | SudachiDict-core==20240109
154 | SudachiPy==0.6.8
155 | sympy==1.12
156 | tensorboard==2.16.2
157 | tensorboard-data-server==0.7.2
158 | termcolor==2.4.0
159 | thinc==8.2.3
160 | threadpoolctl==3.4.0
161 | tiktoken==0.6.0
162 | tokenizers==0.15.2
163 | torch==2.2.2+cu118
164 | torchaudio==2.2.2+cu118
165 | tqdm==4.66.2
166 | trainer==0.0.36
167 | traitlets==5.14.2
168 | transformers==4.39.2
169 | TTS==0.22.0
170 | typer==0.9.4
171 | typing_extensions==4.8.0
172 | tzdata==2024.1
173 | tzlocal==5.2
174 | umap-learn==0.5.5
175 | Unidecode==1.3.8
176 | urllib3==2.2.1
177 | wasabi==1.1.2
178 | wcwidth==0.2.13
179 | weasel==0.3.4
180 | webrtcvad==2.0.10
181 | websockets==12.0
182 | Werkzeug==3.0.1
183 | yarl==1.9.4
184 | yt-dlp==2024.3.10
185 |
--------------------------------------------------------------------------------
/speaker_diarize.py:
--------------------------------------------------------------------------------
1 | """
2 | Speaker Diarization
3 |
4 | Idea:
5 | - create 1D embeddings specs from sentences
6 | - for every sentence
7 | - find most similar 10% other sentences
8 | - average out the 1Ds and make a "speech group" embedding from that
9 | - for every sentence
10 | - compare speech group embedding with all other sentence speech group embeddings
11 | - find the two speech groups with least similar embeddings
12 | - the 2 "speech group" embedding from that will be our "speaker" characteristics 1D embeddings
13 | - for every sentence
14 | - find cosine similarity between the sentence and the two "speaker" characteristics 1D embeddings
15 | - assign to the speaker with higher similarity
16 |
17 | => every sentence assigned to one to two speakers
18 |
19 | notes:
20 | - cut out every < 3s file before processing
21 |
22 | """
23 |
24 | from TTS.tts.models import setup_model as setup_tts_model
25 | from scipy.spatial.distance import cosine
26 | from TTS.config import load_config
27 | import librosa.display
28 | import librosa
29 | import numpy as np
30 | import shutil
31 | import torch
32 | import os
33 |
34 | input_directory = 'output_sentences_wav'
35 | output_directory = 'output_speakers'
36 | max_sentences = 1000000
37 | group_percentage = 0.1
38 | minimum_duration = 1
39 | only_keep_most_confident_percentage = 0.8
40 |
41 | data = []
42 |
43 | device = torch.device("cuda")
44 | local_models_path = os.environ.get("COQUI_MODEL_PATH")
45 | checkpoint = os.path.join(local_models_path, "v2.0.2")
46 | config = load_config((os.path.join(checkpoint, "config.json")))
47 | tts = setup_tts_model(config)
48 | tts.load_checkpoint(
49 | config,
50 | checkpoint_dir=checkpoint,
51 | checkpoint_path=None,
52 | vocab_path=None,
53 | eval=True,
54 | use_deepspeed=False,
55 | )
56 | tts.to(device)
57 | print("TTS model loaded")
58 |
59 | # create 1D embeddings from sentences
60 | count = 0
61 | for filename in os.listdir(input_directory):
62 | if filename.endswith(".wav"):
63 | count += 1
64 | if count > max_sentences:
65 | break
66 |
67 | # skip if file is too short
68 | y, sr = librosa.load(os.path.join(input_directory, filename))
69 | if librosa.get_duration(y=y, sr=sr) < minimum_duration:
70 | continue
71 |
72 | full_path = os.path.join(input_directory, filename)
73 | print(full_path)
74 |
75 | gpt_cond_latent, speaker_embedding = tts.get_conditioning_latents(audio_path=full_path, gpt_cond_len=30, max_ref_length=60)
76 | spealer_embedding = speaker_embedding.cpu().squeeze().half().tolist()
77 | speaker_embedding_1D = speaker_embedding.view(-1).cpu().detach().numpy() # Reshape to 1D then convert to NumPy
78 |
79 | entry = {
80 | 'filename': filename,
81 | 'speaker_embeds_1D': speaker_embedding_1D
82 | }
83 | data.append(entry)
84 | else:
85 | continue
86 |
87 | # Find most similar 10% other sentences
88 | # Calculate 10% of the number of sentences, at least 1
89 | num_top_sentences = max(1, int(group_percentage * len(data)))
90 | print(f"Sentences per group: {num_top_sentences}")
91 |
92 | # Find speech group embedding of sentence
93 | for index, entry in enumerate(data):
94 | similarities = []
95 | embedding = entry['speaker_embeds_1D']
96 |
97 | # Compute similarities with other sentences
98 | for index_compare, compare_entry in enumerate(data):
99 | if index_compare != index:
100 | embedding_compare = compare_entry['speaker_embeds_1D']
101 | similarity = 1 - cosine(embedding, embedding_compare)
102 | similarities.append((similarity, embedding_compare))
103 |
104 | # Sort by similarity and pick top 10%
105 | similarities.sort(reverse=True, key=lambda x: x[0])
106 | top_similar_embeddings = [x[1] for x in similarities[:num_top_sentences]]
107 |
108 | # Step 2: Average out the 1Ds and make a "speech group" embeddinngs
109 | speech_group_embedding = np.mean(np.array(top_similar_embeddings), axis=0)
110 |
111 | # Step 3: Store the speech group embedding in data
112 | entry['speech_group_embed'] = speech_group_embedding
113 |
114 | # Find speakers by comparing speech group embeddings
115 | for index, entry in enumerate(data):
116 | similarities = []
117 | embedding = entry['speech_group_embed']
118 | for index_compare, compare_entry in enumerate(data):
119 | if index_compare != index:
120 | embedding_compare = compare_entry['speech_group_embed']
121 | similarity = 1 - cosine(embedding, embedding_compare)
122 | similarities.append((similarity, embedding_compare))
123 |
124 | # Sort by similarity and pick least similar
125 | similarities.sort(reverse=False, key=lambda x: x[0])
126 | least_similar_embed = similarities[0][1]
127 | entry['least_similarity'] = similarities[0][0]
128 | entry['least_similar_embed'] = least_similar_embed
129 |
130 | # Find entry with least similarity
131 | data.sort(reverse=False, key=lambda x: x['least_similarity'])
132 | least_similar_entry = data[0]
133 |
134 | embed_speaker_1 = least_similar_entry['speech_group_embed']
135 | embed_speaker_2 = least_similar_entry['least_similar_embed']
136 |
137 | for entry in data:
138 | similarity_1 = 1 - cosine(entry['speaker_embeds_1D'], embed_speaker_1)
139 | similarity_2 = 1 - cosine(entry['speaker_embeds_1D'], embed_speaker_2)
140 |
141 | if similarity_1 > similarity_2:
142 | entry['speaker'] = 1
143 | entry['confidence'] = similarity_1 - similarity_2
144 | else:
145 | entry['speaker'] = 2
146 | entry['confidence'] = similarity_2 - similarity_1
147 |
148 | print(f"Speaker {entry['speaker']} assigned to {entry['filename']} with confidence {entry['confidence']}")
149 |
150 | # Remove the least confident
151 | data.sort(reverse=True, key=lambda x: x['confidence'])
152 | data = data[:int(len(data) * only_keep_most_confident_percentage)]
153 |
154 | # Ensure output directories exist
155 | for speaker_id in [1, 2]:
156 | speaker_dir = os.path.join(output_directory, f"speaker_{speaker_id}")
157 | if not os.path.exists(speaker_dir):
158 | os.makedirs(speaker_dir)
159 |
160 | # Copy files to the corresponding speaker directory
161 | for entry in data:
162 | speaker = entry['speaker']
163 | filename = entry['filename']
164 | source_path = os.path.join(input_directory, filename)
165 | destination_path = os.path.join(output_directory, f"speaker_{speaker}", filename)
166 |
167 | # Copy the file
168 | shutil.copy(source_path, destination_path)
169 | print(f"Copied {filename} to {destination_path}")
170 |
--------------------------------------------------------------------------------
/split_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import tempfile
4 | from typing import List
5 | import ffmpeg
6 | import concurrent.futures
7 | import time
8 | import json
9 | from faster_whisper import WhisperModel
10 | import stable_whisper
11 | import multiprocessing
12 | from cleaner import multilingual_cleaners
13 |
14 |
15 | # input audio file
16 | language = "en"
17 | input_audio_files = [
18 | "input/CoinToss.mp3"
19 | # "input/SamuelLJackson.mp3",
20 | # "input/Elon Musk Podcast #49.mp3",
21 | # "input/Elon Musk Podcast #252.mp3",
22 | # "input/Elon Musk Podcast #400.mp3",
23 | ]
24 | output_directory = 'output_sentences'
25 |
26 | # for faster transcription disable transcript refinement
27 | # and use a smaller model like tiny, tiny.en, small, medium
28 | TRANSCRIPT_REFINEMENT = True
29 | whisper_model = "tiny.en"
30 |
31 | extend_detected_borders_start = 0.05
32 | extend_detected_borders_end = 0.15
33 |
34 | # https://github.com/coqui-ai/TTS/blob/dev/TTS/tts/layers/xtts/tokenizer.py#L597
35 | max_text_len = 250
36 | max_processes = 1
37 |
38 | MB = 1024 * 1024 # Bytes in a Megabyte
39 | CHUNK_SIZE_MB = 20 # Desired file chunk size in MB
40 |
41 | def find_optimal_breakpoints(points: List[float], n: int) -> List[float]:
42 | result = []
43 | optimal_length = points[-1] / n
44 | temp = 0
45 | temp_a = 0
46 | l = len(points)
47 | for i in points[:l - 1]:
48 | if (i - temp_a) >= optimal_length:
49 | if optimal_length - (temp - temp_a) < (i - temp_a) - optimal_length:
50 | result.append(temp)
51 | else:
52 | result.append(i)
53 | temp_a = result[-1]
54 | temp = i
55 | return result
56 |
57 |
58 | def split_audio_into_chunks(input_file: str, max_chunks: int,
59 | silence_threshold: str = "-20dB", silence_duration: float = 2.0) -> List[str]:
60 | def save_chunk_to_temp_file(input_file: str, start: float, end: float) -> str:
61 | temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=file_extension)
62 | temp_file.close()
63 |
64 | in_stream = ffmpeg.input(input_file)
65 | (
66 | ffmpeg.output(in_stream, temp_file.name, ss=start, t=end - start, c="copy")
67 | .overwrite_output()
68 | .run()
69 | )
70 |
71 | return temp_file.name, end - start
72 |
73 | def get_silence_starts(input_file: str) -> List[float]:
74 | silence_starts = [0.0]
75 |
76 | reader = (
77 | ffmpeg.input(input_file)
78 | .filter("silencedetect", n=silence_threshold, d=str(silence_duration))
79 | .output("pipe:", format="null")
80 | .run_async(pipe_stderr=True)
81 | )
82 |
83 | silence_end_re = re.compile(
84 | r" silence_end: (?P[0-9]+(\.?[0-9]*)) \| silence_duration: (?P[0-9]+(\.?[0-9]*))"
85 | )
86 |
87 | while True:
88 | line = reader.stderr.readline().decode("utf-8")
89 | if not line:
90 | break
91 |
92 | match = silence_end_re.search(line)
93 | if match:
94 | silence_end = float(match.group("end"))
95 | silence_dur = float(match.group("dur"))
96 | silence_start = silence_end - silence_dur
97 | silence_starts.append(silence_start)
98 |
99 | return silence_starts
100 |
101 | file_extension = os.path.splitext(input_file)[1]
102 | metadata = ffmpeg.probe(input_file)
103 | duration = float(metadata["format"]["duration"])
104 |
105 | silence_starts = get_silence_starts(input_file)
106 | silence_starts.append(duration)
107 |
108 | temp_files = []
109 | lengths = []
110 | current_chunk_start = 0.0
111 |
112 | n = max_chunks
113 | selected_items = find_optimal_breakpoints(silence_starts, n)
114 | selected_items.append(duration)
115 |
116 | for j in range(0, len(selected_items)):
117 | temp_file_path, length = save_chunk_to_temp_file(input_file, current_chunk_start, selected_items[j])
118 | temp_files.append(temp_file_path)
119 | lengths.append(length)
120 |
121 | current_chunk_start = selected_items[j]
122 |
123 | return temp_files, lengths
124 |
125 |
126 | def transcribe_file(file_name, model):
127 | """
128 | Transcribes a audio file with stable_whisper,
129 | returns transcript and word timestamps.
130 | """
131 | result = model.transcribe(
132 | file_name,
133 | word_timestamps=True,
134 | vad=True,
135 | language="en",
136 | suppress_silence=True,
137 | regroup=False # disable default regrouping logic
138 | )
139 |
140 | if TRANSCRIPT_REFINEMENT:
141 | result = model.refine(
142 | file_name,
143 | result,
144 | precision=0.05,
145 | )
146 |
147 | result = (
148 | result.clamp_max()
149 | .split_by_punctuation([('.', ' '), '。', '?', '?', (',', ' '), ','])
150 | .split_by_gap(.4)
151 | .merge_by_gap(.2, max_words=3)
152 | .split_by_punctuation([('.', ' '), '。', '?', '?'])
153 | )
154 |
155 | file_name_base, _ = os.path.splitext(file_name)
156 | result.save_as_json(file_name_base + "_transcript.json")
157 | return result, file_name
158 |
159 |
160 | def format_seconds_to_hms(seconds):
161 | hours = int(seconds // 3600)
162 | minutes = int((seconds % 3600) // 60)
163 | seconds = seconds % 60
164 | return f"{hours:02}:{minutes:02}:{seconds:04.1f}"
165 |
166 |
167 | # Function to calculate number of chunks based on file size
168 | def calculate_max_chunks(file_path, chunk_size_mb):
169 | file_size_bytes = os.path.getsize(file_path)
170 | file_size_mb = file_size_bytes / MB
171 | return max(1, int(file_size_mb / chunk_size_mb))
172 |
173 |
174 | def transcribe_audio(input_file: str, max_processes = 0,
175 | silence_threshold: str = "-20dB", silence_duration: float = 2.0, model=None) -> str:
176 | if max_processes > multiprocessing.cpu_count() or max_processes == 0:
177 | max_processes = multiprocessing.cpu_count()
178 |
179 |
180 | # Calculate max chunks based on file size
181 | max_chunks = calculate_max_chunks(input_file, CHUNK_SIZE_MB)
182 |
183 | # Split the audio into chunks
184 | temp_files_array, lengths = split_audio_into_chunks(input_file, max_chunks, silence_threshold, silence_duration)
185 | print(f"Split audio into {len(temp_files_array)} chunks")
186 | start = time.time()
187 | futures = []
188 |
189 | # Submit each file to the thread pool and store the corresponding future object
190 | with concurrent.futures.ThreadPoolExecutor(max_processes) as executor:
191 | for file_path in temp_files_array:
192 | future = executor.submit(transcribe_file, file_path, model)
193 | futures.append(future)
194 |
195 | offset = 0.0
196 | offsets = []
197 | for index, file_path in enumerate(temp_files_array):
198 | offsets.append(offset)
199 | offset += lengths[index]
200 |
201 | sentences = []
202 | for future in futures:
203 | segments, filename = future.result()
204 |
205 | for segment in segments:
206 | if len(segment.words) == 0:
207 | continue
208 |
209 | sentence_text = ""
210 | sentence_start = -1
211 | sentence_end = -1
212 | for segword in segment.words:
213 | if sentence_start == -1:
214 | sentence_start = segword.start
215 | sentence_text += segword.word
216 | sentence_end = segword.end
217 |
218 | file_index = temp_files_array.index(filename)
219 | sentence_start += offsets[file_index]
220 | sentence_end += offsets[file_index]
221 |
222 | if len(sentence_text) > max_text_len:
223 | print(f"Skipping long sentence: {sentence_text}")
224 | continue
225 | sentences.append((sentence_start, sentence_end, sentence_text))
226 |
227 | end = time.time()
228 | print(end - start)
229 |
230 | # Remember to remove the temporary files after you're done processing them
231 | for temp_file in temp_files_array:
232 | os.remove(temp_file)
233 |
234 | return sentences
235 |
236 |
237 | def ends_with_sentence_ending(sentence):
238 | return sentence.strip().endswith(('.', '?', '!'))
239 |
240 |
241 | # Function to merge sentences
242 | def merge_sentences(sentences):
243 | merged_sentences = []
244 | temp_sentence = ""
245 | temp_start = None
246 |
247 | for i in range(len(sentences)):
248 | start, end, text = sentences[i]
249 |
250 | if not temp_sentence:
251 | temp_start = start # Set start time for a new group of sentences
252 |
253 | if temp_sentence:
254 | text = temp_sentence + text
255 | temp_sentence = ""
256 |
257 | if not ends_with_sentence_ending(text):
258 | temp_sentence = text
259 | continue
260 |
261 | merged_sentences.append((temp_start, end, text.strip()))
262 |
263 | # Handle the last sentence if it doesn't end with a sentence-ending character
264 | if temp_sentence:
265 | last_start, last_end, _ = sentences[-1]
266 | merged_sentences.append((last_start, last_end, temp_sentence.strip()))
267 |
268 | return merged_sentences
269 |
270 |
271 | def check_transcription_file(audio_file):
272 | """
273 | Check for an existing transcription file for the given audio file.
274 | Returns the path to the transcription file if it exists, None otherwise.
275 | """
276 | base, _ = os.path.splitext(audio_file)
277 | transcription_file = f"{base}_transcription.json"
278 | if os.path.exists(transcription_file):
279 | return transcription_file
280 | return None
281 |
282 |
283 | def load_transcription(transcription_file):
284 | """
285 | Load transcription from a file.
286 | """
287 | with open(transcription_file, 'r') as file:
288 | return json.load(file)
289 |
290 |
291 | def save_transcription(transcription, audio_file):
292 | """
293 | Save transcription to a file.
294 | """
295 | base, _ = os.path.splitext(audio_file)
296 | transcription_file = f"{base}_transcription.json"
297 | with open(transcription_file, 'w') as file:
298 | json.dump(transcription, file)
299 |
300 |
301 | def format_seconds_to_hms_full_seconds(seconds):
302 | hours = int(seconds // 3600)
303 | minutes = int((seconds % 3600) // 60)
304 | seconds = seconds % 60
305 | fraction_seconds = int((seconds % 1) * 10)
306 | seconds = int(seconds % 60)
307 | return f"{hours:02d}{minutes:02d}{seconds:02d}{fraction_seconds}"
308 |
309 |
310 | def sanitize_filename(text):
311 | """
312 | Sanitize the sentence text to be safe for use in file names.
313 | Replace problematic characters with underscores.
314 | """
315 | return re.sub(r"[\\/*?\"<>|:']", "_", text)
316 |
317 |
318 | def save_audio_segment(input_file, start_seconds, end_seconds, sentence_text, output_dir):
319 | """
320 | Save an audio segment from input_file between start and end times.
321 | The output file name is derived from start, end, and sentence_text.
322 | """
323 |
324 | # Convert seconds to formatted string
325 | start_formatted = format_seconds_to_hms_full_seconds(start_seconds)
326 | end_formatted = format_seconds_to_hms_full_seconds(end_seconds)
327 |
328 | # Sanitize the sentence text for file naming
329 | safe_sentence_text = sanitize_filename(sentence_text)
330 | safe_sentence_text = safe_sentence_text[:25]
331 |
332 | # Prepare the output file name
333 | file_name = f"{start_formatted}-{end_formatted}_{len(safe_sentence_text)}_{safe_sentence_text[:15].replace(' ', '_').replace('/', '_')}.mp3"
334 |
335 | output_path = os.path.join(output_dir, file_name)
336 |
337 | # Use ffmpeg library to cut the audio segment
338 | (
339 | ffmpeg
340 | .input(input_file, ss=start_seconds, to=end_seconds)
341 | .output(output_path, c="copy")
342 | .run(overwrite_output=True)
343 | )
344 |
345 | return output_path
346 |
347 |
348 | def preprocess(sentences):
349 | for index, sentence in enumerate(sentences):
350 | start, end, text = sentence
351 | text_before = text
352 | text = multilingual_cleaners(text_before, language)
353 | if text != text_before:
354 | print(f"Preprocessed {text_before} to {text}")
355 | sentences[index] = (start, end, text)
356 |
357 |
358 | if __name__ == "__main__":
359 | # Check for an existing transcription file
360 | model = None
361 |
362 | for input_audio in input_audio_files:
363 | print(f"Processing {input_audio}")
364 |
365 | transcription_file = check_transcription_file(input_audio)
366 | if transcription_file:
367 | # Load transcription from the file
368 | sentences = load_transcription(transcription_file)
369 | else:
370 | # Perform transcription
371 | if model is None:
372 | model = stable_whisper.load_model(whisper_model)
373 |
374 | sentences = transcribe_audio(
375 | input_audio,
376 | max_processes,
377 | silence_threshold="-20dB",
378 | silence_duration=2,
379 | model=model)
380 |
381 | # Merging sentences
382 | sentences = merge_sentences(sentences)
383 |
384 | # Save transcription to a file
385 | save_transcription(sentences, input_audio)
386 |
387 | # Preprocess sentences
388 | # Prepare texts for optimal training
389 | print("Preprocessing sentences texts")
390 |
391 | preprocess(sentences)
392 |
393 | # Remove sentences with 0 or negative duration before merging
394 | sentences = [sentence for sentence in sentences if sentence[1] > sentence[0]]
395 |
396 | for index, sentence in enumerate(sentences):
397 | start, end, text = sentence
398 | if end <= start:
399 | print(f"Pretest Skipping {text} ({start}-{end}) due to negative duration")
400 | if index > 0:
401 | _, prev_end, _ = sentences[index - 1]
402 | if start < prev_end:
403 | print(f"Pretest Skipping {text} ({start}-{end}) due to overlap")
404 |
405 | # Filter out sentences with text longer than max_text_len
406 | final_sentences = []
407 | for sentence in sentences:
408 | if len(sentence[2]) > max_text_len:
409 | print(f"Removed: {sentence[2]} (Text too long: length {len(sentence[2])} > max_text_len {max_text_len})")
410 | continue
411 | final_sentences.append(sentence)
412 |
413 | sentences = final_sentences
414 |
415 | # Write sentences to disk
416 | print("Writing sentences to disk")
417 | # Ensure output directory exists
418 | if not os.path.exists(output_directory):
419 | os.makedirs(output_directory)
420 |
421 |
422 | for index, sentence in enumerate(sentences):
423 | start, end, text = sentence
424 |
425 | new_start = start
426 | if index > 0:
427 | _, prev_end, _ = sentences[index - 1]
428 | middle = start - (start - prev_end) / 2
429 | new_start = middle
430 | if start - middle > extend_detected_borders_start:
431 | new_start = start - extend_detected_borders_start
432 |
433 | new_end = end
434 | if index < len(sentences) - 1:
435 | next_start, _, _ = sentences[index + 1]
436 | middle = end + (next_start - end) / 2
437 | new_end = middle
438 | if middle - end > extend_detected_borders_end:
439 | new_end = end + extend_detected_borders_end
440 |
441 | startf = format_seconds_to_hms(new_start)
442 | endf = format_seconds_to_hms(new_end)
443 |
444 | print(f"{startf}-{endf}: {text}")
445 |
446 | if new_end < new_start:
447 | print(f"Skipping {text} ({new_start}-{new_end}) due to negative duration")
448 | continue
449 |
450 | save_audio_segment(input_audio, new_start, new_end, text, output_directory)
451 |
452 |
--------------------------------------------------------------------------------