├── packages.txt ├── requirements.txt ├── km_phonemizer.npz ├── cog.yaml ├── .github └── workflows │ └── build_cog.yml ├── LICENSE ├── losses.py ├── config.json ├── khmer_phonemizer.py ├── .gitattributes ├── README.md ├── infer.py ├── export_onnx.py ├── infer_onnx.py ├── app.py ├── mel_processing.py ├── .gitignore ├── commons.py ├── utils.py ├── g2p.py ├── transforms.py ├── attentions.py ├── modules.py ├── models.py └── wavfile.py /packages.txt: -------------------------------------------------------------------------------- 1 | libsndfile1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.25.2 2 | librosa 3 | scipy 4 | torch 5 | torchvision 6 | Unidecode 7 | gradio 8 | khmernormalizer 9 | monotonic-align -------------------------------------------------------------------------------- /km_phonemizer.npz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b97935ae20ae5533baa2096a48d0728bf90291a488fceccbf7e700018bc38ffe 3 | size 34352402 4 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | system_packages: 4 | - libsndfile1 5 | - ffmpeg 6 | python_packages: 7 | - numpy==1.25.2 8 | - librosa 9 | - scipy 10 | - torch==2.1.1 11 | - Unidecode 12 | - gradio 13 | - khmernormalizer 14 | - monotonic-align -------------------------------------------------------------------------------- /.github/workflows/build_cog.yml: -------------------------------------------------------------------------------- 1 | name: Build Cog 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Free Disk Space (Ubuntu) 14 | uses: jlumbroso/free-disk-space@main 15 | with: 16 | tool-cache: false 17 | android: false 18 | dotnet: false 19 | haskell: false 20 | large-packages: true 21 | docker-images: true 22 | swap-storage: true 23 | 24 | - name: Check out code 25 | uses: actions/checkout@v3 26 | 27 | - name: Setup Cog 28 | uses: replicate/setup-cog@v1 29 | 30 | - name: download weight 31 | run: curl -L https://huggingface.co/spaces/seanghay/KLEA/resolve/main/G_60000.pth -o G_60000.pth 32 | 33 | - name: Build 34 | run: | 35 | cog build -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Seanghay Yath 4 | Copyright (c) 2021 Jaehyeon Kim 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def feature_loss(fmap_r, fmap_g): 4 | loss = 0 5 | for dr, dg in zip(fmap_r, fmap_g): 6 | for rl, gl in zip(dr, dg): 7 | rl = rl.float().detach() 8 | gl = gl.float() 9 | loss += torch.mean(torch.abs(rl - gl)) 10 | 11 | return loss * 2 12 | 13 | 14 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 15 | loss = 0 16 | r_losses = [] 17 | g_losses = [] 18 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 19 | dr = dr.float() 20 | dg = dg.float() 21 | r_loss = torch.mean((1-dr)**2) 22 | g_loss = torch.mean(dg**2) 23 | loss += (r_loss + g_loss) 24 | r_losses.append(r_loss.item()) 25 | g_losses.append(g_loss.item()) 26 | 27 | return loss, r_losses, g_losses 28 | 29 | 30 | def generator_loss(disc_outputs): 31 | loss = 0 32 | gen_losses = [] 33 | for dg in disc_outputs: 34 | dg = dg.float() 35 | l = torch.mean((1-dg)**2) 36 | gen_losses.append(l) 37 | loss += l 38 | 39 | return loss, gen_losses 40 | 41 | 42 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 43 | """ 44 | z_p, logs_q: [b, h, t_t] 45 | m_p, logs_p: [b, h, t_t] 46 | """ 47 | z_p = z_p.float() 48 | logs_q = logs_q.float() 49 | m_p = m_p.float() 50 | logs_p = logs_p.float() 51 | z_mask = z_mask.float() 52 | 53 | kl = logs_p - logs_q - 0.5 54 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) 55 | kl = torch.sum(kl * z_mask) 56 | l = kl / torch.sum(z_mask) 57 | return l 58 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 20000, 7 | "learning_rate": 2e-4, 8 | "betas": [0.8, 0.99], 9 | "eps": 1e-9, 10 | "batch_size": 32, 11 | "fp16_run": true, 12 | "lr_decay": 0.999875, 13 | "segment_size": 8192, 14 | "init_lr_ratio": 1, 15 | "warmup_epochs": 0, 16 | "c_mel": 45, 17 | "c_kl": 1.0 18 | }, 19 | "data": { 20 | "training_files":"filelists/kheng_text_train_filelist.txt.cleaned", 21 | "validation_files":"filelists/kheng_text_test_filelist.txt.cleaned", 22 | "text_cleaners":["english_cleaners2"], 23 | "max_wav_value": 32768.0, 24 | "sampling_rate": 22050, 25 | "filter_length": 1024, 26 | "hop_length": 256, 27 | "win_length": 1024, 28 | "n_mel_channels": 80, 29 | "mel_fmin": 0.0, 30 | "mel_fmax": null, 31 | "add_blank": true, 32 | "n_speakers": 0, 33 | "cleaned_text": true 34 | }, 35 | "model": { 36 | "inter_channels": 192, 37 | "hidden_channels": 192, 38 | "filter_channels": 768, 39 | "n_heads": 2, 40 | "n_layers": 6, 41 | "kernel_size": 3, 42 | "p_dropout": 0.1, 43 | "resblock": "1", 44 | "resblock_kernel_sizes": [3,7,11], 45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], 46 | "upsample_rates": [8,8,2,2], 47 | "upsample_initial_channel": 512, 48 | "upsample_kernel_sizes": [16,16,4,4], 49 | "n_layers_q": 3, 50 | "use_spectral_norm": false 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /khmer_phonemizer.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Khmer Phonemizer - A Free, Standalone and Open-Source Khmer Grapheme-to-Phonemes. 3 | """ 4 | import os 5 | import csv 6 | from g2p import PhonetisaurusGraph 7 | 8 | def _read_lexicon_file(file): 9 | lexicon = {} 10 | with open(file) as infile: 11 | for line in csv.reader(infile, delimiter="\t"): 12 | word, phonemes = line 13 | word, phonemes = word.strip(), phonemes.strip().split() 14 | lexicon[word] = phonemes 15 | return lexicon 16 | 17 | _graph_file = os.path.join(os.path.dirname(__file__), "km_phonemizer.npz") 18 | _lexicon_file = os.path.join(os.path.dirname(__file__), "km_lexicon.tsv") 19 | _lexicon_dict = _read_lexicon_file(_lexicon_file) 20 | _graph = PhonetisaurusGraph.load(_graph_file, preload=False) 21 | 22 | def _phoneticize(word: str, beam: int, min_beam: int, beam_scale: float): 23 | results = _graph.g2p_one(word, beam=beam, min_beam=min_beam, beam_scale=beam_scale) 24 | results = list(results) 25 | if len(results) == 0: 26 | return None 27 | return results[0] 28 | 29 | 30 | def phonemize_single( 31 | word, 32 | beam: int = 500, 33 | min_beam: int = 100, 34 | beam_scale: float = 0.6, 35 | use_lexicon: bool = True, 36 | ): 37 | r""" 38 | Phonemize a single word. The word must match [a-zA-Z\u1780-\u17dd]+ 39 | """ 40 | if word is None: 41 | return None 42 | word = word.lower() 43 | if use_lexicon and word in _lexicon_dict: 44 | return _lexicon_dict[word] 45 | return _phoneticize(word, beam=beam, min_beam=min_beam, beam_scale=beam_scale) 46 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KLEA 2 | 3 | An open-source Khmer Word to Speech Model. Just single word not sentence! 4 | 5 | 6 | Open In Colab 7 | 8 | 9 | ### 1. Setup 10 | 11 | ```shell 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | ### 2. Download Checkpoint 16 | 17 | [G_60000.pth](https://huggingface.co/spaces/seanghay/KLEA/resolve/main/G_60000.pth) 18 | 19 | ```shell 20 | wget https://huggingface.co/spaces/seanghay/KLEA/resolve/main/G_60000.pth 21 | ``` 22 | 23 | Place the checkpoint in the current directory. 24 | 25 | ### 3. Inference 26 | 27 | ```shell 28 | python infer.py "មនុស្សខ្មែរ" 29 | ``` 30 | 31 | This will output a file called `audio.wav` in the current directory. Output audio sample rate is 22.05 kHz. 32 | 33 | ### Gradio 34 | 35 | ``` 36 | python app.py 37 | ``` 38 | 39 | 40 | ### Colab 41 | 42 | image 43 | 44 | 45 | 46 | ### Dataset 47 | 48 | This model was trained on kheng.info dataset. You can find it on http://kheng.info or at https://hf.co/datasets/seanghay/khmer_kheng_info_speech 49 | 50 | ## Reference 51 | 52 | - [VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://github.com/jaywalnut310/vits) 53 | - [kheng.info](https://kheng.info/about/) is an online audio dictionary for the Khmer language with over 3000 recordings. Kheng.info is backed by multiple dictionaries and a large text corpus, and supports search in English and Khmer with search results ordered by word frequency. 54 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | from models import SynthesizerTrn 2 | from scipy.io.wavfile import write 3 | from khmer_phonemizer import phonemize_single 4 | import utils 5 | import commons 6 | import torch 7 | import sys 8 | 9 | _pad = '_' 10 | _punctuation = '. ' 11 | _letters_ipa = 'acefhijklmnoprstuwzĕŋŏŭɑɓɔɗəɛɡɨɲʋʔʰː' 12 | 13 | # Export all symbols: 14 | symbols = [_pad] + list(_punctuation) + list(_letters_ipa) 15 | 16 | 17 | # Special symbol ids 18 | SPACE_ID = symbols.index(" ") 19 | 20 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 21 | 22 | def text_to_sequence(text): 23 | sequence = [] 24 | for symbol in text: 25 | symbol_id = _symbol_to_id[symbol] 26 | sequence += [symbol_id] 27 | return sequence 28 | 29 | 30 | def get_text(text, hps): 31 | text_norm = text_to_sequence(text) 32 | if hps.data.add_blank: 33 | text_norm = commons.intersperse(text_norm, 0) 34 | text_norm = torch.LongTensor(text_norm) 35 | return text_norm 36 | 37 | 38 | hps = utils.get_hparams_from_file("config.json") 39 | net_g = SynthesizerTrn( 40 | len(symbols), 41 | hps.data.filter_length // 2 + 1, 42 | hps.train.segment_size // hps.data.hop_length, 43 | **hps.model 44 | ) 45 | 46 | _ = net_g.eval() 47 | _ = utils.load_checkpoint("G_60000.pth", net_g, None) 48 | 49 | text = " ".join(phonemize_single(sys.argv[1]) + ["."]) 50 | stn_tst = get_text(text, hps) 51 | 52 | with torch.no_grad(): 53 | x_tst = stn_tst.unsqueeze(0) 54 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]) 55 | audio = ( 56 | net_g.infer( 57 | x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1 58 | )[0][0, 0] 59 | .data.cpu() 60 | .float() 61 | .numpy() 62 | ) 63 | write("audio.wav", rate=hps.data.sampling_rate, data=audio) 64 | print("saved audio.wav") -------------------------------------------------------------------------------- /export_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import torch 3 | import utils 4 | from models import SynthesizerTrn 5 | 6 | _pad = "_" 7 | _punctuation = ". " 8 | _letters_ipa = "acefhijklmnoprstuwzĕŋŏŭɑɓɔɗəɛɡɨɲʋʔʰː" 9 | 10 | symbols = [_pad] + list(_punctuation) + list(_letters_ipa) 11 | 12 | hps = utils.get_hparams_from_file("config.json") 13 | net_g = SynthesizerTrn( 14 | len(symbols), 15 | hps.data.filter_length // 2 + 1, 16 | hps.train.segment_size // hps.data.hop_length, 17 | **hps.model 18 | ) 19 | 20 | ckpt = torch.load("./G_60000.pth", map_location="cpu") 21 | net_g.load_state_dict(ckpt["model"]) 22 | net_g.eval() 23 | net_g.dec.remove_weight_norm() 24 | 25 | 26 | def infer_forward(text, text_lengths, scales, sid=None): 27 | noise_scale = scales[0] 28 | length_scale = scales[1] 29 | noise_scale_w = scales[2] 30 | audio = net_g.infer( 31 | text, 32 | text_lengths, 33 | noise_scale=noise_scale, 34 | length_scale=length_scale, 35 | noise_scale_w=noise_scale_w, 36 | sid=sid, 37 | )[0].unsqueeze(1) 38 | return audio 39 | 40 | 41 | net_g.forward = infer_forward 42 | 43 | dummy_input_length = 50 44 | 45 | num_symbols = len(symbols) 46 | sequences = torch.randint( 47 | low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long 48 | ) 49 | sequence_lengths = torch.LongTensor([sequences.size(1)]) 50 | 51 | # noise, noise_w, length 52 | scales = torch.FloatTensor([0.667, 1.0, 0.8]) 53 | dummy_input = (sequences, sequence_lengths, scales, None) 54 | 55 | torch.onnx.export( 56 | model=net_g, 57 | args=dummy_input, 58 | f=str("output.onnx"), 59 | verbose=False, 60 | opset_version=15, 61 | input_names=["input", "input_lengths", "scales", "sid"], 62 | output_names=["output"], 63 | dynamic_axes={ 64 | "input": {0: "batch_size", 1: "phonemes"}, 65 | "input_lengths": {0: "batch_size"}, 66 | "output": {0: "batch_size", 1: "time"}, 67 | }, 68 | ) 69 | -------------------------------------------------------------------------------- /infer_onnx.py: -------------------------------------------------------------------------------- 1 | import onnxruntime 2 | import numpy as np 3 | from wavfile import write as write_wav 4 | from utils import get_hparams_from_file 5 | from commons import intersperse 6 | from khmer_phonemizer import phonemize_single 7 | 8 | def audio_float_to_int16( 9 | audio: np.ndarray, max_wav_value: float = 32767.0 10 | ) -> np.ndarray: 11 | """Normalize audio and convert to int16 range""" 12 | audio_norm = audio * (max_wav_value / max(0.01, np.max(np.abs(audio)))) 13 | audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value) 14 | audio_norm = audio_norm.astype("int16") 15 | return audio_norm 16 | 17 | symbols = [ 18 | "_", 19 | ".", 20 | " ", 21 | "a", 22 | "c", 23 | "e", 24 | "f", 25 | "h", 26 | "i", 27 | "j", 28 | "k", 29 | "l", 30 | "m", 31 | "n", 32 | "o", 33 | "p", 34 | "r", 35 | "s", 36 | "t", 37 | "u", 38 | "w", 39 | "z", 40 | "ĕ", 41 | "ŋ", 42 | "ŏ", 43 | "ŭ", 44 | "ɑ", 45 | "ɓ", 46 | "ɔ", 47 | "ɗ", 48 | "ə", 49 | "ɛ", 50 | "ɡ", 51 | "ɨ", 52 | "ɲ", 53 | "ʋ", 54 | "ʔ", 55 | "ʰ", 56 | "ː", 57 | ] 58 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 59 | 60 | 61 | def text_to_sequence(text): 62 | sequence = [] 63 | for symbol in text: 64 | symbol_id = symbol_to_id[symbol] 65 | sequence += [symbol_id] 66 | return sequence 67 | 68 | 69 | def get_text(text, hps): 70 | text_norm = text_to_sequence(text) 71 | if hps.data.add_blank: 72 | text_norm = intersperse(text_norm, 0) 73 | return text_norm 74 | 75 | 76 | def infer(): 77 | session_options = onnxruntime.SessionOptions() 78 | providers = ["CPUExecutionProvider"] 79 | model = onnxruntime.InferenceSession( 80 | "./output.onnx", sess_options=session_options, providers=providers 81 | ) 82 | 83 | hps = get_hparams_from_file("config.json") 84 | text = " ".join(phonemize_single("ទិញបាយ") + ["."]) 85 | stn_tst = get_text(text, hps) 86 | 87 | text = np.expand_dims(np.array(stn_tst, dtype=np.int64), 0) 88 | text_lengths = np.array([text.shape[1]], dtype=np.int64) 89 | scales = np.array( 90 | [0.667, 1, 0.8], 91 | dtype=np.float32, 92 | ) 93 | sample_rate = 22050 94 | sid = None 95 | audio = model.run( 96 | None, 97 | { 98 | "input": text, 99 | "input_lengths": text_lengths, 100 | "scales": scales, 101 | "sid": sid, 102 | }, 103 | )[0].squeeze((0, 1)) 104 | audio = audio_float_to_int16(audio.squeeze()) 105 | write_wav("audio.wav", sample_rate, audio) 106 | 107 | 108 | if __name__ == "__main__": 109 | infer() 110 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import gradio as gr 3 | from models import SynthesizerTrn 4 | from khmer_phonemizer import phonemize_single 5 | import utils 6 | import commons 7 | import torch 8 | import khmernormalizer 9 | 10 | _pad = "_" 11 | _punctuation = ". " 12 | _letters_ipa = "acefhijklmnoprstuwzĕŋŏŭɑɓɔɗəɛɡɨɲʋʔʰː" 13 | 14 | # Export all symbols: 15 | symbols = [_pad] + list(_punctuation) + list(_letters_ipa) 16 | 17 | # Special symbol ids 18 | SPACE_ID = symbols.index(" ") 19 | 20 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 21 | 22 | 23 | def text_to_sequence(text): 24 | sequence = [] 25 | for symbol in text: 26 | symbol_id = _symbol_to_id[symbol] 27 | sequence += [symbol_id] 28 | return sequence 29 | 30 | 31 | def get_text(text, hps): 32 | text_norm = text_to_sequence(text) 33 | 34 | if hps.data.add_blank: 35 | text_norm = commons.intersperse(text_norm, 0) 36 | text_norm = torch.LongTensor(text_norm) 37 | return text_norm 38 | 39 | 40 | hps = utils.get_hparams_from_file("config.json") 41 | net_g = SynthesizerTrn( 42 | len(symbols), 43 | hps.data.filter_length // 2 + 1, 44 | hps.train.segment_size // hps.data.hop_length, 45 | **hps.model 46 | ) 47 | 48 | _ = net_g.eval() 49 | _ = utils.load_checkpoint("G_60000.pth", net_g, None) 50 | 51 | def generate_voice(text): 52 | text = khmernormalizer.normalize(text) 53 | text = " ".join(phonemize_single(text) + ["."]) 54 | stn_tst = get_text(text, hps) 55 | with torch.no_grad(): 56 | x_tst = stn_tst.unsqueeze(0) 57 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)]) 58 | audio = ( 59 | net_g.infer( 60 | x_tst, 61 | x_tst_lengths, 62 | noise_scale=0.667, 63 | noise_scale_w=0.8, 64 | length_scale=1, 65 | )[0][0, 0] 66 | .data.cpu() 67 | .float() 68 | .numpy() 69 | ) 70 | 71 | return (hps.data.sampling_rate, audio) 72 | 73 | 74 | with gr.Blocks( 75 | title="Khmer Word to Speech", 76 | theme=gr.themes.Default( 77 | font=[gr.themes.GoogleFont("Noto Sans Khmer"), "Arial", "sans-serif"] 78 | ), 79 | ) as blocks: 80 | gr.Markdown("# Khmer Word to Speech") 81 | 82 | input_text = gr.Text(label="ពាក្យខ្លី", lines=1) 83 | examples = gr.Examples(examples=["មនុស្សជាតិ", "ភ្នំព្រះ"], inputs=[input_text]) 84 | run_button = gr.Button(value="បង្កើត") 85 | 86 | out_audio = gr.Audio( 87 | label="សំឡេងដែលបានបង្កើត", 88 | type="numpy", 89 | ) 90 | 91 | inputs = [input_text] 92 | outputs = [out_audio] 93 | 94 | run_button.click( 95 | fn=generate_voice, 96 | inputs=inputs, 97 | outputs=outputs, 98 | queue=True, 99 | ) 100 | 101 | 102 | blocks.queue(concurrency_count=1).launch(debug=True) 103 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from librosa.filters import mel as librosa_mel_fn 4 | 5 | MAX_WAV_VALUE = 32768.0 6 | 7 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 8 | """ 9 | PARAMS 10 | ------ 11 | C: compression factor 12 | """ 13 | return torch.log(torch.clamp(x, min=clip_val) * C) 14 | 15 | 16 | def dynamic_range_decompression_torch(x, C=1): 17 | """ 18 | PARAMS 19 | ------ 20 | C: compression factor used to compress 21 | """ 22 | return torch.exp(x) / C 23 | 24 | 25 | def spectral_normalize_torch(magnitudes): 26 | output = dynamic_range_compression_torch(magnitudes) 27 | return output 28 | 29 | 30 | def spectral_de_normalize_torch(magnitudes): 31 | output = dynamic_range_decompression_torch(magnitudes) 32 | return output 33 | 34 | 35 | mel_basis = {} 36 | hann_window = {} 37 | 38 | 39 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 40 | if torch.min(y) < -1.: 41 | print('min value is ', torch.min(y)) 42 | if torch.max(y) > 1.: 43 | print('max value is ', torch.max(y)) 44 | 45 | global hann_window 46 | dtype_device = str(y.dtype) + '_' + str(y.device) 47 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 48 | if wnsize_dtype_device not in hann_window: 49 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 50 | 51 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 52 | y = y.squeeze(1) 53 | 54 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 55 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 56 | 57 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 58 | return spec 59 | 60 | 61 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 62 | global mel_basis 63 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 64 | fmax_dtype_device = str(fmax) + '_' + dtype_device 65 | if fmax_dtype_device not in mel_basis: 66 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 67 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 68 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 69 | spec = spectral_normalize_torch(spec) 70 | return spec 71 | 72 | 73 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 74 | if torch.min(y) < -1.: 75 | print('min value is ', torch.min(y)) 76 | if torch.max(y) > 1.: 77 | print('max value is ', torch.max(y)) 78 | 79 | global mel_basis, hann_window 80 | dtype_device = str(y.dtype) + '_' + str(y.device) 81 | fmax_dtype_device = str(fmax) + '_' + dtype_device 82 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 83 | if fmax_dtype_device not in mel_basis: 84 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 85 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 86 | if wnsize_dtype_device not in hann_window: 87 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 88 | 89 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 90 | y = y.squeeze(1) 91 | 92 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 93 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 94 | 95 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 96 | 97 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 98 | spec = spectral_normalize_torch(spec) 99 | 100 | return spec 101 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # General 2 | .DS_Store 3 | .AppleDouble 4 | .LSOverride 5 | 6 | # Icon must end with two \r 7 | Icon 8 | 9 | 10 | # Thumbnails 11 | ._* 12 | 13 | # Files that might appear in the root of a volume 14 | .DocumentRevisions-V100 15 | .fseventsd 16 | .Spotlight-V100 17 | .TemporaryItems 18 | .Trashes 19 | .VolumeIcon.icns 20 | .com.apple.timemachine.donotpresent 21 | 22 | # Directories potentially created on remote AFP share 23 | .AppleDB 24 | .AppleDesktop 25 | Network Trash Folder 26 | Temporary Items 27 | .apdisk 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | share/python-wheels/ 52 | *.egg-info/ 53 | .installed.cfg 54 | *.egg 55 | MANIFEST 56 | 57 | # PyInstaller 58 | # Usually these files are written by a python script from a template 59 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 60 | *.manifest 61 | *.spec 62 | 63 | # Installer logs 64 | pip-log.txt 65 | pip-delete-this-directory.txt 66 | 67 | # Unit test / coverage reports 68 | htmlcov/ 69 | .tox/ 70 | .nox/ 71 | .coverage 72 | .coverage.* 73 | .cache 74 | nosetests.xml 75 | coverage.xml 76 | *.cover 77 | *.py,cover 78 | .hypothesis/ 79 | .pytest_cache/ 80 | cover/ 81 | 82 | # Translations 83 | *.mo 84 | *.pot 85 | 86 | # Django stuff: 87 | *.log 88 | local_settings.py 89 | db.sqlite3 90 | db.sqlite3-journal 91 | 92 | # Flask stuff: 93 | instance/ 94 | .webassets-cache 95 | 96 | # Scrapy stuff: 97 | .scrapy 98 | 99 | # Sphinx documentation 100 | docs/_build/ 101 | 102 | # PyBuilder 103 | .pybuilder/ 104 | target/ 105 | 106 | # Jupyter Notebook 107 | .ipynb_checkpoints 108 | 109 | # IPython 110 | profile_default/ 111 | ipython_config.py 112 | 113 | # pyenv 114 | # For a library or package, you might want to ignore these files since the code is 115 | # intended to run in multiple environments; otherwise, check them in: 116 | # .python-version 117 | 118 | # pipenv 119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 122 | # install all needed dependencies. 123 | #Pipfile.lock 124 | 125 | # poetry 126 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 127 | # This is especially recommended for binary packages to ensure reproducibility, and is more 128 | # commonly ignored for libraries. 129 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 130 | #poetry.lock 131 | 132 | # pdm 133 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 134 | #pdm.lock 135 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 136 | # in version control. 137 | # https://pdm.fming.dev/#use-with-ide 138 | .pdm.toml 139 | 140 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 141 | __pypackages__/ 142 | 143 | # Celery stuff 144 | celerybeat-schedule 145 | celerybeat.pid 146 | 147 | # SageMath parsed files 148 | *.sage.py 149 | 150 | # Environments 151 | .env 152 | .venv 153 | env/ 154 | venv/ 155 | ENV/ 156 | env.bak/ 157 | venv.bak/ 158 | 159 | # Spyder project settings 160 | .spyderproject 161 | .spyproject 162 | 163 | # Rope project settings 164 | .ropeproject 165 | 166 | # mkdocs documentation 167 | /site 168 | 169 | # mypy 170 | .mypy_cache/ 171 | .dmypy.json 172 | dmypy.json 173 | 174 | # Pyre type checker 175 | .pyre/ 176 | 177 | # pytype static type analyzer 178 | .pytype/ 179 | 180 | # Cython debug symbols 181 | cython_debug/ 182 | 183 | # PyCharm 184 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 185 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 186 | # and can be added to the global gitignore or merged into this file. For a more nuclear 187 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 188 | #.idea/ 189 | *.wav 190 | *.pth 191 | *.onnx -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | def init_weights(m, mean=0.0, std=0.01): 6 | classname = m.__class__.__name__ 7 | if classname.find("Conv") != -1: 8 | m.weight.data.normal_(mean, std) 9 | 10 | 11 | def get_padding(kernel_size, dilation=1): 12 | return int((kernel_size*dilation - dilation)/2) 13 | 14 | 15 | def convert_pad_shape(pad_shape): 16 | l = pad_shape[::-1] 17 | pad_shape = [item for sublist in l for item in sublist] 18 | return pad_shape 19 | 20 | 21 | def intersperse(lst, item): 22 | result = [item] * (len(lst) * 2 + 1) 23 | result[1::2] = lst 24 | return result 25 | 26 | 27 | def kl_divergence(m_p, logs_p, m_q, logs_q): 28 | """KL(P||Q)""" 29 | kl = (logs_q - logs_p) - 0.5 30 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 31 | return kl 32 | 33 | 34 | def rand_gumbel(shape): 35 | """Sample from the Gumbel distribution, protect from overflows.""" 36 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 37 | return -torch.log(-torch.log(uniform_samples)) 38 | 39 | 40 | def rand_gumbel_like(x): 41 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 42 | return g 43 | 44 | 45 | def slice_segments(x, ids_str, segment_size=4): 46 | ret = torch.zeros_like(x[:, :, :segment_size]) 47 | for i in range(x.size(0)): 48 | idx_str = ids_str[i] 49 | idx_end = idx_str + segment_size 50 | ret[i] = x[i, :, idx_str:idx_end] 51 | return ret 52 | 53 | 54 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 55 | b, d, t = x.size() 56 | if x_lengths is None: 57 | x_lengths = t 58 | ids_str_max = x_lengths - segment_size + 1 59 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 60 | ret = slice_segments(x, ids_str, segment_size) 61 | return ret, ids_str 62 | 63 | 64 | def get_timing_signal_1d( 65 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 66 | position = torch.arange(length, dtype=torch.float) 67 | num_timescales = channels // 2 68 | log_timescale_increment = ( 69 | math.log(float(max_timescale) / float(min_timescale)) / 70 | (num_timescales - 1)) 71 | inv_timescales = min_timescale * torch.exp( 72 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 73 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 74 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 75 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 76 | signal = signal.view(1, channels, length) 77 | return signal 78 | 79 | 80 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 81 | b, channels, length = x.size() 82 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 83 | return x + signal.to(dtype=x.dtype, device=x.device) 84 | 85 | 86 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 87 | b, channels, length = x.size() 88 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 89 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 90 | 91 | 92 | def subsequent_mask(length): 93 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 94 | return mask 95 | 96 | 97 | @torch.jit.script 98 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 99 | n_channels_int = n_channels[0] 100 | in_act = input_a + input_b 101 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 102 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 103 | acts = t_act * s_act 104 | return acts 105 | 106 | 107 | def convert_pad_shape(pad_shape): 108 | l = pad_shape[::-1] 109 | pad_shape = [item for sublist in l for item in sublist] 110 | return pad_shape 111 | 112 | 113 | def shift_1d(x): 114 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 115 | return x 116 | 117 | 118 | def sequence_mask(length, max_length=None): 119 | if max_length is None: 120 | max_length = length.max() 121 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 122 | return x.unsqueeze(0) < length.unsqueeze(1) 123 | 124 | 125 | def generate_path(duration, mask): 126 | """ 127 | duration: [b, 1, t_x] 128 | mask: [b, 1, t_y, t_x] 129 | """ 130 | device = duration.device 131 | 132 | b, _, t_y, t_x = mask.shape 133 | cum_duration = torch.cumsum(duration, -1) 134 | 135 | cum_duration_flat = cum_duration.view(b * t_x) 136 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 137 | path = path.view(b, t_x, t_y) 138 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 139 | path = path.unsqueeze(1).transpose(2,3) * mask 140 | return path 141 | 142 | 143 | def clip_grad_value_(parameters, clip_value, norm_type=2): 144 | if isinstance(parameters, torch.Tensor): 145 | parameters = [parameters] 146 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 147 | norm_type = float(norm_type) 148 | if clip_value is not None: 149 | clip_value = float(clip_value) 150 | 151 | total_norm = 0 152 | for p in parameters: 153 | param_norm = p.grad.data.norm(norm_type) 154 | total_norm += param_norm.item() ** norm_type 155 | if clip_value is not None: 156 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 157 | total_norm = total_norm ** (1. / norm_type) 158 | return total_norm 159 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | import numpy as np 9 | from scipy.io.wavfile import read 10 | import torch 11 | 12 | MATPLOTLIB_FLAG = False 13 | 14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 15 | logger = logging 16 | 17 | 18 | def load_checkpoint(checkpoint_path, model, optimizer=None): 19 | assert os.path.isfile(checkpoint_path) 20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 21 | iteration = checkpoint_dict['iteration'] 22 | learning_rate = checkpoint_dict['learning_rate'] 23 | if optimizer is not None: 24 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 25 | saved_state_dict = checkpoint_dict['model'] 26 | if hasattr(model, 'module'): 27 | state_dict = model.module.state_dict() 28 | else: 29 | state_dict = model.state_dict() 30 | new_state_dict= {} 31 | for k, v in state_dict.items(): 32 | try: 33 | new_state_dict[k] = saved_state_dict[k] 34 | except: 35 | logger.info("%s is not in the checkpoint" % k) 36 | new_state_dict[k] = v 37 | if hasattr(model, 'module'): 38 | model.module.load_state_dict(new_state_dict) 39 | else: 40 | model.load_state_dict(new_state_dict) 41 | logger.info("Loaded checkpoint '{}' (iteration {})" .format( 42 | checkpoint_path, iteration)) 43 | return model, optimizer, learning_rate, iteration 44 | 45 | 46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 47 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 48 | iteration, checkpoint_path)) 49 | if hasattr(model, 'module'): 50 | state_dict = model.module.state_dict() 51 | else: 52 | state_dict = model.state_dict() 53 | torch.save({'model': state_dict, 54 | 'iteration': iteration, 55 | 'optimizer': optimizer.state_dict(), 56 | 'learning_rate': learning_rate}, checkpoint_path) 57 | 58 | 59 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 60 | for k, v in scalars.items(): 61 | writer.add_scalar(k, v, global_step) 62 | for k, v in histograms.items(): 63 | writer.add_histogram(k, v, global_step) 64 | for k, v in images.items(): 65 | writer.add_image(k, v, global_step, dataformats='HWC') 66 | for k, v in audios.items(): 67 | writer.add_audio(k, v, global_step, audio_sampling_rate) 68 | 69 | 70 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 71 | f_list = glob.glob(os.path.join(dir_path, regex)) 72 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 73 | x = f_list[-1] 74 | print(x) 75 | return x 76 | 77 | def load_wav_to_torch(full_path): 78 | sampling_rate, data = read(full_path) 79 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate 80 | 81 | 82 | def load_filepaths_and_text(filename, split="|"): 83 | with open(filename, encoding='utf-8') as f: 84 | filepaths_and_text = [line.strip().split(split) for line in f] 85 | return filepaths_and_text 86 | 87 | 88 | def get_hparams(init=True): 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json", 91 | help='JSON file for configuration') 92 | parser.add_argument('-m', '--model', type=str, required=True, 93 | help='Model name') 94 | 95 | args = parser.parse_args() 96 | model_dir = os.path.join("./logs", args.model) 97 | 98 | if not os.path.exists(model_dir): 99 | os.makedirs(model_dir) 100 | 101 | config_path = args.config 102 | config_save_path = os.path.join(model_dir, "config.json") 103 | if init: 104 | with open(config_path, "r") as f: 105 | data = f.read() 106 | with open(config_save_path, "w") as f: 107 | f.write(data) 108 | else: 109 | with open(config_save_path, "r") as f: 110 | data = f.read() 111 | config = json.loads(data) 112 | 113 | hparams = HParams(**config) 114 | hparams.model_dir = model_dir 115 | return hparams 116 | 117 | 118 | def get_hparams_from_dir(model_dir): 119 | config_save_path = os.path.join(model_dir, "config.json") 120 | with open(config_save_path, "r") as f: 121 | data = f.read() 122 | config = json.loads(data) 123 | 124 | hparams =HParams(**config) 125 | hparams.model_dir = model_dir 126 | return hparams 127 | 128 | 129 | def get_hparams_from_file(config_path): 130 | with open(config_path, "r") as f: 131 | data = f.read() 132 | config = json.loads(data) 133 | 134 | hparams =HParams(**config) 135 | return hparams 136 | 137 | 138 | def check_git_hash(model_dir): 139 | source_dir = os.path.dirname(os.path.realpath(__file__)) 140 | if not os.path.exists(os.path.join(source_dir, ".git")): 141 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 142 | source_dir 143 | )) 144 | return 145 | 146 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 147 | 148 | path = os.path.join(model_dir, "githash") 149 | if os.path.exists(path): 150 | saved_hash = open(path).read() 151 | if saved_hash != cur_hash: 152 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 153 | saved_hash[:8], cur_hash[:8])) 154 | else: 155 | open(path, "w").write(cur_hash) 156 | 157 | 158 | def get_logger(model_dir, filename="train.log"): 159 | global logger 160 | logger = logging.getLogger(os.path.basename(model_dir)) 161 | logger.setLevel(logging.DEBUG) 162 | 163 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 164 | if not os.path.exists(model_dir): 165 | os.makedirs(model_dir) 166 | h = logging.FileHandler(os.path.join(model_dir, filename)) 167 | h.setLevel(logging.DEBUG) 168 | h.setFormatter(formatter) 169 | logger.addHandler(h) 170 | return logger 171 | 172 | 173 | class HParams(): 174 | def __init__(self, **kwargs): 175 | for k, v in kwargs.items(): 176 | if type(v) == dict: 177 | v = HParams(**v) 178 | self[k] = v 179 | 180 | def keys(self): 181 | return self.__dict__.keys() 182 | 183 | def items(self): 184 | return self.__dict__.items() 185 | 186 | def values(self): 187 | return self.__dict__.values() 188 | 189 | def __len__(self): 190 | return len(self.__dict__) 191 | 192 | def __getitem__(self, key): 193 | return getattr(self, key) 194 | 195 | def __setitem__(self, key, value): 196 | return setattr(self, key, value) 197 | 198 | def __contains__(self, key): 199 | return key in self.__dict__ 200 | 201 | def __repr__(self): 202 | return self.__dict__.__repr__() 203 | -------------------------------------------------------------------------------- /g2p.py: -------------------------------------------------------------------------------- 1 | """ 2 | Guess word pronunciations using a Phonetisaurus FST 3 | 4 | See bin/fst2npz.py to convert an FST to a numpy graph. 5 | 6 | Reference: 7 | https://github.com/rhasspy/gruut/blob/master/gruut/g2p_phonetisaurus.py 8 | """ 9 | import typing 10 | from collections import defaultdict 11 | from pathlib import Path 12 | import numpy as np 13 | 14 | NUMPY_GRAPH = typing.Dict[str, np.ndarray] 15 | _NOT_FINAL = object() 16 | 17 | class PhonetisaurusGraph: 18 | """Graph of numpy arrays that represents a Phonetisaurus FST 19 | 20 | Also contains shared cache of edges and final state probabilities. 21 | These caches are necessary to ensure that the .npz file stays small and fast 22 | to load. 23 | """ 24 | 25 | def __init__(self, graph: NUMPY_GRAPH, preload: bool = False): 26 | self.graph = graph 27 | 28 | self.start_node = int(self.graph["start_node"].item()) 29 | 30 | # edge_index -> (from_node, to_node, ilabel, olabel) 31 | self.edges = self.graph["edges"] 32 | self.edge_probs = self.graph["edge_probs"] 33 | 34 | # int -> [str] 35 | self.symbols = [] 36 | for symbol_str in self.graph["symbols"]: 37 | symbol_list = symbol_str.replace("_", "").split("|") 38 | self.symbols.append((len(symbol_list), symbol_list)) 39 | 40 | # nodes that are accepting states 41 | self.final_nodes = self.graph["final_nodes"] 42 | 43 | # node -> probability 44 | self.final_probs = self.graph["final_probs"] 45 | 46 | # Cache 47 | self.preloaded = preload 48 | self.out_edges: typing.Dict[int, typing.List[int]] = defaultdict(list) 49 | self.final_node_probs: typing.Dict[int, typing.Any] = {} 50 | 51 | if preload: 52 | # Load out edges 53 | for edge_idx, (from_node, *_) in enumerate(self.edges): 54 | self.out_edges[from_node].append(edge_idx) 55 | 56 | # Load final probabilities 57 | self.final_node_probs.update(zip(self.final_nodes, self.final_probs)) 58 | 59 | @staticmethod 60 | def load(graph_path: typing.Union[str, Path], **kwargs) -> "PhonetisaurusGraph": 61 | """Load .npz file with numpy graph""" 62 | np_graph = np.load(graph_path, allow_pickle=True) 63 | return PhonetisaurusGraph(np_graph, **kwargs) 64 | 65 | def g2p_one( 66 | self, 67 | word: typing.Union[str, typing.Sequence[str]], 68 | eps: str = "", 69 | beam: int = 5000, 70 | min_beam: int = 100, 71 | beam_scale: float = 0.6, 72 | grapheme_separator: str = "", 73 | max_guesses: int = 1, 74 | ) -> typing.Iterable[typing.Tuple[typing.Sequence[str], typing.Sequence[str]]]: 75 | """Guess phonemes for word""" 76 | current_beam = beam 77 | graphemes: typing.Sequence[str] = [] 78 | 79 | if isinstance(word, str): 80 | word = word.strip() 81 | 82 | if grapheme_separator: 83 | graphemes = word.split(grapheme_separator) 84 | else: 85 | graphemes = list(word) 86 | else: 87 | graphemes = word 88 | 89 | if not graphemes: 90 | return [] 91 | 92 | # (prob, node, graphemes, phonemes, final, beam) 93 | q: typing.List[ 94 | typing.Tuple[ 95 | float, 96 | typing.Optional[int], 97 | typing.Sequence[str], 98 | typing.List[str], 99 | bool, 100 | ] 101 | ] = [(0.0, self.start_node, graphemes, [], False)] 102 | 103 | q_next: typing.List[ 104 | typing.Tuple[ 105 | float, 106 | typing.Optional[int], 107 | typing.Sequence[str], 108 | typing.List[str], 109 | bool, 110 | ] 111 | ] = [] 112 | 113 | # (prob, phonemes) 114 | best_heap: typing.List[typing.Tuple[float, typing.Sequence[str]]] = [] 115 | 116 | # Avoid duplicate guesses 117 | guessed_phonemes: typing.Set[typing.Tuple[str, ...]] = set() 118 | 119 | while q: 120 | done_with_word = False 121 | q_next = [] 122 | 123 | for prob, node, next_graphemes, output, is_final in q: 124 | if is_final: 125 | # Complete guess 126 | phonemes = tuple(output) 127 | if phonemes not in guessed_phonemes: 128 | best_heap.append((prob, phonemes)) 129 | guessed_phonemes.add(phonemes) 130 | 131 | if len(best_heap) >= max_guesses: 132 | done_with_word = True 133 | break 134 | 135 | continue 136 | 137 | assert node is not None 138 | 139 | if not next_graphemes: 140 | if self.preloaded: 141 | final_prob = self.final_node_probs.get(node, _NOT_FINAL) 142 | else: 143 | final_prob = self.final_node_probs.get(node) 144 | if final_prob is None: 145 | final_idx = int(np.searchsorted(self.final_nodes, node)) 146 | if self.final_nodes[final_idx] == node: 147 | # Cache 148 | final_prob = float(self.final_probs[final_idx]) 149 | self.final_node_probs[node] = final_prob 150 | else: 151 | # Not a final state 152 | final_prob = _NOT_FINAL 153 | self.final_node_probs[node] = final_prob 154 | 155 | if final_prob != _NOT_FINAL: 156 | final_prob = typing.cast(float, final_prob) 157 | q_next.append((prob + final_prob, None, [], output, True)) 158 | 159 | len_next_graphemes = len(next_graphemes) 160 | if self.preloaded: 161 | # Was pre-loaded in __init__ 162 | edge_idxs = self.out_edges[node] 163 | else: 164 | # Build cache during search 165 | maybe_edge_idxs = self.out_edges.get(node) 166 | if maybe_edge_idxs is None: 167 | edge_idx = int(np.searchsorted(self.edges[:, 0], node)) 168 | edge_idxs = [] 169 | while self.edges[edge_idx][0] == node: 170 | edge_idxs.append(edge_idx) 171 | edge_idx += 1 172 | 173 | # Cache 174 | self.out_edges[node] = edge_idxs 175 | else: 176 | edge_idxs = maybe_edge_idxs 177 | 178 | for edge_idx in edge_idxs: 179 | _, to_node, ilabel_idx, olabel_idx = self.edges[edge_idx] 180 | out_prob = self.edge_probs[edge_idx] 181 | 182 | len_igraphemes, igraphemes = self.symbols[ilabel_idx] 183 | 184 | if len_igraphemes > len_next_graphemes: 185 | continue 186 | 187 | if igraphemes == [eps]: 188 | item = (prob + out_prob, to_node, next_graphemes, output, False) 189 | q_next.append(item) 190 | else: 191 | sub_graphemes = next_graphemes[:len_igraphemes] 192 | if igraphemes == sub_graphemes: 193 | _, olabel = self.symbols[olabel_idx] 194 | item = ( 195 | prob + out_prob, 196 | to_node, 197 | next_graphemes[len(sub_graphemes) :], 198 | output + olabel, 199 | False, 200 | ) 201 | q_next.append(item) 202 | 203 | if done_with_word: 204 | break 205 | 206 | q_next = sorted(q_next, key=lambda item: item[0])[:current_beam] 207 | q = q_next 208 | 209 | current_beam = max(min_beam, (int(current_beam * beam_scale))) 210 | 211 | # Yield guesses 212 | if best_heap: 213 | for _, guess_phonemes in sorted(best_heap, key=lambda item: item[0])[ 214 | :max_guesses 215 | ]: 216 | yield [p for p in guess_phonemes if p] 217 | else: 218 | # No guesses 219 | yield [] -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | unnormalized_derivatives, 16 | inverse=False, 17 | tails=None, 18 | tail_bound=1., 19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 21 | min_derivative=DEFAULT_MIN_DERIVATIVE): 22 | 23 | if tails is None: 24 | spline_fn = rational_quadratic_spline 25 | spline_kwargs = {} 26 | else: 27 | spline_fn = unconstrained_rational_quadratic_spline 28 | spline_kwargs = { 29 | 'tails': tails, 30 | 'tail_bound': tail_bound 31 | } 32 | 33 | outputs, logabsdet = spline_fn( 34 | inputs=inputs, 35 | unnormalized_widths=unnormalized_widths, 36 | unnormalized_heights=unnormalized_heights, 37 | unnormalized_derivatives=unnormalized_derivatives, 38 | inverse=inverse, 39 | min_bin_width=min_bin_width, 40 | min_bin_height=min_bin_height, 41 | min_derivative=min_derivative, 42 | **spline_kwargs 43 | ) 44 | return outputs, logabsdet 45 | 46 | 47 | def searchsorted(bin_locations, inputs, eps=1e-6): 48 | bin_locations[..., -1] += eps 49 | return torch.sum( 50 | inputs[..., None] >= bin_locations, 51 | dim=-1 52 | ) - 1 53 | 54 | 55 | def unconstrained_rational_quadratic_spline(inputs, 56 | unnormalized_widths, 57 | unnormalized_heights, 58 | unnormalized_derivatives, 59 | inverse=False, 60 | tails='linear', 61 | tail_bound=1., 62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 64 | min_derivative=DEFAULT_MIN_DERIVATIVE): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == 'linear': 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | def rational_quadratic_spline(inputs, 97 | unnormalized_widths, 98 | unnormalized_heights, 99 | unnormalized_derivatives, 100 | inverse=False, 101 | left=0., right=1., bottom=0., top=1., 102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 104 | min_derivative=DEFAULT_MIN_DERIVATIVE): 105 | if torch.min(inputs) < left or torch.max(inputs) > right: 106 | raise ValueError('Input to a transform is not within its domain') 107 | 108 | num_bins = unnormalized_widths.shape[-1] 109 | 110 | if min_bin_width * num_bins > 1.0: 111 | raise ValueError('Minimal bin width too large for the number of bins') 112 | if min_bin_height * num_bins > 1.0: 113 | raise ValueError('Minimal bin height too large for the number of bins') 114 | 115 | widths = F.softmax(unnormalized_widths, dim=-1) 116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 117 | cumwidths = torch.cumsum(widths, dim=-1) 118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 119 | cumwidths = (right - left) * cumwidths + left 120 | cumwidths[..., 0] = left 121 | cumwidths[..., -1] = right 122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 123 | 124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 125 | 126 | heights = F.softmax(unnormalized_heights, dim=-1) 127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 128 | cumheights = torch.cumsum(heights, dim=-1) 129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 130 | cumheights = (top - bottom) * cumheights + bottom 131 | cumheights[..., 0] = bottom 132 | cumheights[..., -1] = top 133 | heights = cumheights[..., 1:] - cumheights[..., :-1] 134 | 135 | if inverse: 136 | bin_idx = searchsorted(cumheights, inputs)[..., None] 137 | else: 138 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 139 | 140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 142 | 143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 144 | delta = heights / widths 145 | input_delta = delta.gather(-1, bin_idx)[..., 0] 146 | 147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 149 | 150 | input_heights = heights.gather(-1, bin_idx)[..., 0] 151 | 152 | if inverse: 153 | a = (((inputs - input_cumheights) * (input_derivatives 154 | + input_derivatives_plus_one 155 | - 2 * input_delta) 156 | + input_heights * (input_delta - input_derivatives))) 157 | b = (input_heights * input_derivatives 158 | - (inputs - input_cumheights) * (input_derivatives 159 | + input_derivatives_plus_one 160 | - 2 * input_delta)) 161 | c = - input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 171 | * theta_one_minus_theta) 172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 173 | + 2 * input_delta * theta_one_minus_theta 174 | + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) 183 | + input_derivatives * theta_one_minus_theta) 184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 185 | * theta_one_minus_theta) 186 | outputs = input_cumheights + numerator / denominator 187 | 188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 189 | + 2 * input_delta * theta_one_minus_theta 190 | + input_derivatives * (1 - theta).pow(2)) 191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 192 | 193 | return outputs, logabsdet 194 | -------------------------------------------------------------------------------- /attentions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import commons 6 | import modules 7 | from modules import LayerNorm 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): 12 | super().__init__() 13 | self.hidden_channels = hidden_channels 14 | self.filter_channels = filter_channels 15 | self.n_heads = n_heads 16 | self.n_layers = n_layers 17 | self.kernel_size = kernel_size 18 | self.p_dropout = p_dropout 19 | self.window_size = window_size 20 | 21 | self.drop = nn.Dropout(p_dropout) 22 | self.attn_layers = nn.ModuleList() 23 | self.norm_layers_1 = nn.ModuleList() 24 | self.ffn_layers = nn.ModuleList() 25 | self.norm_layers_2 = nn.ModuleList() 26 | for i in range(self.n_layers): 27 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) 28 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 29 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 30 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 31 | 32 | def forward(self, x, x_mask): 33 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 34 | x = x * x_mask 35 | for i in range(self.n_layers): 36 | y = self.attn_layers[i](x, x, attn_mask) 37 | y = self.drop(y) 38 | x = self.norm_layers_1[i](x + y) 39 | 40 | y = self.ffn_layers[i](x, x_mask) 41 | y = self.drop(y) 42 | x = self.norm_layers_2[i](x + y) 43 | x = x * x_mask 44 | return x 45 | 46 | 47 | class Decoder(nn.Module): 48 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): 49 | super().__init__() 50 | self.hidden_channels = hidden_channels 51 | self.filter_channels = filter_channels 52 | self.n_heads = n_heads 53 | self.n_layers = n_layers 54 | self.kernel_size = kernel_size 55 | self.p_dropout = p_dropout 56 | self.proximal_bias = proximal_bias 57 | self.proximal_init = proximal_init 58 | 59 | self.drop = nn.Dropout(p_dropout) 60 | self.self_attn_layers = nn.ModuleList() 61 | self.norm_layers_0 = nn.ModuleList() 62 | self.encdec_attn_layers = nn.ModuleList() 63 | self.norm_layers_1 = nn.ModuleList() 64 | self.ffn_layers = nn.ModuleList() 65 | self.norm_layers_2 = nn.ModuleList() 66 | for i in range(self.n_layers): 67 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) 68 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 69 | self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) 70 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 71 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 72 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 73 | 74 | def forward(self, x, x_mask, h, h_mask): 75 | """ 76 | x: decoder input 77 | h: encoder output 78 | """ 79 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 80 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 81 | x = x * x_mask 82 | for i in range(self.n_layers): 83 | y = self.self_attn_layers[i](x, x, self_attn_mask) 84 | y = self.drop(y) 85 | x = self.norm_layers_0[i](x + y) 86 | 87 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 88 | y = self.drop(y) 89 | x = self.norm_layers_1[i](x + y) 90 | 91 | y = self.ffn_layers[i](x, x_mask) 92 | y = self.drop(y) 93 | x = self.norm_layers_2[i](x + y) 94 | x = x * x_mask 95 | return x 96 | 97 | 98 | class MultiHeadAttention(nn.Module): 99 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): 100 | super().__init__() 101 | assert channels % n_heads == 0 102 | 103 | self.channels = channels 104 | self.out_channels = out_channels 105 | self.n_heads = n_heads 106 | self.p_dropout = p_dropout 107 | self.window_size = window_size 108 | self.heads_share = heads_share 109 | self.block_length = block_length 110 | self.proximal_bias = proximal_bias 111 | self.proximal_init = proximal_init 112 | self.attn = None 113 | 114 | self.k_channels = channels // n_heads 115 | self.conv_q = nn.Conv1d(channels, channels, 1) 116 | self.conv_k = nn.Conv1d(channels, channels, 1) 117 | self.conv_v = nn.Conv1d(channels, channels, 1) 118 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 119 | self.drop = nn.Dropout(p_dropout) 120 | 121 | if window_size is not None: 122 | n_heads_rel = 1 if heads_share else n_heads 123 | rel_stddev = self.k_channels**-0.5 124 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 125 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 126 | 127 | nn.init.xavier_uniform_(self.conv_q.weight) 128 | nn.init.xavier_uniform_(self.conv_k.weight) 129 | nn.init.xavier_uniform_(self.conv_v.weight) 130 | if proximal_init: 131 | with torch.no_grad(): 132 | self.conv_k.weight.copy_(self.conv_q.weight) 133 | self.conv_k.bias.copy_(self.conv_q.bias) 134 | 135 | def forward(self, x, c, attn_mask=None): 136 | q = self.conv_q(x) 137 | k = self.conv_k(c) 138 | v = self.conv_v(c) 139 | 140 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 141 | 142 | x = self.conv_o(x) 143 | return x 144 | 145 | def attention(self, query, key, value, mask=None): 146 | # reshape [b, d, t] -> [b, n_h, t, d_k] 147 | b, d, t_s, t_t = (*key.size(), query.size(2)) 148 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 149 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 150 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 151 | 152 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 153 | if self.window_size is not None: 154 | assert t_s == t_t, "Relative attention is only available for self-attention." 155 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 156 | rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) 157 | scores_local = self._relative_position_to_absolute_position(rel_logits) 158 | scores = scores + scores_local 159 | if self.proximal_bias: 160 | assert t_s == t_t, "Proximal bias is only available for self-attention." 161 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 162 | if mask is not None: 163 | scores = scores.masked_fill(mask == 0, -1e4) 164 | if self.block_length is not None: 165 | assert t_s == t_t, "Local attention is only available for self-attention." 166 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 167 | scores = scores.masked_fill(block_mask == 0, -1e4) 168 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 169 | p_attn = self.drop(p_attn) 170 | output = torch.matmul(p_attn, value) 171 | if self.window_size is not None: 172 | relative_weights = self._absolute_position_to_relative_position(p_attn) 173 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 174 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 175 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 176 | return output, p_attn 177 | 178 | def _matmul_with_relative_values(self, x, y): 179 | """ 180 | x: [b, h, l, m] 181 | y: [h or 1, m, d] 182 | ret: [b, h, l, d] 183 | """ 184 | ret = torch.matmul(x, y.unsqueeze(0)) 185 | return ret 186 | 187 | def _matmul_with_relative_keys(self, x, y): 188 | """ 189 | x: [b, h, l, d] 190 | y: [h or 1, m, d] 191 | ret: [b, h, l, m] 192 | """ 193 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 194 | return ret 195 | 196 | def _get_relative_embeddings(self, relative_embeddings, length): 197 | max_relative_position = 2 * self.window_size + 1 198 | # Pad first before slice to avoid using cond ops. 199 | pad_length = max(length - (self.window_size + 1), 0) 200 | slice_start_position = max((self.window_size + 1) - length, 0) 201 | slice_end_position = slice_start_position + 2 * length - 1 202 | if pad_length > 0: 203 | padded_relative_embeddings = F.pad( 204 | relative_embeddings, 205 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 206 | else: 207 | padded_relative_embeddings = relative_embeddings 208 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] 209 | return used_relative_embeddings 210 | 211 | def _relative_position_to_absolute_position(self, x): 212 | """ 213 | x: [b, h, l, 2*l-1] 214 | ret: [b, h, l, l] 215 | """ 216 | batch, heads, length, _ = x.size() 217 | # Concat columns of pad to shift from relative to absolute indexing. 218 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) 219 | 220 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 221 | x_flat = x.view([batch, heads, length * 2 * length]) 222 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]])) 223 | 224 | # Reshape and slice out the padded elements. 225 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] 226 | return x_final 227 | 228 | def _absolute_position_to_relative_position(self, x): 229 | """ 230 | x: [b, h, l, l] 231 | ret: [b, h, l, 2*l-1] 232 | """ 233 | batch, heads, length, _ = x.size() 234 | # padd along column 235 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) 236 | x_flat = x.view([batch, heads, length**2 + length*(length -1)]) 237 | # add 0's in the beginning that will skew the elements after reshape 238 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 239 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] 240 | return x_final 241 | 242 | def _attention_bias_proximal(self, length): 243 | """Bias for self-attention to encourage attention to close positions. 244 | Args: 245 | length: an integer scalar. 246 | Returns: 247 | a Tensor with shape [1, 1, length, length] 248 | """ 249 | r = torch.arange(length, dtype=torch.float32) 250 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 251 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 252 | 253 | 254 | class FFN(nn.Module): 255 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): 256 | super().__init__() 257 | self.in_channels = in_channels 258 | self.out_channels = out_channels 259 | self.filter_channels = filter_channels 260 | self.kernel_size = kernel_size 261 | self.p_dropout = p_dropout 262 | self.activation = activation 263 | self.causal = causal 264 | 265 | if causal: 266 | self.padding = self._causal_padding 267 | else: 268 | self.padding = self._same_padding 269 | 270 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 271 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 272 | self.drop = nn.Dropout(p_dropout) 273 | 274 | def forward(self, x, x_mask): 275 | x = self.conv_1(self.padding(x * x_mask)) 276 | if self.activation == "gelu": 277 | x = x * torch.sigmoid(1.702 * x) 278 | else: 279 | x = torch.relu(x) 280 | x = self.drop(x) 281 | x = self.conv_2(self.padding(x * x_mask)) 282 | return x * x_mask 283 | 284 | def _causal_padding(self, x): 285 | if self.kernel_size == 1: 286 | return x 287 | pad_l = self.kernel_size - 1 288 | pad_r = 0 289 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 290 | x = F.pad(x, commons.convert_pad_shape(padding)) 291 | return x 292 | 293 | def _same_padding(self, x): 294 | if self.kernel_size == 1: 295 | return x 296 | pad_l = (self.kernel_size - 1) // 2 297 | pad_r = self.kernel_size // 2 298 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 299 | x = F.pad(x, commons.convert_pad_shape(padding)) 300 | return x 301 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 9 | from torch.nn.utils import weight_norm, remove_weight_norm 10 | 11 | import commons 12 | from commons import init_weights, get_padding 13 | from transforms import piecewise_rational_quadratic_transform 14 | 15 | 16 | LRELU_SLOPE = 0.1 17 | 18 | 19 | class LayerNorm(nn.Module): 20 | def __init__(self, channels, eps=1e-5): 21 | super().__init__() 22 | self.channels = channels 23 | self.eps = eps 24 | 25 | self.gamma = nn.Parameter(torch.ones(channels)) 26 | self.beta = nn.Parameter(torch.zeros(channels)) 27 | 28 | def forward(self, x): 29 | x = x.transpose(1, -1) 30 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 31 | return x.transpose(1, -1) 32 | 33 | 34 | class ConvReluNorm(nn.Module): 35 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 36 | super().__init__() 37 | self.in_channels = in_channels 38 | self.hidden_channels = hidden_channels 39 | self.out_channels = out_channels 40 | self.kernel_size = kernel_size 41 | self.n_layers = n_layers 42 | self.p_dropout = p_dropout 43 | assert n_layers > 1, "Number of layers should be larger than 0." 44 | 45 | self.conv_layers = nn.ModuleList() 46 | self.norm_layers = nn.ModuleList() 47 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 48 | self.norm_layers.append(LayerNorm(hidden_channels)) 49 | self.relu_drop = nn.Sequential( 50 | nn.ReLU(), 51 | nn.Dropout(p_dropout)) 52 | for _ in range(n_layers-1): 53 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 54 | self.norm_layers.append(LayerNorm(hidden_channels)) 55 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 56 | self.proj.weight.data.zero_() 57 | self.proj.bias.data.zero_() 58 | 59 | def forward(self, x, x_mask): 60 | x_org = x 61 | for i in range(self.n_layers): 62 | x = self.conv_layers[i](x * x_mask) 63 | x = self.norm_layers[i](x) 64 | x = self.relu_drop(x) 65 | x = x_org + self.proj(x) 66 | return x * x_mask 67 | 68 | 69 | class DDSConv(nn.Module): 70 | """ 71 | Dialted and Depth-Separable Convolution 72 | """ 73 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 74 | super().__init__() 75 | self.channels = channels 76 | self.kernel_size = kernel_size 77 | self.n_layers = n_layers 78 | self.p_dropout = p_dropout 79 | 80 | self.drop = nn.Dropout(p_dropout) 81 | self.convs_sep = nn.ModuleList() 82 | self.convs_1x1 = nn.ModuleList() 83 | self.norms_1 = nn.ModuleList() 84 | self.norms_2 = nn.ModuleList() 85 | for i in range(n_layers): 86 | dilation = kernel_size ** i 87 | padding = (kernel_size * dilation - dilation) // 2 88 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 89 | groups=channels, dilation=dilation, padding=padding 90 | )) 91 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 92 | self.norms_1.append(LayerNorm(channels)) 93 | self.norms_2.append(LayerNorm(channels)) 94 | 95 | def forward(self, x, x_mask, g=None): 96 | if g is not None: 97 | x = x + g 98 | for i in range(self.n_layers): 99 | y = self.convs_sep[i](x * x_mask) 100 | y = self.norms_1[i](y) 101 | y = F.gelu(y) 102 | y = self.convs_1x1[i](y) 103 | y = self.norms_2[i](y) 104 | y = F.gelu(y) 105 | y = self.drop(y) 106 | x = x + y 107 | return x * x_mask 108 | 109 | 110 | class WN(torch.nn.Module): 111 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 112 | super(WN, self).__init__() 113 | assert(kernel_size % 2 == 1) 114 | self.hidden_channels =hidden_channels 115 | self.kernel_size = kernel_size, 116 | self.dilation_rate = dilation_rate 117 | self.n_layers = n_layers 118 | self.gin_channels = gin_channels 119 | self.p_dropout = p_dropout 120 | 121 | self.in_layers = torch.nn.ModuleList() 122 | self.res_skip_layers = torch.nn.ModuleList() 123 | self.drop = nn.Dropout(p_dropout) 124 | 125 | if gin_channels != 0: 126 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 127 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 128 | 129 | for i in range(n_layers): 130 | dilation = dilation_rate ** i 131 | padding = int((kernel_size * dilation - dilation) / 2) 132 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, 133 | dilation=dilation, padding=padding) 134 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 135 | self.in_layers.append(in_layer) 136 | 137 | # last one is not necessary 138 | if i < n_layers - 1: 139 | res_skip_channels = 2 * hidden_channels 140 | else: 141 | res_skip_channels = hidden_channels 142 | 143 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 144 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 145 | self.res_skip_layers.append(res_skip_layer) 146 | 147 | def forward(self, x, x_mask, g=None, **kwargs): 148 | output = torch.zeros_like(x) 149 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 150 | 151 | if g is not None: 152 | g = self.cond_layer(g) 153 | 154 | for i in range(self.n_layers): 155 | x_in = self.in_layers[i](x) 156 | if g is not None: 157 | cond_offset = i * 2 * self.hidden_channels 158 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] 159 | else: 160 | g_l = torch.zeros_like(x_in) 161 | 162 | acts = commons.fused_add_tanh_sigmoid_multiply( 163 | x_in, 164 | g_l, 165 | n_channels_tensor) 166 | acts = self.drop(acts) 167 | 168 | res_skip_acts = self.res_skip_layers[i](acts) 169 | if i < self.n_layers - 1: 170 | res_acts = res_skip_acts[:,:self.hidden_channels,:] 171 | x = (x + res_acts) * x_mask 172 | output = output + res_skip_acts[:,self.hidden_channels:,:] 173 | else: 174 | output = output + res_skip_acts 175 | return output * x_mask 176 | 177 | def remove_weight_norm(self): 178 | if self.gin_channels != 0: 179 | torch.nn.utils.remove_weight_norm(self.cond_layer) 180 | for l in self.in_layers: 181 | torch.nn.utils.remove_weight_norm(l) 182 | for l in self.res_skip_layers: 183 | torch.nn.utils.remove_weight_norm(l) 184 | 185 | 186 | class ResBlock1(torch.nn.Module): 187 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 188 | super(ResBlock1, self).__init__() 189 | self.convs1 = nn.ModuleList([ 190 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 191 | padding=get_padding(kernel_size, dilation[0]))), 192 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 193 | padding=get_padding(kernel_size, dilation[1]))), 194 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 195 | padding=get_padding(kernel_size, dilation[2]))) 196 | ]) 197 | self.convs1.apply(init_weights) 198 | 199 | self.convs2 = nn.ModuleList([ 200 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 201 | padding=get_padding(kernel_size, 1))), 202 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 203 | padding=get_padding(kernel_size, 1))), 204 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 205 | padding=get_padding(kernel_size, 1))) 206 | ]) 207 | self.convs2.apply(init_weights) 208 | 209 | def forward(self, x, x_mask=None): 210 | for c1, c2 in zip(self.convs1, self.convs2): 211 | xt = F.leaky_relu(x, LRELU_SLOPE) 212 | if x_mask is not None: 213 | xt = xt * x_mask 214 | xt = c1(xt) 215 | xt = F.leaky_relu(xt, LRELU_SLOPE) 216 | if x_mask is not None: 217 | xt = xt * x_mask 218 | xt = c2(xt) 219 | x = xt + x 220 | if x_mask is not None: 221 | x = x * x_mask 222 | return x 223 | 224 | def remove_weight_norm(self): 225 | for l in self.convs1: 226 | remove_weight_norm(l) 227 | for l in self.convs2: 228 | remove_weight_norm(l) 229 | 230 | 231 | class ResBlock2(torch.nn.Module): 232 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 233 | super(ResBlock2, self).__init__() 234 | self.convs = nn.ModuleList([ 235 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 236 | padding=get_padding(kernel_size, dilation[0]))), 237 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 238 | padding=get_padding(kernel_size, dilation[1]))) 239 | ]) 240 | self.convs.apply(init_weights) 241 | 242 | def forward(self, x, x_mask=None): 243 | for c in self.convs: 244 | xt = F.leaky_relu(x, LRELU_SLOPE) 245 | if x_mask is not None: 246 | xt = xt * x_mask 247 | xt = c(xt) 248 | x = xt + x 249 | if x_mask is not None: 250 | x = x * x_mask 251 | return x 252 | 253 | def remove_weight_norm(self): 254 | for l in self.convs: 255 | remove_weight_norm(l) 256 | 257 | 258 | class Log(nn.Module): 259 | def forward(self, x, x_mask, reverse=False, **kwargs): 260 | if not reverse: 261 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 262 | logdet = torch.sum(-y, [1, 2]) 263 | return y, logdet 264 | else: 265 | x = torch.exp(x) * x_mask 266 | return x 267 | 268 | 269 | class Flip(nn.Module): 270 | def forward(self, x, *args, reverse=False, **kwargs): 271 | x = torch.flip(x, [1]) 272 | if not reverse: 273 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 274 | return x, logdet 275 | else: 276 | return x 277 | 278 | 279 | class ElementwiseAffine(nn.Module): 280 | def __init__(self, channels): 281 | super().__init__() 282 | self.channels = channels 283 | self.m = nn.Parameter(torch.zeros(channels,1)) 284 | self.logs = nn.Parameter(torch.zeros(channels,1)) 285 | 286 | def forward(self, x, x_mask, reverse=False, **kwargs): 287 | if not reverse: 288 | y = self.m + torch.exp(self.logs) * x 289 | y = y * x_mask 290 | logdet = torch.sum(self.logs * x_mask, [1,2]) 291 | return y, logdet 292 | else: 293 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 294 | return x 295 | 296 | 297 | class ResidualCouplingLayer(nn.Module): 298 | def __init__(self, 299 | channels, 300 | hidden_channels, 301 | kernel_size, 302 | dilation_rate, 303 | n_layers, 304 | p_dropout=0, 305 | gin_channels=0, 306 | mean_only=False): 307 | assert channels % 2 == 0, "channels should be divisible by 2" 308 | super().__init__() 309 | self.channels = channels 310 | self.hidden_channels = hidden_channels 311 | self.kernel_size = kernel_size 312 | self.dilation_rate = dilation_rate 313 | self.n_layers = n_layers 314 | self.half_channels = channels // 2 315 | self.mean_only = mean_only 316 | 317 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 318 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) 319 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 320 | self.post.weight.data.zero_() 321 | self.post.bias.data.zero_() 322 | 323 | def forward(self, x, x_mask, g=None, reverse=False): 324 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 325 | h = self.pre(x0) * x_mask 326 | h = self.enc(h, x_mask, g=g) 327 | stats = self.post(h) * x_mask 328 | if not self.mean_only: 329 | m, logs = torch.split(stats, [self.half_channels]*2, 1) 330 | else: 331 | m = stats 332 | logs = torch.zeros_like(m) 333 | 334 | if not reverse: 335 | x1 = m + x1 * torch.exp(logs) * x_mask 336 | x = torch.cat([x0, x1], 1) 337 | logdet = torch.sum(logs, [1,2]) 338 | return x, logdet 339 | else: 340 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 341 | x = torch.cat([x0, x1], 1) 342 | return x 343 | 344 | 345 | class ConvFlow(nn.Module): 346 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 347 | super().__init__() 348 | self.in_channels = in_channels 349 | self.filter_channels = filter_channels 350 | self.kernel_size = kernel_size 351 | self.n_layers = n_layers 352 | self.num_bins = num_bins 353 | self.tail_bound = tail_bound 354 | self.half_channels = in_channels // 2 355 | 356 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 357 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 358 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 359 | self.proj.weight.data.zero_() 360 | self.proj.bias.data.zero_() 361 | 362 | def forward(self, x, x_mask, g=None, reverse=False): 363 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 364 | h = self.pre(x0) 365 | h = self.convs(h, x_mask, g=g) 366 | h = self.proj(h) * x_mask 367 | 368 | b, c, t = x0.shape 369 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 370 | 371 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 372 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) 373 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 374 | 375 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 376 | unnormalized_widths, 377 | unnormalized_heights, 378 | unnormalized_derivatives, 379 | inverse=reverse, 380 | tails='linear', 381 | tail_bound=self.tail_bound 382 | ) 383 | 384 | x = torch.cat([x0, x1], 1) * x_mask 385 | logdet = torch.sum(logabsdet * x_mask, [1,2]) 386 | if not reverse: 387 | return x, logdet 388 | else: 389 | return x 390 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | import commons 8 | import modules 9 | import attentions 10 | import monotonic_align 11 | 12 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 13 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 14 | from commons import init_weights, get_padding 15 | 16 | class StochasticDurationPredictor(nn.Module): 17 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 18 | super().__init__() 19 | filter_channels = in_channels # it needs to be removed from future version. 20 | self.in_channels = in_channels 21 | self.filter_channels = filter_channels 22 | self.kernel_size = kernel_size 23 | self.p_dropout = p_dropout 24 | self.n_flows = n_flows 25 | self.gin_channels = gin_channels 26 | 27 | self.log_flow = modules.Log() 28 | self.flows = nn.ModuleList() 29 | self.flows.append(modules.ElementwiseAffine(2)) 30 | for i in range(n_flows): 31 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 32 | self.flows.append(modules.Flip()) 33 | 34 | self.post_pre = nn.Conv1d(1, filter_channels, 1) 35 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) 36 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 37 | self.post_flows = nn.ModuleList() 38 | self.post_flows.append(modules.ElementwiseAffine(2)) 39 | for i in range(4): 40 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 41 | self.post_flows.append(modules.Flip()) 42 | 43 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 44 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 45 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 46 | if gin_channels != 0: 47 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 48 | 49 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 50 | x = torch.detach(x) 51 | x = self.pre(x) 52 | if g is not None: 53 | g = torch.detach(g) 54 | x = x + self.cond(g) 55 | x = self.convs(x, x_mask) 56 | x = self.proj(x) * x_mask 57 | 58 | if not reverse: 59 | flows = self.flows 60 | assert w is not None 61 | 62 | logdet_tot_q = 0 63 | h_w = self.post_pre(w) 64 | h_w = self.post_convs(h_w, x_mask) 65 | h_w = self.post_proj(h_w) * x_mask 66 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 67 | z_q = e_q 68 | for flow in self.post_flows: 69 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 70 | logdet_tot_q += logdet_q 71 | z_u, z1 = torch.split(z_q, [1, 1], 1) 72 | u = torch.sigmoid(z_u) * x_mask 73 | z0 = (w - u) * x_mask 74 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) 75 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 76 | 77 | logdet_tot = 0 78 | z0, logdet = self.log_flow(z0, x_mask) 79 | logdet_tot += logdet 80 | z = torch.cat([z0, z1], 1) 81 | for flow in flows: 82 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 83 | logdet_tot = logdet_tot + logdet 84 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 85 | return nll + logq # [b] 86 | else: 87 | flows = list(reversed(self.flows)) 88 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 89 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 90 | for flow in flows: 91 | z = flow(z, x_mask, g=x, reverse=reverse) 92 | z0, z1 = torch.split(z, [1, 1], 1) 93 | logw = z0 94 | return logw 95 | 96 | 97 | class DurationPredictor(nn.Module): 98 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): 99 | super().__init__() 100 | 101 | self.in_channels = in_channels 102 | self.filter_channels = filter_channels 103 | self.kernel_size = kernel_size 104 | self.p_dropout = p_dropout 105 | self.gin_channels = gin_channels 106 | 107 | self.drop = nn.Dropout(p_dropout) 108 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) 109 | self.norm_1 = modules.LayerNorm(filter_channels) 110 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) 111 | self.norm_2 = modules.LayerNorm(filter_channels) 112 | self.proj = nn.Conv1d(filter_channels, 1, 1) 113 | 114 | if gin_channels != 0: 115 | self.cond = nn.Conv1d(gin_channels, in_channels, 1) 116 | 117 | def forward(self, x, x_mask, g=None): 118 | x = torch.detach(x) 119 | if g is not None: 120 | g = torch.detach(g) 121 | x = x + self.cond(g) 122 | x = self.conv_1(x * x_mask) 123 | x = torch.relu(x) 124 | x = self.norm_1(x) 125 | x = self.drop(x) 126 | x = self.conv_2(x * x_mask) 127 | x = torch.relu(x) 128 | x = self.norm_2(x) 129 | x = self.drop(x) 130 | x = self.proj(x * x_mask) 131 | return x * x_mask 132 | 133 | 134 | class TextEncoder(nn.Module): 135 | def __init__(self, 136 | n_vocab, 137 | out_channels, 138 | hidden_channels, 139 | filter_channels, 140 | n_heads, 141 | n_layers, 142 | kernel_size, 143 | p_dropout): 144 | super().__init__() 145 | self.n_vocab = n_vocab 146 | self.out_channels = out_channels 147 | self.hidden_channels = hidden_channels 148 | self.filter_channels = filter_channels 149 | self.n_heads = n_heads 150 | self.n_layers = n_layers 151 | self.kernel_size = kernel_size 152 | self.p_dropout = p_dropout 153 | 154 | self.emb = nn.Embedding(n_vocab, hidden_channels) 155 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 156 | 157 | self.encoder = attentions.Encoder( 158 | hidden_channels, 159 | filter_channels, 160 | n_heads, 161 | n_layers, 162 | kernel_size, 163 | p_dropout) 164 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) 165 | 166 | def forward(self, x, x_lengths): 167 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 168 | x = torch.transpose(x, 1, -1) # [b, h, t] 169 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 170 | 171 | x = self.encoder(x * x_mask, x_mask) 172 | stats = self.proj(x) * x_mask 173 | 174 | m, logs = torch.split(stats, self.out_channels, dim=1) 175 | return x, m, logs, x_mask 176 | 177 | 178 | class ResidualCouplingBlock(nn.Module): 179 | def __init__(self, 180 | channels, 181 | hidden_channels, 182 | kernel_size, 183 | dilation_rate, 184 | n_layers, 185 | n_flows=4, 186 | gin_channels=0): 187 | super().__init__() 188 | self.channels = channels 189 | self.hidden_channels = hidden_channels 190 | self.kernel_size = kernel_size 191 | self.dilation_rate = dilation_rate 192 | self.n_layers = n_layers 193 | self.n_flows = n_flows 194 | self.gin_channels = gin_channels 195 | 196 | self.flows = nn.ModuleList() 197 | for i in range(n_flows): 198 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) 199 | self.flows.append(modules.Flip()) 200 | 201 | def forward(self, x, x_mask, g=None, reverse=False): 202 | if not reverse: 203 | for flow in self.flows: 204 | x, _ = flow(x, x_mask, g=g, reverse=reverse) 205 | else: 206 | for flow in reversed(self.flows): 207 | x = flow(x, x_mask, g=g, reverse=reverse) 208 | return x 209 | 210 | 211 | class PosteriorEncoder(nn.Module): 212 | def __init__(self, 213 | in_channels, 214 | out_channels, 215 | hidden_channels, 216 | kernel_size, 217 | dilation_rate, 218 | n_layers, 219 | gin_channels=0): 220 | super().__init__() 221 | self.in_channels = in_channels 222 | self.out_channels = out_channels 223 | self.hidden_channels = hidden_channels 224 | self.kernel_size = kernel_size 225 | self.dilation_rate = dilation_rate 226 | self.n_layers = n_layers 227 | self.gin_channels = gin_channels 228 | 229 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 230 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) 231 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 232 | 233 | def forward(self, x, x_lengths, g=None): 234 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 235 | x = self.pre(x) * x_mask 236 | x = self.enc(x, x_mask, g=g) 237 | stats = self.proj(x) * x_mask 238 | m, logs = torch.split(stats, self.out_channels, dim=1) 239 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 240 | return z, m, logs, x_mask 241 | 242 | 243 | class Generator(torch.nn.Module): 244 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 245 | super(Generator, self).__init__() 246 | self.num_kernels = len(resblock_kernel_sizes) 247 | self.num_upsamples = len(upsample_rates) 248 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) 249 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 250 | 251 | self.ups = nn.ModuleList() 252 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 253 | self.ups.append(weight_norm( 254 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 255 | k, u, padding=(k-u)//2))) 256 | 257 | self.resblocks = nn.ModuleList() 258 | for i in range(len(self.ups)): 259 | ch = upsample_initial_channel//(2**(i+1)) 260 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 261 | self.resblocks.append(resblock(ch, k, d)) 262 | 263 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 264 | self.ups.apply(init_weights) 265 | 266 | if gin_channels != 0: 267 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 268 | 269 | def forward(self, x, g=None): 270 | x = self.conv_pre(x) 271 | if g is not None: 272 | x = x + self.cond(g) 273 | 274 | for i in range(self.num_upsamples): 275 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 276 | x = self.ups[i](x) 277 | xs = None 278 | for j in range(self.num_kernels): 279 | if xs is None: 280 | xs = self.resblocks[i*self.num_kernels+j](x) 281 | else: 282 | xs += self.resblocks[i*self.num_kernels+j](x) 283 | x = xs / self.num_kernels 284 | x = F.leaky_relu(x) 285 | x = self.conv_post(x) 286 | x = torch.tanh(x) 287 | 288 | return x 289 | 290 | def remove_weight_norm(self): 291 | print('Removing weight norm...') 292 | for l in self.ups: 293 | remove_weight_norm(l) 294 | for l in self.resblocks: 295 | l.remove_weight_norm() 296 | 297 | 298 | class DiscriminatorP(torch.nn.Module): 299 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 300 | super(DiscriminatorP, self).__init__() 301 | self.period = period 302 | self.use_spectral_norm = use_spectral_norm 303 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 304 | self.convs = nn.ModuleList([ 305 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 306 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 307 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 308 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 309 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 310 | ]) 311 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 312 | 313 | def forward(self, x): 314 | fmap = [] 315 | 316 | # 1d to 2d 317 | b, c, t = x.shape 318 | if t % self.period != 0: # pad first 319 | n_pad = self.period - (t % self.period) 320 | x = F.pad(x, (0, n_pad), "reflect") 321 | t = t + n_pad 322 | x = x.view(b, c, t // self.period, self.period) 323 | 324 | for l in self.convs: 325 | x = l(x) 326 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 327 | fmap.append(x) 328 | x = self.conv_post(x) 329 | fmap.append(x) 330 | x = torch.flatten(x, 1, -1) 331 | 332 | return x, fmap 333 | 334 | 335 | class DiscriminatorS(torch.nn.Module): 336 | def __init__(self, use_spectral_norm=False): 337 | super(DiscriminatorS, self).__init__() 338 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 339 | self.convs = nn.ModuleList([ 340 | norm_f(Conv1d(1, 16, 15, 1, padding=7)), 341 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), 342 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), 343 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 344 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 345 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 346 | ]) 347 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 348 | 349 | def forward(self, x): 350 | fmap = [] 351 | 352 | for l in self.convs: 353 | x = l(x) 354 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 355 | fmap.append(x) 356 | x = self.conv_post(x) 357 | fmap.append(x) 358 | x = torch.flatten(x, 1, -1) 359 | 360 | return x, fmap 361 | 362 | 363 | class MultiPeriodDiscriminator(torch.nn.Module): 364 | def __init__(self, use_spectral_norm=False): 365 | super(MultiPeriodDiscriminator, self).__init__() 366 | periods = [2,3,5,7,11] 367 | 368 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] 369 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 370 | self.discriminators = nn.ModuleList(discs) 371 | 372 | def forward(self, y, y_hat): 373 | y_d_rs = [] 374 | y_d_gs = [] 375 | fmap_rs = [] 376 | fmap_gs = [] 377 | for i, d in enumerate(self.discriminators): 378 | y_d_r, fmap_r = d(y) 379 | y_d_g, fmap_g = d(y_hat) 380 | y_d_rs.append(y_d_r) 381 | y_d_gs.append(y_d_g) 382 | fmap_rs.append(fmap_r) 383 | fmap_gs.append(fmap_g) 384 | 385 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 386 | 387 | 388 | 389 | class SynthesizerTrn(nn.Module): 390 | """ 391 | Synthesizer for Training 392 | """ 393 | 394 | def __init__(self, 395 | n_vocab, 396 | spec_channels, 397 | segment_size, 398 | inter_channels, 399 | hidden_channels, 400 | filter_channels, 401 | n_heads, 402 | n_layers, 403 | kernel_size, 404 | p_dropout, 405 | resblock, 406 | resblock_kernel_sizes, 407 | resblock_dilation_sizes, 408 | upsample_rates, 409 | upsample_initial_channel, 410 | upsample_kernel_sizes, 411 | n_speakers=0, 412 | gin_channels=0, 413 | use_sdp=True, 414 | **kwargs): 415 | 416 | super().__init__() 417 | self.n_vocab = n_vocab 418 | self.spec_channels = spec_channels 419 | self.inter_channels = inter_channels 420 | self.hidden_channels = hidden_channels 421 | self.filter_channels = filter_channels 422 | self.n_heads = n_heads 423 | self.n_layers = n_layers 424 | self.kernel_size = kernel_size 425 | self.p_dropout = p_dropout 426 | self.resblock = resblock 427 | self.resblock_kernel_sizes = resblock_kernel_sizes 428 | self.resblock_dilation_sizes = resblock_dilation_sizes 429 | self.upsample_rates = upsample_rates 430 | self.upsample_initial_channel = upsample_initial_channel 431 | self.upsample_kernel_sizes = upsample_kernel_sizes 432 | self.segment_size = segment_size 433 | self.n_speakers = n_speakers 434 | self.gin_channels = gin_channels 435 | 436 | self.use_sdp = use_sdp 437 | 438 | self.enc_p = TextEncoder(n_vocab, 439 | inter_channels, 440 | hidden_channels, 441 | filter_channels, 442 | n_heads, 443 | n_layers, 444 | kernel_size, 445 | p_dropout) 446 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) 447 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) 448 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) 449 | 450 | if use_sdp: 451 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) 452 | else: 453 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) 454 | 455 | if n_speakers > 1: 456 | self.emb_g = nn.Embedding(n_speakers, gin_channels) 457 | 458 | def forward(self, x, x_lengths, y, y_lengths, sid=None): 459 | 460 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) 461 | if self.n_speakers > 0: 462 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 463 | else: 464 | g = None 465 | 466 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) 467 | z_p = self.flow(z, y_mask, g=g) 468 | 469 | with torch.no_grad(): 470 | # negative cross-entropy 471 | s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] 472 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] 473 | neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] 474 | neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] 475 | neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] 476 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 477 | 478 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 479 | attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() 480 | 481 | w = attn.sum(2) 482 | if self.use_sdp: 483 | l_length = self.dp(x, x_mask, w, g=g) 484 | l_length = l_length / torch.sum(x_mask) 485 | else: 486 | logw_ = torch.log(w + 1e-6) * x_mask 487 | logw = self.dp(x, x_mask, g=g) 488 | l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging 489 | 490 | # expand prior 491 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) 492 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) 493 | 494 | z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) 495 | o = self.dec(z_slice, g=g) 496 | return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) 497 | 498 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): 499 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) 500 | if self.n_speakers > 0: 501 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 502 | else: 503 | g = None 504 | 505 | if self.use_sdp: 506 | logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) 507 | else: 508 | logw = self.dp(x, x_mask, g=g) 509 | w = torch.exp(logw) * x_mask * length_scale 510 | w_ceil = torch.ceil(w) 511 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 512 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) 513 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 514 | attn = commons.generate_path(w_ceil, attn_mask) 515 | 516 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 517 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 518 | 519 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale 520 | z = self.flow(z_p, y_mask, g=g, reverse=True) 521 | o = self.dec((z * y_mask)[:,:,:max_len], g=g) 522 | return o, attn, y_mask, (z, z_p, m_p, logs_p) 523 | 524 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): 525 | assert self.n_speakers > 0, "n_speakers have to be larger than 0." 526 | g_src = self.emb_g(sid_src).unsqueeze(-1) 527 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) 528 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) 529 | z_p = self.flow(z, y_mask, g=g_src) 530 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) 531 | o_hat = self.dec(z_hat * y_mask, g=g_tgt) 532 | return o_hat, y_mask, (z, z_p, z_hat) 533 | 534 | -------------------------------------------------------------------------------- /wavfile.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module to read / write wav files using NumPy arrays 3 | 4 | Functions 5 | --------- 6 | `read`: Return the sample rate (in samples/sec) and data from a WAV file. 7 | 8 | `write`: Write a NumPy array as a WAV file. 9 | 10 | """ 11 | import io 12 | import struct 13 | import sys 14 | import warnings 15 | from enum import IntEnum 16 | 17 | import numpy 18 | 19 | __all__ = ["WavFileWarning", "read", "write"] 20 | 21 | 22 | class WavFileWarning(UserWarning): 23 | pass 24 | 25 | 26 | class WAVE_FORMAT(IntEnum): 27 | """ 28 | WAVE form wFormatTag IDs 29 | 30 | Complete list is in mmreg.h in Windows 10 SDK. ALAC and OPUS are the 31 | newest additions, in v10.0.14393 2016-07 32 | """ 33 | 34 | UNKNOWN = 0x0000 35 | PCM = 0x0001 36 | ADPCM = 0x0002 37 | IEEE_FLOAT = 0x0003 38 | VSELP = 0x0004 39 | IBM_CVSD = 0x0005 40 | ALAW = 0x0006 41 | MULAW = 0x0007 42 | DTS = 0x0008 43 | DRM = 0x0009 44 | WMAVOICE9 = 0x000A 45 | WMAVOICE10 = 0x000B 46 | OKI_ADPCM = 0x0010 47 | DVI_ADPCM = 0x0011 48 | IMA_ADPCM = 0x0011 # Duplicate 49 | MEDIASPACE_ADPCM = 0x0012 50 | SIERRA_ADPCM = 0x0013 51 | G723_ADPCM = 0x0014 52 | DIGISTD = 0x0015 53 | DIGIFIX = 0x0016 54 | DIALOGIC_OKI_ADPCM = 0x0017 55 | MEDIAVISION_ADPCM = 0x0018 56 | CU_CODEC = 0x0019 57 | HP_DYN_VOICE = 0x001A 58 | YAMAHA_ADPCM = 0x0020 59 | SONARC = 0x0021 60 | DSPGROUP_TRUESPEECH = 0x0022 61 | ECHOSC1 = 0x0023 62 | AUDIOFILE_AF36 = 0x0024 63 | APTX = 0x0025 64 | AUDIOFILE_AF10 = 0x0026 65 | PROSODY_1612 = 0x0027 66 | LRC = 0x0028 67 | DOLBY_AC2 = 0x0030 68 | GSM610 = 0x0031 69 | MSNAUDIO = 0x0032 70 | ANTEX_ADPCME = 0x0033 71 | CONTROL_RES_VQLPC = 0x0034 72 | DIGIREAL = 0x0035 73 | DIGIADPCM = 0x0036 74 | CONTROL_RES_CR10 = 0x0037 75 | NMS_VBXADPCM = 0x0038 76 | CS_IMAADPCM = 0x0039 77 | ECHOSC3 = 0x003A 78 | ROCKWELL_ADPCM = 0x003B 79 | ROCKWELL_DIGITALK = 0x003C 80 | XEBEC = 0x003D 81 | G721_ADPCM = 0x0040 82 | G728_CELP = 0x0041 83 | MSG723 = 0x0042 84 | INTEL_G723_1 = 0x0043 85 | INTEL_G729 = 0x0044 86 | SHARP_G726 = 0x0045 87 | MPEG = 0x0050 88 | RT24 = 0x0052 89 | PAC = 0x0053 90 | MPEGLAYER3 = 0x0055 91 | LUCENT_G723 = 0x0059 92 | CIRRUS = 0x0060 93 | ESPCM = 0x0061 94 | VOXWARE = 0x0062 95 | CANOPUS_ATRAC = 0x0063 96 | G726_ADPCM = 0x0064 97 | G722_ADPCM = 0x0065 98 | DSAT = 0x0066 99 | DSAT_DISPLAY = 0x0067 100 | VOXWARE_BYTE_ALIGNED = 0x0069 101 | VOXWARE_AC8 = 0x0070 102 | VOXWARE_AC10 = 0x0071 103 | VOXWARE_AC16 = 0x0072 104 | VOXWARE_AC20 = 0x0073 105 | VOXWARE_RT24 = 0x0074 106 | VOXWARE_RT29 = 0x0075 107 | VOXWARE_RT29HW = 0x0076 108 | VOXWARE_VR12 = 0x0077 109 | VOXWARE_VR18 = 0x0078 110 | VOXWARE_TQ40 = 0x0079 111 | VOXWARE_SC3 = 0x007A 112 | VOXWARE_SC3_1 = 0x007B 113 | SOFTSOUND = 0x0080 114 | VOXWARE_TQ60 = 0x0081 115 | MSRT24 = 0x0082 116 | G729A = 0x0083 117 | MVI_MVI2 = 0x0084 118 | DF_G726 = 0x0085 119 | DF_GSM610 = 0x0086 120 | ISIAUDIO = 0x0088 121 | ONLIVE = 0x0089 122 | MULTITUDE_FT_SX20 = 0x008A 123 | INFOCOM_ITS_G721_ADPCM = 0x008B 124 | CONVEDIA_G729 = 0x008C 125 | CONGRUENCY = 0x008D 126 | SBC24 = 0x0091 127 | DOLBY_AC3_SPDIF = 0x0092 128 | MEDIASONIC_G723 = 0x0093 129 | PROSODY_8KBPS = 0x0094 130 | ZYXEL_ADPCM = 0x0097 131 | PHILIPS_LPCBB = 0x0098 132 | PACKED = 0x0099 133 | MALDEN_PHONYTALK = 0x00A0 134 | RACAL_RECORDER_GSM = 0x00A1 135 | RACAL_RECORDER_G720_A = 0x00A2 136 | RACAL_RECORDER_G723_1 = 0x00A3 137 | RACAL_RECORDER_TETRA_ACELP = 0x00A4 138 | NEC_AAC = 0x00B0 139 | RAW_AAC1 = 0x00FF 140 | RHETOREX_ADPCM = 0x0100 141 | IRAT = 0x0101 142 | VIVO_G723 = 0x0111 143 | VIVO_SIREN = 0x0112 144 | PHILIPS_CELP = 0x0120 145 | PHILIPS_GRUNDIG = 0x0121 146 | DIGITAL_G723 = 0x0123 147 | SANYO_LD_ADPCM = 0x0125 148 | SIPROLAB_ACEPLNET = 0x0130 149 | SIPROLAB_ACELP4800 = 0x0131 150 | SIPROLAB_ACELP8V3 = 0x0132 151 | SIPROLAB_G729 = 0x0133 152 | SIPROLAB_G729A = 0x0134 153 | SIPROLAB_KELVIN = 0x0135 154 | VOICEAGE_AMR = 0x0136 155 | G726ADPCM = 0x0140 156 | DICTAPHONE_CELP68 = 0x0141 157 | DICTAPHONE_CELP54 = 0x0142 158 | QUALCOMM_PUREVOICE = 0x0150 159 | QUALCOMM_HALFRATE = 0x0151 160 | TUBGSM = 0x0155 161 | MSAUDIO1 = 0x0160 162 | WMAUDIO2 = 0x0161 163 | WMAUDIO3 = 0x0162 164 | WMAUDIO_LOSSLESS = 0x0163 165 | WMASPDIF = 0x0164 166 | UNISYS_NAP_ADPCM = 0x0170 167 | UNISYS_NAP_ULAW = 0x0171 168 | UNISYS_NAP_ALAW = 0x0172 169 | UNISYS_NAP_16K = 0x0173 170 | SYCOM_ACM_SYC008 = 0x0174 171 | SYCOM_ACM_SYC701_G726L = 0x0175 172 | SYCOM_ACM_SYC701_CELP54 = 0x0176 173 | SYCOM_ACM_SYC701_CELP68 = 0x0177 174 | KNOWLEDGE_ADVENTURE_ADPCM = 0x0178 175 | FRAUNHOFER_IIS_MPEG2_AAC = 0x0180 176 | DTS_DS = 0x0190 177 | CREATIVE_ADPCM = 0x0200 178 | CREATIVE_FASTSPEECH8 = 0x0202 179 | CREATIVE_FASTSPEECH10 = 0x0203 180 | UHER_ADPCM = 0x0210 181 | ULEAD_DV_AUDIO = 0x0215 182 | ULEAD_DV_AUDIO_1 = 0x0216 183 | QUARTERDECK = 0x0220 184 | ILINK_VC = 0x0230 185 | RAW_SPORT = 0x0240 186 | ESST_AC3 = 0x0241 187 | GENERIC_PASSTHRU = 0x0249 188 | IPI_HSX = 0x0250 189 | IPI_RPELP = 0x0251 190 | CS2 = 0x0260 191 | SONY_SCX = 0x0270 192 | SONY_SCY = 0x0271 193 | SONY_ATRAC3 = 0x0272 194 | SONY_SPC = 0x0273 195 | TELUM_AUDIO = 0x0280 196 | TELUM_IA_AUDIO = 0x0281 197 | NORCOM_VOICE_SYSTEMS_ADPCM = 0x0285 198 | FM_TOWNS_SND = 0x0300 199 | MICRONAS = 0x0350 200 | MICRONAS_CELP833 = 0x0351 201 | BTV_DIGITAL = 0x0400 202 | INTEL_MUSIC_CODER = 0x0401 203 | INDEO_AUDIO = 0x0402 204 | QDESIGN_MUSIC = 0x0450 205 | ON2_VP7_AUDIO = 0x0500 206 | ON2_VP6_AUDIO = 0x0501 207 | VME_VMPCM = 0x0680 208 | TPC = 0x0681 209 | LIGHTWAVE_LOSSLESS = 0x08AE 210 | OLIGSM = 0x1000 211 | OLIADPCM = 0x1001 212 | OLICELP = 0x1002 213 | OLISBC = 0x1003 214 | OLIOPR = 0x1004 215 | LH_CODEC = 0x1100 216 | LH_CODEC_CELP = 0x1101 217 | LH_CODEC_SBC8 = 0x1102 218 | LH_CODEC_SBC12 = 0x1103 219 | LH_CODEC_SBC16 = 0x1104 220 | NORRIS = 0x1400 221 | ISIAUDIO_2 = 0x1401 222 | SOUNDSPACE_MUSICOMPRESS = 0x1500 223 | MPEG_ADTS_AAC = 0x1600 224 | MPEG_RAW_AAC = 0x1601 225 | MPEG_LOAS = 0x1602 226 | NOKIA_MPEG_ADTS_AAC = 0x1608 227 | NOKIA_MPEG_RAW_AAC = 0x1609 228 | VODAFONE_MPEG_ADTS_AAC = 0x160A 229 | VODAFONE_MPEG_RAW_AAC = 0x160B 230 | MPEG_HEAAC = 0x1610 231 | VOXWARE_RT24_SPEECH = 0x181C 232 | SONICFOUNDRY_LOSSLESS = 0x1971 233 | INNINGS_TELECOM_ADPCM = 0x1979 234 | LUCENT_SX8300P = 0x1C07 235 | LUCENT_SX5363S = 0x1C0C 236 | CUSEEME = 0x1F03 237 | NTCSOFT_ALF2CM_ACM = 0x1FC4 238 | DVM = 0x2000 239 | DTS2 = 0x2001 240 | MAKEAVIS = 0x3313 241 | DIVIO_MPEG4_AAC = 0x4143 242 | NOKIA_ADAPTIVE_MULTIRATE = 0x4201 243 | DIVIO_G726 = 0x4243 244 | LEAD_SPEECH = 0x434C 245 | LEAD_VORBIS = 0x564C 246 | WAVPACK_AUDIO = 0x5756 247 | OGG_VORBIS_MODE_1 = 0x674F 248 | OGG_VORBIS_MODE_2 = 0x6750 249 | OGG_VORBIS_MODE_3 = 0x6751 250 | OGG_VORBIS_MODE_1_PLUS = 0x676F 251 | OGG_VORBIS_MODE_2_PLUS = 0x6770 252 | OGG_VORBIS_MODE_3_PLUS = 0x6771 253 | ALAC = 0x6C61 254 | _3COM_NBX = 0x7000 # Can't have leading digit 255 | OPUS = 0x704F 256 | FAAD_AAC = 0x706D 257 | AMR_NB = 0x7361 258 | AMR_WB = 0x7362 259 | AMR_WP = 0x7363 260 | GSM_AMR_CBR = 0x7A21 261 | GSM_AMR_VBR_SID = 0x7A22 262 | COMVERSE_INFOSYS_G723_1 = 0xA100 263 | COMVERSE_INFOSYS_AVQSBC = 0xA101 264 | COMVERSE_INFOSYS_SBC = 0xA102 265 | SYMBOL_G729_A = 0xA103 266 | VOICEAGE_AMR_WB = 0xA104 267 | INGENIENT_G726 = 0xA105 268 | MPEG4_AAC = 0xA106 269 | ENCORE_G726 = 0xA107 270 | ZOLL_ASAO = 0xA108 271 | SPEEX_VOICE = 0xA109 272 | VIANIX_MASC = 0xA10A 273 | WM9_SPECTRUM_ANALYZER = 0xA10B 274 | WMF_SPECTRUM_ANAYZER = 0xA10C 275 | GSM_610 = 0xA10D 276 | GSM_620 = 0xA10E 277 | GSM_660 = 0xA10F 278 | GSM_690 = 0xA110 279 | GSM_ADAPTIVE_MULTIRATE_WB = 0xA111 280 | POLYCOM_G722 = 0xA112 281 | POLYCOM_G728 = 0xA113 282 | POLYCOM_G729_A = 0xA114 283 | POLYCOM_SIREN = 0xA115 284 | GLOBAL_IP_ILBC = 0xA116 285 | RADIOTIME_TIME_SHIFT_RADIO = 0xA117 286 | NICE_ACA = 0xA118 287 | NICE_ADPCM = 0xA119 288 | VOCORD_G721 = 0xA11A 289 | VOCORD_G726 = 0xA11B 290 | VOCORD_G722_1 = 0xA11C 291 | VOCORD_G728 = 0xA11D 292 | VOCORD_G729 = 0xA11E 293 | VOCORD_G729_A = 0xA11F 294 | VOCORD_G723_1 = 0xA120 295 | VOCORD_LBC = 0xA121 296 | NICE_G728 = 0xA122 297 | FRACE_TELECOM_G729 = 0xA123 298 | CODIAN = 0xA124 299 | FLAC = 0xF1AC 300 | EXTENSIBLE = 0xFFFE 301 | DEVELOPMENT = 0xFFFF 302 | 303 | 304 | KNOWN_WAVE_FORMATS = {WAVE_FORMAT.PCM, WAVE_FORMAT.IEEE_FLOAT} 305 | 306 | 307 | def _raise_bad_format(format_tag): 308 | try: 309 | format_name = WAVE_FORMAT(format_tag).name 310 | except ValueError: 311 | format_name = f"{format_tag:#06x}" 312 | raise ValueError( 313 | f"Unknown wave file format: {format_name}. Supported " 314 | "formats: " + ", ".join(x.name for x in KNOWN_WAVE_FORMATS) 315 | ) 316 | 317 | 318 | def _read_fmt_chunk(fid, is_big_endian): 319 | """ 320 | Returns 321 | ------- 322 | size : int 323 | size of format subchunk in bytes (minus 8 for "fmt " and itself) 324 | format_tag : int 325 | PCM, float, or compressed format 326 | channels : int 327 | number of channels 328 | fs : int 329 | sampling frequency in samples per second 330 | bytes_per_second : int 331 | overall byte rate for the file 332 | block_align : int 333 | bytes per sample, including all channels 334 | bit_depth : int 335 | bits per sample 336 | 337 | Notes 338 | ----- 339 | Assumes file pointer is immediately after the 'fmt ' id 340 | """ 341 | if is_big_endian: 342 | fmt = ">" 343 | else: 344 | fmt = "<" 345 | 346 | size = struct.unpack(fmt + "I", fid.read(4))[0] 347 | 348 | if size < 16: 349 | raise ValueError("Binary structure of wave file is not compliant") 350 | 351 | res = struct.unpack(fmt + "HHIIHH", fid.read(16)) 352 | bytes_read = 16 353 | 354 | format_tag, channels, fs, bytes_per_second, block_align, bit_depth = res 355 | 356 | if format_tag == WAVE_FORMAT.EXTENSIBLE and size >= (16 + 2): 357 | ext_chunk_size = struct.unpack(fmt + "H", fid.read(2))[0] 358 | bytes_read += 2 359 | if ext_chunk_size >= 22: 360 | extensible_chunk_data = fid.read(22) 361 | bytes_read += 22 362 | raw_guid = extensible_chunk_data[2 + 4 : 2 + 4 + 16] 363 | # GUID template {XXXXXXXX-0000-0010-8000-00AA00389B71} (RFC-2361) 364 | # MS GUID byte order: first three groups are native byte order, 365 | # rest is Big Endian 366 | if is_big_endian: 367 | tail = b"\x00\x00\x00\x10\x80\x00\x00\xAA\x00\x38\x9B\x71" 368 | else: 369 | tail = b"\x00\x00\x10\x00\x80\x00\x00\xAA\x00\x38\x9B\x71" 370 | if raw_guid.endswith(tail): 371 | format_tag = struct.unpack(fmt + "I", raw_guid[:4])[0] 372 | else: 373 | raise ValueError("Binary structure of wave file is not compliant") 374 | 375 | if format_tag not in KNOWN_WAVE_FORMATS: 376 | _raise_bad_format(format_tag) 377 | 378 | # move file pointer to next chunk 379 | if size > bytes_read: 380 | fid.read(size - bytes_read) 381 | 382 | # fmt should always be 16, 18 or 40, but handle it just in case 383 | _handle_pad_byte(fid, size) 384 | 385 | return (size, format_tag, channels, fs, bytes_per_second, block_align, bit_depth) 386 | 387 | 388 | def _read_data_chunk( 389 | fid, format_tag, channels, bit_depth, is_big_endian, block_align, mmap=False 390 | ): 391 | """ 392 | Notes 393 | ----- 394 | Assumes file pointer is immediately after the 'data' id 395 | 396 | It's possible to not use all available bits in a container, or to store 397 | samples in a container bigger than necessary, so bytes_per_sample uses 398 | the actual reported container size (nBlockAlign / nChannels). Real-world 399 | examples: 400 | 401 | Adobe Audition's "24-bit packed int (type 1, 20-bit)" 402 | 403 | nChannels = 2, nBlockAlign = 6, wBitsPerSample = 20 404 | 405 | http://www-mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/Samples/AFsp/M1F1-int12-AFsp.wav 406 | is: 407 | 408 | nChannels = 2, nBlockAlign = 4, wBitsPerSample = 12 409 | 410 | http://www-mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/Docs/multichaudP.pdf 411 | gives an example of: 412 | 413 | nChannels = 2, nBlockAlign = 8, wBitsPerSample = 20 414 | """ 415 | if is_big_endian: 416 | fmt = ">" 417 | else: 418 | fmt = "<" 419 | 420 | # Size of the data subchunk in bytes 421 | size = struct.unpack(fmt + "I", fid.read(4))[0] 422 | 423 | # Number of bytes per sample (sample container size) 424 | bytes_per_sample = block_align // channels 425 | n_samples = size // bytes_per_sample 426 | 427 | if format_tag == WAVE_FORMAT.PCM: 428 | if 1 <= bit_depth <= 8: 429 | dtype = "u1" # WAV of 8-bit integer or less are unsigned 430 | elif bytes_per_sample in {3, 5, 6, 7}: 431 | # No compatible dtype. Load as raw bytes for reshaping later. 432 | dtype = "V1" 433 | elif bit_depth <= 64: 434 | # Remaining bit depths can map directly to signed numpy dtypes 435 | dtype = f"{fmt}i{bytes_per_sample}" 436 | else: 437 | raise ValueError( 438 | "Unsupported bit depth: the WAV file " 439 | f"has {bit_depth}-bit integer data." 440 | ) 441 | elif format_tag == WAVE_FORMAT.IEEE_FLOAT: 442 | if bit_depth in {32, 64}: 443 | dtype = f"{fmt}f{bytes_per_sample}" 444 | else: 445 | raise ValueError( 446 | "Unsupported bit depth: the WAV file " 447 | f"has {bit_depth}-bit floating-point data." 448 | ) 449 | else: 450 | _raise_bad_format(format_tag) 451 | 452 | start = fid.tell() 453 | if not mmap: 454 | try: 455 | count = size if dtype == "V1" else n_samples 456 | data = numpy.fromfile(fid, dtype=dtype, count=count) 457 | except io.UnsupportedOperation: # not a C-like file 458 | fid.seek(start, 0) # just in case it seeked, though it shouldn't 459 | data = numpy.frombuffer(fid.read(size), dtype=dtype) 460 | 461 | if dtype == "V1": 462 | # Rearrange raw bytes into smallest compatible numpy dtype 463 | dt = f"{fmt}i4" if bytes_per_sample == 3 else f"{fmt}i8" 464 | a = numpy.zeros( 465 | (len(data) // bytes_per_sample, numpy.dtype(dt).itemsize), dtype="V1" 466 | ) 467 | if is_big_endian: 468 | a[:, :bytes_per_sample] = data.reshape((-1, bytes_per_sample)) 469 | else: 470 | a[:, -bytes_per_sample:] = data.reshape((-1, bytes_per_sample)) 471 | data = a.view(dt).reshape(a.shape[:-1]) 472 | else: 473 | if bytes_per_sample in {1, 2, 4, 8}: 474 | start = fid.tell() 475 | data = numpy.memmap( 476 | fid, dtype=dtype, mode="c", offset=start, shape=(n_samples,) 477 | ) 478 | fid.seek(start + size) 479 | else: 480 | raise ValueError( 481 | "mmap=True not compatible with " 482 | f"{bytes_per_sample}-byte container size." 483 | ) 484 | 485 | _handle_pad_byte(fid, size) 486 | 487 | if channels > 1: 488 | data = data.reshape(-1, channels) 489 | return data 490 | 491 | 492 | def _skip_unknown_chunk(fid, is_big_endian): 493 | if is_big_endian: 494 | fmt = ">I" 495 | else: 496 | fmt = ">> from os.path import dirname, join as pjoin 613 | >>> from scipy.io import wavfile 614 | >>> import scipy.io 615 | 616 | Get the filename for an example .wav file from the tests/data directory. 617 | 618 | >>> data_dir = pjoin(dirname(scipy.io.__file__), 'tests', 'data') 619 | >>> wav_fname = pjoin(data_dir, 'test-44100Hz-2ch-32bit-float-be.wav') 620 | 621 | Load the .wav file contents. 622 | 623 | >>> samplerate, data = wavfile.read(wav_fname) 624 | >>> print(f"number of channels = {data.shape[1]}") 625 | number of channels = 2 626 | >>> length = data.shape[0] / samplerate 627 | >>> print(f"length = {length}s") 628 | length = 0.01s 629 | 630 | Plot the waveform. 631 | 632 | >>> import matplotlib.pyplot as plt 633 | >>> import numpy as np 634 | >>> time = np.linspace(0., length, data.shape[0]) 635 | >>> plt.plot(time, data[:, 0], label="Left channel") 636 | >>> plt.plot(time, data[:, 1], label="Right channel") 637 | >>> plt.legend() 638 | >>> plt.xlabel("Time [s]") 639 | >>> plt.ylabel("Amplitude") 640 | >>> plt.show() 641 | 642 | """ 643 | if hasattr(filename, "read"): 644 | fid = filename 645 | mmap = False 646 | else: 647 | # pylint: disable=consider-using-with 648 | fid = open(filename, "rb") 649 | 650 | try: 651 | file_size, is_big_endian = _read_riff_chunk(fid) 652 | fmt_chunk_received = False 653 | data_chunk_received = False 654 | while fid.tell() < file_size: 655 | # read the next chunk 656 | chunk_id = fid.read(4) 657 | 658 | if not chunk_id: 659 | if data_chunk_received: 660 | # End of file but data successfully read 661 | warnings.warn( 662 | f"Reached EOF prematurely; finished at {fid.tell()} bytes, " 663 | "expected {file_size} bytes from header.", 664 | WavFileWarning, 665 | stacklevel=2, 666 | ) 667 | break 668 | 669 | raise ValueError("Unexpected end of file.") 670 | if len(chunk_id) < 4: 671 | msg = f"Incomplete chunk ID: {repr(chunk_id)}" 672 | # If we have the data, ignore the broken chunk 673 | if fmt_chunk_received and data_chunk_received: 674 | warnings.warn(msg + ", ignoring it.", WavFileWarning, stacklevel=2) 675 | else: 676 | raise ValueError(msg) 677 | 678 | if chunk_id == b"fmt ": 679 | fmt_chunk_received = True 680 | fmt_chunk = _read_fmt_chunk(fid, is_big_endian) 681 | format_tag, channels, fs = fmt_chunk[1:4] 682 | bit_depth = fmt_chunk[6] 683 | block_align = fmt_chunk[5] 684 | elif chunk_id == b"fact": 685 | _skip_unknown_chunk(fid, is_big_endian) 686 | elif chunk_id == b"data": 687 | data_chunk_received = True 688 | if not fmt_chunk_received: 689 | raise ValueError("No fmt chunk before data") 690 | data = _read_data_chunk( 691 | fid, 692 | format_tag, 693 | channels, 694 | bit_depth, 695 | is_big_endian, 696 | block_align, 697 | mmap, 698 | ) 699 | elif chunk_id == b"LIST": 700 | # Someday this could be handled properly but for now skip it 701 | _skip_unknown_chunk(fid, is_big_endian) 702 | elif chunk_id in {b"JUNK", b"Fake"}: 703 | # Skip alignment chunks without warning 704 | _skip_unknown_chunk(fid, is_big_endian) 705 | else: 706 | warnings.warn( 707 | "Chunk (non-data) not understood, skipping it.", 708 | WavFileWarning, 709 | stacklevel=2, 710 | ) 711 | _skip_unknown_chunk(fid, is_big_endian) 712 | finally: 713 | if not hasattr(filename, "read"): 714 | fid.close() 715 | else: 716 | fid.seek(0) 717 | 718 | return fs, data 719 | 720 | 721 | def write(filename, rate, data): 722 | """ 723 | Write a NumPy array as a WAV file. 724 | 725 | Parameters 726 | ---------- 727 | filename : string or open file handle 728 | Output wav file. 729 | rate : int 730 | The sample rate (in samples/sec). 731 | data : ndarray 732 | A 1-D or 2-D NumPy array of either integer or float data-type. 733 | 734 | Notes 735 | ----- 736 | * Writes a simple uncompressed WAV file. 737 | * To write multiple-channels, use a 2-D array of shape 738 | (Nsamples, Nchannels). 739 | * The bits-per-sample and PCM/float will be determined by the data-type. 740 | 741 | Common data types: [1]_ 742 | 743 | ===================== =========== =========== ============= 744 | WAV format Min Max NumPy dtype 745 | ===================== =========== =========== ============= 746 | 32-bit floating-point -1.0 +1.0 float32 747 | 32-bit PCM -2147483648 +2147483647 int32 748 | 16-bit PCM -32768 +32767 int16 749 | 8-bit PCM 0 255 uint8 750 | ===================== =========== =========== ============= 751 | 752 | Note that 8-bit PCM is unsigned. 753 | 754 | References 755 | ---------- 756 | .. [1] IBM Corporation and Microsoft Corporation, "Multimedia Programming 757 | Interface and Data Specifications 1.0", section "Data Format of the 758 | Samples", August 1991 759 | http://www.tactilemedia.com/info/MCI_Control_Info.html 760 | 761 | Examples 762 | -------- 763 | Create a 100Hz sine wave, sampled at 44100Hz. 764 | Write to 16-bit PCM, Mono. 765 | 766 | >>> from scipy.io.wavfile import write 767 | >>> samplerate = 44100; fs = 100 768 | >>> t = np.linspace(0., 1., samplerate) 769 | >>> amplitude = np.iinfo(np.int16).max 770 | >>> data = amplitude * np.sin(2. * np.pi * fs * t) 771 | >>> write("example.wav", samplerate, data.astype(np.int16)) 772 | 773 | """ 774 | if hasattr(filename, "write"): 775 | fid = filename 776 | else: 777 | # pylint: disable=consider-using-with 778 | fid = open(filename, "wb") 779 | 780 | fs = rate 781 | 782 | try: 783 | dkind = data.dtype.kind 784 | if not ( 785 | dkind == "i" or dkind == "f" or (dkind == "u" and data.dtype.itemsize == 1) 786 | ): 787 | raise ValueError(f"Unsupported data type '{data.dtype}'") 788 | 789 | header_data = b"" 790 | 791 | header_data += b"RIFF" 792 | header_data += b"\x00\x00\x00\x00" 793 | header_data += b"WAVE" 794 | 795 | # fmt chunk 796 | header_data += b"fmt " 797 | if dkind == "f": 798 | format_tag = WAVE_FORMAT.IEEE_FLOAT 799 | else: 800 | format_tag = WAVE_FORMAT.PCM 801 | if data.ndim == 1: 802 | channels = 1 803 | else: 804 | channels = data.shape[1] 805 | bit_depth = data.dtype.itemsize * 8 806 | bytes_per_second = fs * (bit_depth // 8) * channels 807 | block_align = channels * (bit_depth // 8) 808 | 809 | fmt_chunk_data = struct.pack( 810 | " 0xFFFFFFFF: 832 | raise ValueError("Data exceeds wave file size limit") 833 | 834 | fid.write(header_data) 835 | 836 | # data chunk 837 | fid.write(b"data") 838 | fid.write(struct.pack("" or ( 840 | data.dtype.byteorder == "=" and sys.byteorder == "big" 841 | ): 842 | data = data.byteswap() 843 | _array_tofile(fid, data) 844 | 845 | # Determine file size and place it in correct 846 | # position at start of the file. 847 | size = fid.tell() 848 | fid.seek(4) 849 | fid.write(struct.pack("