├── packages.txt
├── requirements.txt
├── km_phonemizer.npz
├── cog.yaml
├── .github
└── workflows
│ └── build_cog.yml
├── LICENSE
├── losses.py
├── config.json
├── khmer_phonemizer.py
├── .gitattributes
├── README.md
├── infer.py
├── export_onnx.py
├── infer_onnx.py
├── app.py
├── mel_processing.py
├── .gitignore
├── commons.py
├── utils.py
├── g2p.py
├── transforms.py
├── attentions.py
├── modules.py
├── models.py
└── wavfile.py
/packages.txt:
--------------------------------------------------------------------------------
1 | libsndfile1
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.25.2
2 | librosa
3 | scipy
4 | torch
5 | torchvision
6 | Unidecode
7 | gradio
8 | khmernormalizer
9 | monotonic-align
--------------------------------------------------------------------------------
/km_phonemizer.npz:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:b97935ae20ae5533baa2096a48d0728bf90291a488fceccbf7e700018bc38ffe
3 | size 34352402
4 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | gpu: true
3 | system_packages:
4 | - libsndfile1
5 | - ffmpeg
6 | python_packages:
7 | - numpy==1.25.2
8 | - librosa
9 | - scipy
10 | - torch==2.1.1
11 | - Unidecode
12 | - gradio
13 | - khmernormalizer
14 | - monotonic-align
--------------------------------------------------------------------------------
/.github/workflows/build_cog.yml:
--------------------------------------------------------------------------------
1 | name: Build Cog
2 |
3 | on:
4 | workflow_dispatch:
5 | push:
6 | branches:
7 | - main
8 |
9 | jobs:
10 | build:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Free Disk Space (Ubuntu)
14 | uses: jlumbroso/free-disk-space@main
15 | with:
16 | tool-cache: false
17 | android: false
18 | dotnet: false
19 | haskell: false
20 | large-packages: true
21 | docker-images: true
22 | swap-storage: true
23 |
24 | - name: Check out code
25 | uses: actions/checkout@v3
26 |
27 | - name: Setup Cog
28 | uses: replicate/setup-cog@v1
29 |
30 | - name: download weight
31 | run: curl -L https://huggingface.co/spaces/seanghay/KLEA/resolve/main/G_60000.pth -o G_60000.pth
32 |
33 | - name: Build
34 | run: |
35 | cog build
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Seanghay Yath
4 | Copyright (c) 2021 Jaehyeon Kim
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def feature_loss(fmap_r, fmap_g):
4 | loss = 0
5 | for dr, dg in zip(fmap_r, fmap_g):
6 | for rl, gl in zip(dr, dg):
7 | rl = rl.float().detach()
8 | gl = gl.float()
9 | loss += torch.mean(torch.abs(rl - gl))
10 |
11 | return loss * 2
12 |
13 |
14 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
15 | loss = 0
16 | r_losses = []
17 | g_losses = []
18 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
19 | dr = dr.float()
20 | dg = dg.float()
21 | r_loss = torch.mean((1-dr)**2)
22 | g_loss = torch.mean(dg**2)
23 | loss += (r_loss + g_loss)
24 | r_losses.append(r_loss.item())
25 | g_losses.append(g_loss.item())
26 |
27 | return loss, r_losses, g_losses
28 |
29 |
30 | def generator_loss(disc_outputs):
31 | loss = 0
32 | gen_losses = []
33 | for dg in disc_outputs:
34 | dg = dg.float()
35 | l = torch.mean((1-dg)**2)
36 | gen_losses.append(l)
37 | loss += l
38 |
39 | return loss, gen_losses
40 |
41 |
42 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
43 | """
44 | z_p, logs_q: [b, h, t_t]
45 | m_p, logs_p: [b, h, t_t]
46 | """
47 | z_p = z_p.float()
48 | logs_q = logs_q.float()
49 | m_p = m_p.float()
50 | logs_p = logs_p.float()
51 | z_mask = z_mask.float()
52 |
53 | kl = logs_p - logs_q - 0.5
54 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
55 | kl = torch.sum(kl * z_mask)
56 | l = kl / torch.sum(z_mask)
57 | return l
58 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "train": {
3 | "log_interval": 200,
4 | "eval_interval": 1000,
5 | "seed": 1234,
6 | "epochs": 20000,
7 | "learning_rate": 2e-4,
8 | "betas": [0.8, 0.99],
9 | "eps": 1e-9,
10 | "batch_size": 32,
11 | "fp16_run": true,
12 | "lr_decay": 0.999875,
13 | "segment_size": 8192,
14 | "init_lr_ratio": 1,
15 | "warmup_epochs": 0,
16 | "c_mel": 45,
17 | "c_kl": 1.0
18 | },
19 | "data": {
20 | "training_files":"filelists/kheng_text_train_filelist.txt.cleaned",
21 | "validation_files":"filelists/kheng_text_test_filelist.txt.cleaned",
22 | "text_cleaners":["english_cleaners2"],
23 | "max_wav_value": 32768.0,
24 | "sampling_rate": 22050,
25 | "filter_length": 1024,
26 | "hop_length": 256,
27 | "win_length": 1024,
28 | "n_mel_channels": 80,
29 | "mel_fmin": 0.0,
30 | "mel_fmax": null,
31 | "add_blank": true,
32 | "n_speakers": 0,
33 | "cleaned_text": true
34 | },
35 | "model": {
36 | "inter_channels": 192,
37 | "hidden_channels": 192,
38 | "filter_channels": 768,
39 | "n_heads": 2,
40 | "n_layers": 6,
41 | "kernel_size": 3,
42 | "p_dropout": 0.1,
43 | "resblock": "1",
44 | "resblock_kernel_sizes": [3,7,11],
45 | "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
46 | "upsample_rates": [8,8,2,2],
47 | "upsample_initial_channel": 512,
48 | "upsample_kernel_sizes": [16,16,4,4],
49 | "n_layers_q": 3,
50 | "use_spectral_norm": false
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/khmer_phonemizer.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Khmer Phonemizer - A Free, Standalone and Open-Source Khmer Grapheme-to-Phonemes.
3 | """
4 | import os
5 | import csv
6 | from g2p import PhonetisaurusGraph
7 |
8 | def _read_lexicon_file(file):
9 | lexicon = {}
10 | with open(file) as infile:
11 | for line in csv.reader(infile, delimiter="\t"):
12 | word, phonemes = line
13 | word, phonemes = word.strip(), phonemes.strip().split()
14 | lexicon[word] = phonemes
15 | return lexicon
16 |
17 | _graph_file = os.path.join(os.path.dirname(__file__), "km_phonemizer.npz")
18 | _lexicon_file = os.path.join(os.path.dirname(__file__), "km_lexicon.tsv")
19 | _lexicon_dict = _read_lexicon_file(_lexicon_file)
20 | _graph = PhonetisaurusGraph.load(_graph_file, preload=False)
21 |
22 | def _phoneticize(word: str, beam: int, min_beam: int, beam_scale: float):
23 | results = _graph.g2p_one(word, beam=beam, min_beam=min_beam, beam_scale=beam_scale)
24 | results = list(results)
25 | if len(results) == 0:
26 | return None
27 | return results[0]
28 |
29 |
30 | def phonemize_single(
31 | word,
32 | beam: int = 500,
33 | min_beam: int = 100,
34 | beam_scale: float = 0.6,
35 | use_lexicon: bool = True,
36 | ):
37 | r"""
38 | Phonemize a single word. The word must match [a-zA-Z\u1780-\u17dd]+
39 | """
40 | if word is None:
41 | return None
42 | word = word.lower()
43 | if use_lexicon and word in _lexicon_dict:
44 | return _lexicon_dict[word]
45 | return _phoneticize(word, beam=beam, min_beam=min_beam, beam_scale=beam_scale)
46 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.7z filter=lfs diff=lfs merge=lfs -text
2 | *.arrow filter=lfs diff=lfs merge=lfs -text
3 | *.bin filter=lfs diff=lfs merge=lfs -text
4 | *.bz2 filter=lfs diff=lfs merge=lfs -text
5 | *.ckpt filter=lfs diff=lfs merge=lfs -text
6 | *.ftz filter=lfs diff=lfs merge=lfs -text
7 | *.gz filter=lfs diff=lfs merge=lfs -text
8 | *.h5 filter=lfs diff=lfs merge=lfs -text
9 | *.joblib filter=lfs diff=lfs merge=lfs -text
10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text
11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text
12 | *.model filter=lfs diff=lfs merge=lfs -text
13 | *.msgpack filter=lfs diff=lfs merge=lfs -text
14 | *.npy filter=lfs diff=lfs merge=lfs -text
15 | *.npz filter=lfs diff=lfs merge=lfs -text
16 | *.onnx filter=lfs diff=lfs merge=lfs -text
17 | *.ot filter=lfs diff=lfs merge=lfs -text
18 | *.parquet filter=lfs diff=lfs merge=lfs -text
19 | *.pb filter=lfs diff=lfs merge=lfs -text
20 | *.pickle filter=lfs diff=lfs merge=lfs -text
21 | *.pkl filter=lfs diff=lfs merge=lfs -text
22 | *.pt filter=lfs diff=lfs merge=lfs -text
23 | *.pth filter=lfs diff=lfs merge=lfs -text
24 | *.rar filter=lfs diff=lfs merge=lfs -text
25 | *.safetensors filter=lfs diff=lfs merge=lfs -text
26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27 | *.tar.* filter=lfs diff=lfs merge=lfs -text
28 | *.tar filter=lfs diff=lfs merge=lfs -text
29 | *.tflite filter=lfs diff=lfs merge=lfs -text
30 | *.tgz filter=lfs diff=lfs merge=lfs -text
31 | *.wasm filter=lfs diff=lfs merge=lfs -text
32 | *.xz filter=lfs diff=lfs merge=lfs -text
33 | *.zip filter=lfs diff=lfs merge=lfs -text
34 | *.zst filter=lfs diff=lfs merge=lfs -text
35 | *tfevents* filter=lfs diff=lfs merge=lfs -text
36 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # KLEA
2 |
3 | An open-source Khmer Word to Speech Model. Just single word not sentence!
4 |
5 |
6 |
7 |
8 |
9 | ### 1. Setup
10 |
11 | ```shell
12 | pip install -r requirements.txt
13 | ```
14 |
15 | ### 2. Download Checkpoint
16 |
17 | [G_60000.pth](https://huggingface.co/spaces/seanghay/KLEA/resolve/main/G_60000.pth)
18 |
19 | ```shell
20 | wget https://huggingface.co/spaces/seanghay/KLEA/resolve/main/G_60000.pth
21 | ```
22 |
23 | Place the checkpoint in the current directory.
24 |
25 | ### 3. Inference
26 |
27 | ```shell
28 | python infer.py "មនុស្សខ្មែរ"
29 | ```
30 |
31 | This will output a file called `audio.wav` in the current directory. Output audio sample rate is 22.05 kHz.
32 |
33 | ### Gradio
34 |
35 | ```
36 | python app.py
37 | ```
38 |
39 |
40 | ### Colab
41 |
42 |
43 |
44 |
45 |
46 | ### Dataset
47 |
48 | This model was trained on kheng.info dataset. You can find it on http://kheng.info or at https://hf.co/datasets/seanghay/khmer_kheng_info_speech
49 |
50 | ## Reference
51 |
52 | - [VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://github.com/jaywalnut310/vits)
53 | - [kheng.info](https://kheng.info/about/) is an online audio dictionary for the Khmer language with over 3000 recordings. Kheng.info is backed by multiple dictionaries and a large text corpus, and supports search in English and Khmer with search results ordered by word frequency.
54 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | from models import SynthesizerTrn
2 | from scipy.io.wavfile import write
3 | from khmer_phonemizer import phonemize_single
4 | import utils
5 | import commons
6 | import torch
7 | import sys
8 |
9 | _pad = '_'
10 | _punctuation = '. '
11 | _letters_ipa = 'acefhijklmnoprstuwzĕŋŏŭɑɓɔɗəɛɡɨɲʋʔʰː'
12 |
13 | # Export all symbols:
14 | symbols = [_pad] + list(_punctuation) + list(_letters_ipa)
15 |
16 |
17 | # Special symbol ids
18 | SPACE_ID = symbols.index(" ")
19 |
20 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
21 |
22 | def text_to_sequence(text):
23 | sequence = []
24 | for symbol in text:
25 | symbol_id = _symbol_to_id[symbol]
26 | sequence += [symbol_id]
27 | return sequence
28 |
29 |
30 | def get_text(text, hps):
31 | text_norm = text_to_sequence(text)
32 | if hps.data.add_blank:
33 | text_norm = commons.intersperse(text_norm, 0)
34 | text_norm = torch.LongTensor(text_norm)
35 | return text_norm
36 |
37 |
38 | hps = utils.get_hparams_from_file("config.json")
39 | net_g = SynthesizerTrn(
40 | len(symbols),
41 | hps.data.filter_length // 2 + 1,
42 | hps.train.segment_size // hps.data.hop_length,
43 | **hps.model
44 | )
45 |
46 | _ = net_g.eval()
47 | _ = utils.load_checkpoint("G_60000.pth", net_g, None)
48 |
49 | text = " ".join(phonemize_single(sys.argv[1]) + ["."])
50 | stn_tst = get_text(text, hps)
51 |
52 | with torch.no_grad():
53 | x_tst = stn_tst.unsqueeze(0)
54 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
55 | audio = (
56 | net_g.infer(
57 | x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
58 | )[0][0, 0]
59 | .data.cpu()
60 | .float()
61 | .numpy()
62 | )
63 | write("audio.wav", rate=hps.data.sampling_rate, data=audio)
64 | print("saved audio.wav")
--------------------------------------------------------------------------------
/export_onnx.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import torch
3 | import utils
4 | from models import SynthesizerTrn
5 |
6 | _pad = "_"
7 | _punctuation = ". "
8 | _letters_ipa = "acefhijklmnoprstuwzĕŋŏŭɑɓɔɗəɛɡɨɲʋʔʰː"
9 |
10 | symbols = [_pad] + list(_punctuation) + list(_letters_ipa)
11 |
12 | hps = utils.get_hparams_from_file("config.json")
13 | net_g = SynthesizerTrn(
14 | len(symbols),
15 | hps.data.filter_length // 2 + 1,
16 | hps.train.segment_size // hps.data.hop_length,
17 | **hps.model
18 | )
19 |
20 | ckpt = torch.load("./G_60000.pth", map_location="cpu")
21 | net_g.load_state_dict(ckpt["model"])
22 | net_g.eval()
23 | net_g.dec.remove_weight_norm()
24 |
25 |
26 | def infer_forward(text, text_lengths, scales, sid=None):
27 | noise_scale = scales[0]
28 | length_scale = scales[1]
29 | noise_scale_w = scales[2]
30 | audio = net_g.infer(
31 | text,
32 | text_lengths,
33 | noise_scale=noise_scale,
34 | length_scale=length_scale,
35 | noise_scale_w=noise_scale_w,
36 | sid=sid,
37 | )[0].unsqueeze(1)
38 | return audio
39 |
40 |
41 | net_g.forward = infer_forward
42 |
43 | dummy_input_length = 50
44 |
45 | num_symbols = len(symbols)
46 | sequences = torch.randint(
47 | low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long
48 | )
49 | sequence_lengths = torch.LongTensor([sequences.size(1)])
50 |
51 | # noise, noise_w, length
52 | scales = torch.FloatTensor([0.667, 1.0, 0.8])
53 | dummy_input = (sequences, sequence_lengths, scales, None)
54 |
55 | torch.onnx.export(
56 | model=net_g,
57 | args=dummy_input,
58 | f=str("output.onnx"),
59 | verbose=False,
60 | opset_version=15,
61 | input_names=["input", "input_lengths", "scales", "sid"],
62 | output_names=["output"],
63 | dynamic_axes={
64 | "input": {0: "batch_size", 1: "phonemes"},
65 | "input_lengths": {0: "batch_size"},
66 | "output": {0: "batch_size", 1: "time"},
67 | },
68 | )
69 |
--------------------------------------------------------------------------------
/infer_onnx.py:
--------------------------------------------------------------------------------
1 | import onnxruntime
2 | import numpy as np
3 | from wavfile import write as write_wav
4 | from utils import get_hparams_from_file
5 | from commons import intersperse
6 | from khmer_phonemizer import phonemize_single
7 |
8 | def audio_float_to_int16(
9 | audio: np.ndarray, max_wav_value: float = 32767.0
10 | ) -> np.ndarray:
11 | """Normalize audio and convert to int16 range"""
12 | audio_norm = audio * (max_wav_value / max(0.01, np.max(np.abs(audio))))
13 | audio_norm = np.clip(audio_norm, -max_wav_value, max_wav_value)
14 | audio_norm = audio_norm.astype("int16")
15 | return audio_norm
16 |
17 | symbols = [
18 | "_",
19 | ".",
20 | " ",
21 | "a",
22 | "c",
23 | "e",
24 | "f",
25 | "h",
26 | "i",
27 | "j",
28 | "k",
29 | "l",
30 | "m",
31 | "n",
32 | "o",
33 | "p",
34 | "r",
35 | "s",
36 | "t",
37 | "u",
38 | "w",
39 | "z",
40 | "ĕ",
41 | "ŋ",
42 | "ŏ",
43 | "ŭ",
44 | "ɑ",
45 | "ɓ",
46 | "ɔ",
47 | "ɗ",
48 | "ə",
49 | "ɛ",
50 | "ɡ",
51 | "ɨ",
52 | "ɲ",
53 | "ʋ",
54 | "ʔ",
55 | "ʰ",
56 | "ː",
57 | ]
58 | symbol_to_id = {s: i for i, s in enumerate(symbols)}
59 |
60 |
61 | def text_to_sequence(text):
62 | sequence = []
63 | for symbol in text:
64 | symbol_id = symbol_to_id[symbol]
65 | sequence += [symbol_id]
66 | return sequence
67 |
68 |
69 | def get_text(text, hps):
70 | text_norm = text_to_sequence(text)
71 | if hps.data.add_blank:
72 | text_norm = intersperse(text_norm, 0)
73 | return text_norm
74 |
75 |
76 | def infer():
77 | session_options = onnxruntime.SessionOptions()
78 | providers = ["CPUExecutionProvider"]
79 | model = onnxruntime.InferenceSession(
80 | "./output.onnx", sess_options=session_options, providers=providers
81 | )
82 |
83 | hps = get_hparams_from_file("config.json")
84 | text = " ".join(phonemize_single("ទិញបាយ") + ["."])
85 | stn_tst = get_text(text, hps)
86 |
87 | text = np.expand_dims(np.array(stn_tst, dtype=np.int64), 0)
88 | text_lengths = np.array([text.shape[1]], dtype=np.int64)
89 | scales = np.array(
90 | [0.667, 1, 0.8],
91 | dtype=np.float32,
92 | )
93 | sample_rate = 22050
94 | sid = None
95 | audio = model.run(
96 | None,
97 | {
98 | "input": text,
99 | "input_lengths": text_lengths,
100 | "scales": scales,
101 | "sid": sid,
102 | },
103 | )[0].squeeze((0, 1))
104 | audio = audio_float_to_int16(audio.squeeze())
105 | write_wav("audio.wav", sample_rate, audio)
106 |
107 |
108 | if __name__ == "__main__":
109 | infer()
110 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import gradio as gr
3 | from models import SynthesizerTrn
4 | from khmer_phonemizer import phonemize_single
5 | import utils
6 | import commons
7 | import torch
8 | import khmernormalizer
9 |
10 | _pad = "_"
11 | _punctuation = ". "
12 | _letters_ipa = "acefhijklmnoprstuwzĕŋŏŭɑɓɔɗəɛɡɨɲʋʔʰː"
13 |
14 | # Export all symbols:
15 | symbols = [_pad] + list(_punctuation) + list(_letters_ipa)
16 |
17 | # Special symbol ids
18 | SPACE_ID = symbols.index(" ")
19 |
20 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
21 |
22 |
23 | def text_to_sequence(text):
24 | sequence = []
25 | for symbol in text:
26 | symbol_id = _symbol_to_id[symbol]
27 | sequence += [symbol_id]
28 | return sequence
29 |
30 |
31 | def get_text(text, hps):
32 | text_norm = text_to_sequence(text)
33 |
34 | if hps.data.add_blank:
35 | text_norm = commons.intersperse(text_norm, 0)
36 | text_norm = torch.LongTensor(text_norm)
37 | return text_norm
38 |
39 |
40 | hps = utils.get_hparams_from_file("config.json")
41 | net_g = SynthesizerTrn(
42 | len(symbols),
43 | hps.data.filter_length // 2 + 1,
44 | hps.train.segment_size // hps.data.hop_length,
45 | **hps.model
46 | )
47 |
48 | _ = net_g.eval()
49 | _ = utils.load_checkpoint("G_60000.pth", net_g, None)
50 |
51 | def generate_voice(text):
52 | text = khmernormalizer.normalize(text)
53 | text = " ".join(phonemize_single(text) + ["."])
54 | stn_tst = get_text(text, hps)
55 | with torch.no_grad():
56 | x_tst = stn_tst.unsqueeze(0)
57 | x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
58 | audio = (
59 | net_g.infer(
60 | x_tst,
61 | x_tst_lengths,
62 | noise_scale=0.667,
63 | noise_scale_w=0.8,
64 | length_scale=1,
65 | )[0][0, 0]
66 | .data.cpu()
67 | .float()
68 | .numpy()
69 | )
70 |
71 | return (hps.data.sampling_rate, audio)
72 |
73 |
74 | with gr.Blocks(
75 | title="Khmer Word to Speech",
76 | theme=gr.themes.Default(
77 | font=[gr.themes.GoogleFont("Noto Sans Khmer"), "Arial", "sans-serif"]
78 | ),
79 | ) as blocks:
80 | gr.Markdown("# Khmer Word to Speech")
81 |
82 | input_text = gr.Text(label="ពាក្យខ្លី", lines=1)
83 | examples = gr.Examples(examples=["មនុស្សជាតិ", "ភ្នំព្រះ"], inputs=[input_text])
84 | run_button = gr.Button(value="បង្កើត")
85 |
86 | out_audio = gr.Audio(
87 | label="សំឡេងដែលបានបង្កើត",
88 | type="numpy",
89 | )
90 |
91 | inputs = [input_text]
92 | outputs = [out_audio]
93 |
94 | run_button.click(
95 | fn=generate_voice,
96 | inputs=inputs,
97 | outputs=outputs,
98 | queue=True,
99 | )
100 |
101 |
102 | blocks.queue(concurrency_count=1).launch(debug=True)
103 |
--------------------------------------------------------------------------------
/mel_processing.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data
3 | from librosa.filters import mel as librosa_mel_fn
4 |
5 | MAX_WAV_VALUE = 32768.0
6 |
7 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
8 | """
9 | PARAMS
10 | ------
11 | C: compression factor
12 | """
13 | return torch.log(torch.clamp(x, min=clip_val) * C)
14 |
15 |
16 | def dynamic_range_decompression_torch(x, C=1):
17 | """
18 | PARAMS
19 | ------
20 | C: compression factor used to compress
21 | """
22 | return torch.exp(x) / C
23 |
24 |
25 | def spectral_normalize_torch(magnitudes):
26 | output = dynamic_range_compression_torch(magnitudes)
27 | return output
28 |
29 |
30 | def spectral_de_normalize_torch(magnitudes):
31 | output = dynamic_range_decompression_torch(magnitudes)
32 | return output
33 |
34 |
35 | mel_basis = {}
36 | hann_window = {}
37 |
38 |
39 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
40 | if torch.min(y) < -1.:
41 | print('min value is ', torch.min(y))
42 | if torch.max(y) > 1.:
43 | print('max value is ', torch.max(y))
44 |
45 | global hann_window
46 | dtype_device = str(y.dtype) + '_' + str(y.device)
47 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
48 | if wnsize_dtype_device not in hann_window:
49 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
50 |
51 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
52 | y = y.squeeze(1)
53 |
54 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
55 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
56 |
57 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
58 | return spec
59 |
60 |
61 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
62 | global mel_basis
63 | dtype_device = str(spec.dtype) + '_' + str(spec.device)
64 | fmax_dtype_device = str(fmax) + '_' + dtype_device
65 | if fmax_dtype_device not in mel_basis:
66 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
67 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
68 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
69 | spec = spectral_normalize_torch(spec)
70 | return spec
71 |
72 |
73 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
74 | if torch.min(y) < -1.:
75 | print('min value is ', torch.min(y))
76 | if torch.max(y) > 1.:
77 | print('max value is ', torch.max(y))
78 |
79 | global mel_basis, hann_window
80 | dtype_device = str(y.dtype) + '_' + str(y.device)
81 | fmax_dtype_device = str(fmax) + '_' + dtype_device
82 | wnsize_dtype_device = str(win_size) + '_' + dtype_device
83 | if fmax_dtype_device not in mel_basis:
84 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
85 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
86 | if wnsize_dtype_device not in hann_window:
87 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
88 |
89 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
90 | y = y.squeeze(1)
91 |
92 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
93 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
94 |
95 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
96 |
97 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
98 | spec = spectral_normalize_torch(spec)
99 |
100 | return spec
101 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # General
2 | .DS_Store
3 | .AppleDouble
4 | .LSOverride
5 |
6 | # Icon must end with two \r
7 | Icon
8 |
9 |
10 | # Thumbnails
11 | ._*
12 |
13 | # Files that might appear in the root of a volume
14 | .DocumentRevisions-V100
15 | .fseventsd
16 | .Spotlight-V100
17 | .TemporaryItems
18 | .Trashes
19 | .VolumeIcon.icns
20 | .com.apple.timemachine.donotpresent
21 |
22 | # Directories potentially created on remote AFP share
23 | .AppleDB
24 | .AppleDesktop
25 | Network Trash Folder
26 | Temporary Items
27 | .apdisk
28 |
29 | # Byte-compiled / optimized / DLL files
30 | __pycache__/
31 | *.py[cod]
32 | *$py.class
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | share/python-wheels/
52 | *.egg-info/
53 | .installed.cfg
54 | *.egg
55 | MANIFEST
56 |
57 | # PyInstaller
58 | # Usually these files are written by a python script from a template
59 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
60 | *.manifest
61 | *.spec
62 |
63 | # Installer logs
64 | pip-log.txt
65 | pip-delete-this-directory.txt
66 |
67 | # Unit test / coverage reports
68 | htmlcov/
69 | .tox/
70 | .nox/
71 | .coverage
72 | .coverage.*
73 | .cache
74 | nosetests.xml
75 | coverage.xml
76 | *.cover
77 | *.py,cover
78 | .hypothesis/
79 | .pytest_cache/
80 | cover/
81 |
82 | # Translations
83 | *.mo
84 | *.pot
85 |
86 | # Django stuff:
87 | *.log
88 | local_settings.py
89 | db.sqlite3
90 | db.sqlite3-journal
91 |
92 | # Flask stuff:
93 | instance/
94 | .webassets-cache
95 |
96 | # Scrapy stuff:
97 | .scrapy
98 |
99 | # Sphinx documentation
100 | docs/_build/
101 |
102 | # PyBuilder
103 | .pybuilder/
104 | target/
105 |
106 | # Jupyter Notebook
107 | .ipynb_checkpoints
108 |
109 | # IPython
110 | profile_default/
111 | ipython_config.py
112 |
113 | # pyenv
114 | # For a library or package, you might want to ignore these files since the code is
115 | # intended to run in multiple environments; otherwise, check them in:
116 | # .python-version
117 |
118 | # pipenv
119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
122 | # install all needed dependencies.
123 | #Pipfile.lock
124 |
125 | # poetry
126 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
127 | # This is especially recommended for binary packages to ensure reproducibility, and is more
128 | # commonly ignored for libraries.
129 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
130 | #poetry.lock
131 |
132 | # pdm
133 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
134 | #pdm.lock
135 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
136 | # in version control.
137 | # https://pdm.fming.dev/#use-with-ide
138 | .pdm.toml
139 |
140 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
141 | __pypackages__/
142 |
143 | # Celery stuff
144 | celerybeat-schedule
145 | celerybeat.pid
146 |
147 | # SageMath parsed files
148 | *.sage.py
149 |
150 | # Environments
151 | .env
152 | .venv
153 | env/
154 | venv/
155 | ENV/
156 | env.bak/
157 | venv.bak/
158 |
159 | # Spyder project settings
160 | .spyderproject
161 | .spyproject
162 |
163 | # Rope project settings
164 | .ropeproject
165 |
166 | # mkdocs documentation
167 | /site
168 |
169 | # mypy
170 | .mypy_cache/
171 | .dmypy.json
172 | dmypy.json
173 |
174 | # Pyre type checker
175 | .pyre/
176 |
177 | # pytype static type analyzer
178 | .pytype/
179 |
180 | # Cython debug symbols
181 | cython_debug/
182 |
183 | # PyCharm
184 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
185 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
186 | # and can be added to the global gitignore or merged into this file. For a more nuclear
187 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
188 | #.idea/
189 | *.wav
190 | *.pth
191 | *.onnx
--------------------------------------------------------------------------------
/commons.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.nn import functional as F
4 |
5 | def init_weights(m, mean=0.0, std=0.01):
6 | classname = m.__class__.__name__
7 | if classname.find("Conv") != -1:
8 | m.weight.data.normal_(mean, std)
9 |
10 |
11 | def get_padding(kernel_size, dilation=1):
12 | return int((kernel_size*dilation - dilation)/2)
13 |
14 |
15 | def convert_pad_shape(pad_shape):
16 | l = pad_shape[::-1]
17 | pad_shape = [item for sublist in l for item in sublist]
18 | return pad_shape
19 |
20 |
21 | def intersperse(lst, item):
22 | result = [item] * (len(lst) * 2 + 1)
23 | result[1::2] = lst
24 | return result
25 |
26 |
27 | def kl_divergence(m_p, logs_p, m_q, logs_q):
28 | """KL(P||Q)"""
29 | kl = (logs_q - logs_p) - 0.5
30 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
31 | return kl
32 |
33 |
34 | def rand_gumbel(shape):
35 | """Sample from the Gumbel distribution, protect from overflows."""
36 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
37 | return -torch.log(-torch.log(uniform_samples))
38 |
39 |
40 | def rand_gumbel_like(x):
41 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
42 | return g
43 |
44 |
45 | def slice_segments(x, ids_str, segment_size=4):
46 | ret = torch.zeros_like(x[:, :, :segment_size])
47 | for i in range(x.size(0)):
48 | idx_str = ids_str[i]
49 | idx_end = idx_str + segment_size
50 | ret[i] = x[i, :, idx_str:idx_end]
51 | return ret
52 |
53 |
54 | def rand_slice_segments(x, x_lengths=None, segment_size=4):
55 | b, d, t = x.size()
56 | if x_lengths is None:
57 | x_lengths = t
58 | ids_str_max = x_lengths - segment_size + 1
59 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
60 | ret = slice_segments(x, ids_str, segment_size)
61 | return ret, ids_str
62 |
63 |
64 | def get_timing_signal_1d(
65 | length, channels, min_timescale=1.0, max_timescale=1.0e4):
66 | position = torch.arange(length, dtype=torch.float)
67 | num_timescales = channels // 2
68 | log_timescale_increment = (
69 | math.log(float(max_timescale) / float(min_timescale)) /
70 | (num_timescales - 1))
71 | inv_timescales = min_timescale * torch.exp(
72 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
73 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
74 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
75 | signal = F.pad(signal, [0, 0, 0, channels % 2])
76 | signal = signal.view(1, channels, length)
77 | return signal
78 |
79 |
80 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
81 | b, channels, length = x.size()
82 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
83 | return x + signal.to(dtype=x.dtype, device=x.device)
84 |
85 |
86 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
87 | b, channels, length = x.size()
88 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
89 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
90 |
91 |
92 | def subsequent_mask(length):
93 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
94 | return mask
95 |
96 |
97 | @torch.jit.script
98 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
99 | n_channels_int = n_channels[0]
100 | in_act = input_a + input_b
101 | t_act = torch.tanh(in_act[:, :n_channels_int, :])
102 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
103 | acts = t_act * s_act
104 | return acts
105 |
106 |
107 | def convert_pad_shape(pad_shape):
108 | l = pad_shape[::-1]
109 | pad_shape = [item for sublist in l for item in sublist]
110 | return pad_shape
111 |
112 |
113 | def shift_1d(x):
114 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
115 | return x
116 |
117 |
118 | def sequence_mask(length, max_length=None):
119 | if max_length is None:
120 | max_length = length.max()
121 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
122 | return x.unsqueeze(0) < length.unsqueeze(1)
123 |
124 |
125 | def generate_path(duration, mask):
126 | """
127 | duration: [b, 1, t_x]
128 | mask: [b, 1, t_y, t_x]
129 | """
130 | device = duration.device
131 |
132 | b, _, t_y, t_x = mask.shape
133 | cum_duration = torch.cumsum(duration, -1)
134 |
135 | cum_duration_flat = cum_duration.view(b * t_x)
136 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
137 | path = path.view(b, t_x, t_y)
138 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
139 | path = path.unsqueeze(1).transpose(2,3) * mask
140 | return path
141 |
142 |
143 | def clip_grad_value_(parameters, clip_value, norm_type=2):
144 | if isinstance(parameters, torch.Tensor):
145 | parameters = [parameters]
146 | parameters = list(filter(lambda p: p.grad is not None, parameters))
147 | norm_type = float(norm_type)
148 | if clip_value is not None:
149 | clip_value = float(clip_value)
150 |
151 | total_norm = 0
152 | for p in parameters:
153 | param_norm = p.grad.data.norm(norm_type)
154 | total_norm += param_norm.item() ** norm_type
155 | if clip_value is not None:
156 | p.grad.data.clamp_(min=-clip_value, max=clip_value)
157 | total_norm = total_norm ** (1. / norm_type)
158 | return total_norm
159 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import sys
4 | import argparse
5 | import logging
6 | import json
7 | import subprocess
8 | import numpy as np
9 | from scipy.io.wavfile import read
10 | import torch
11 |
12 | MATPLOTLIB_FLAG = False
13 |
14 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15 | logger = logging
16 |
17 |
18 | def load_checkpoint(checkpoint_path, model, optimizer=None):
19 | assert os.path.isfile(checkpoint_path)
20 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21 | iteration = checkpoint_dict['iteration']
22 | learning_rate = checkpoint_dict['learning_rate']
23 | if optimizer is not None:
24 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
25 | saved_state_dict = checkpoint_dict['model']
26 | if hasattr(model, 'module'):
27 | state_dict = model.module.state_dict()
28 | else:
29 | state_dict = model.state_dict()
30 | new_state_dict= {}
31 | for k, v in state_dict.items():
32 | try:
33 | new_state_dict[k] = saved_state_dict[k]
34 | except:
35 | logger.info("%s is not in the checkpoint" % k)
36 | new_state_dict[k] = v
37 | if hasattr(model, 'module'):
38 | model.module.load_state_dict(new_state_dict)
39 | else:
40 | model.load_state_dict(new_state_dict)
41 | logger.info("Loaded checkpoint '{}' (iteration {})" .format(
42 | checkpoint_path, iteration))
43 | return model, optimizer, learning_rate, iteration
44 |
45 |
46 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
47 | logger.info("Saving model and optimizer state at iteration {} to {}".format(
48 | iteration, checkpoint_path))
49 | if hasattr(model, 'module'):
50 | state_dict = model.module.state_dict()
51 | else:
52 | state_dict = model.state_dict()
53 | torch.save({'model': state_dict,
54 | 'iteration': iteration,
55 | 'optimizer': optimizer.state_dict(),
56 | 'learning_rate': learning_rate}, checkpoint_path)
57 |
58 |
59 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
60 | for k, v in scalars.items():
61 | writer.add_scalar(k, v, global_step)
62 | for k, v in histograms.items():
63 | writer.add_histogram(k, v, global_step)
64 | for k, v in images.items():
65 | writer.add_image(k, v, global_step, dataformats='HWC')
66 | for k, v in audios.items():
67 | writer.add_audio(k, v, global_step, audio_sampling_rate)
68 |
69 |
70 | def latest_checkpoint_path(dir_path, regex="G_*.pth"):
71 | f_list = glob.glob(os.path.join(dir_path, regex))
72 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
73 | x = f_list[-1]
74 | print(x)
75 | return x
76 |
77 | def load_wav_to_torch(full_path):
78 | sampling_rate, data = read(full_path)
79 | return torch.FloatTensor(data.astype(np.float32)), sampling_rate
80 |
81 |
82 | def load_filepaths_and_text(filename, split="|"):
83 | with open(filename, encoding='utf-8') as f:
84 | filepaths_and_text = [line.strip().split(split) for line in f]
85 | return filepaths_and_text
86 |
87 |
88 | def get_hparams(init=True):
89 | parser = argparse.ArgumentParser()
90 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
91 | help='JSON file for configuration')
92 | parser.add_argument('-m', '--model', type=str, required=True,
93 | help='Model name')
94 |
95 | args = parser.parse_args()
96 | model_dir = os.path.join("./logs", args.model)
97 |
98 | if not os.path.exists(model_dir):
99 | os.makedirs(model_dir)
100 |
101 | config_path = args.config
102 | config_save_path = os.path.join(model_dir, "config.json")
103 | if init:
104 | with open(config_path, "r") as f:
105 | data = f.read()
106 | with open(config_save_path, "w") as f:
107 | f.write(data)
108 | else:
109 | with open(config_save_path, "r") as f:
110 | data = f.read()
111 | config = json.loads(data)
112 |
113 | hparams = HParams(**config)
114 | hparams.model_dir = model_dir
115 | return hparams
116 |
117 |
118 | def get_hparams_from_dir(model_dir):
119 | config_save_path = os.path.join(model_dir, "config.json")
120 | with open(config_save_path, "r") as f:
121 | data = f.read()
122 | config = json.loads(data)
123 |
124 | hparams =HParams(**config)
125 | hparams.model_dir = model_dir
126 | return hparams
127 |
128 |
129 | def get_hparams_from_file(config_path):
130 | with open(config_path, "r") as f:
131 | data = f.read()
132 | config = json.loads(data)
133 |
134 | hparams =HParams(**config)
135 | return hparams
136 |
137 |
138 | def check_git_hash(model_dir):
139 | source_dir = os.path.dirname(os.path.realpath(__file__))
140 | if not os.path.exists(os.path.join(source_dir, ".git")):
141 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
142 | source_dir
143 | ))
144 | return
145 |
146 | cur_hash = subprocess.getoutput("git rev-parse HEAD")
147 |
148 | path = os.path.join(model_dir, "githash")
149 | if os.path.exists(path):
150 | saved_hash = open(path).read()
151 | if saved_hash != cur_hash:
152 | logger.warn("git hash values are different. {}(saved) != {}(current)".format(
153 | saved_hash[:8], cur_hash[:8]))
154 | else:
155 | open(path, "w").write(cur_hash)
156 |
157 |
158 | def get_logger(model_dir, filename="train.log"):
159 | global logger
160 | logger = logging.getLogger(os.path.basename(model_dir))
161 | logger.setLevel(logging.DEBUG)
162 |
163 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
164 | if not os.path.exists(model_dir):
165 | os.makedirs(model_dir)
166 | h = logging.FileHandler(os.path.join(model_dir, filename))
167 | h.setLevel(logging.DEBUG)
168 | h.setFormatter(formatter)
169 | logger.addHandler(h)
170 | return logger
171 |
172 |
173 | class HParams():
174 | def __init__(self, **kwargs):
175 | for k, v in kwargs.items():
176 | if type(v) == dict:
177 | v = HParams(**v)
178 | self[k] = v
179 |
180 | def keys(self):
181 | return self.__dict__.keys()
182 |
183 | def items(self):
184 | return self.__dict__.items()
185 |
186 | def values(self):
187 | return self.__dict__.values()
188 |
189 | def __len__(self):
190 | return len(self.__dict__)
191 |
192 | def __getitem__(self, key):
193 | return getattr(self, key)
194 |
195 | def __setitem__(self, key, value):
196 | return setattr(self, key, value)
197 |
198 | def __contains__(self, key):
199 | return key in self.__dict__
200 |
201 | def __repr__(self):
202 | return self.__dict__.__repr__()
203 |
--------------------------------------------------------------------------------
/g2p.py:
--------------------------------------------------------------------------------
1 | """
2 | Guess word pronunciations using a Phonetisaurus FST
3 |
4 | See bin/fst2npz.py to convert an FST to a numpy graph.
5 |
6 | Reference:
7 | https://github.com/rhasspy/gruut/blob/master/gruut/g2p_phonetisaurus.py
8 | """
9 | import typing
10 | from collections import defaultdict
11 | from pathlib import Path
12 | import numpy as np
13 |
14 | NUMPY_GRAPH = typing.Dict[str, np.ndarray]
15 | _NOT_FINAL = object()
16 |
17 | class PhonetisaurusGraph:
18 | """Graph of numpy arrays that represents a Phonetisaurus FST
19 |
20 | Also contains shared cache of edges and final state probabilities.
21 | These caches are necessary to ensure that the .npz file stays small and fast
22 | to load.
23 | """
24 |
25 | def __init__(self, graph: NUMPY_GRAPH, preload: bool = False):
26 | self.graph = graph
27 |
28 | self.start_node = int(self.graph["start_node"].item())
29 |
30 | # edge_index -> (from_node, to_node, ilabel, olabel)
31 | self.edges = self.graph["edges"]
32 | self.edge_probs = self.graph["edge_probs"]
33 |
34 | # int -> [str]
35 | self.symbols = []
36 | for symbol_str in self.graph["symbols"]:
37 | symbol_list = symbol_str.replace("_", "").split("|")
38 | self.symbols.append((len(symbol_list), symbol_list))
39 |
40 | # nodes that are accepting states
41 | self.final_nodes = self.graph["final_nodes"]
42 |
43 | # node -> probability
44 | self.final_probs = self.graph["final_probs"]
45 |
46 | # Cache
47 | self.preloaded = preload
48 | self.out_edges: typing.Dict[int, typing.List[int]] = defaultdict(list)
49 | self.final_node_probs: typing.Dict[int, typing.Any] = {}
50 |
51 | if preload:
52 | # Load out edges
53 | for edge_idx, (from_node, *_) in enumerate(self.edges):
54 | self.out_edges[from_node].append(edge_idx)
55 |
56 | # Load final probabilities
57 | self.final_node_probs.update(zip(self.final_nodes, self.final_probs))
58 |
59 | @staticmethod
60 | def load(graph_path: typing.Union[str, Path], **kwargs) -> "PhonetisaurusGraph":
61 | """Load .npz file with numpy graph"""
62 | np_graph = np.load(graph_path, allow_pickle=True)
63 | return PhonetisaurusGraph(np_graph, **kwargs)
64 |
65 | def g2p_one(
66 | self,
67 | word: typing.Union[str, typing.Sequence[str]],
68 | eps: str = "",
69 | beam: int = 5000,
70 | min_beam: int = 100,
71 | beam_scale: float = 0.6,
72 | grapheme_separator: str = "",
73 | max_guesses: int = 1,
74 | ) -> typing.Iterable[typing.Tuple[typing.Sequence[str], typing.Sequence[str]]]:
75 | """Guess phonemes for word"""
76 | current_beam = beam
77 | graphemes: typing.Sequence[str] = []
78 |
79 | if isinstance(word, str):
80 | word = word.strip()
81 |
82 | if grapheme_separator:
83 | graphemes = word.split(grapheme_separator)
84 | else:
85 | graphemes = list(word)
86 | else:
87 | graphemes = word
88 |
89 | if not graphemes:
90 | return []
91 |
92 | # (prob, node, graphemes, phonemes, final, beam)
93 | q: typing.List[
94 | typing.Tuple[
95 | float,
96 | typing.Optional[int],
97 | typing.Sequence[str],
98 | typing.List[str],
99 | bool,
100 | ]
101 | ] = [(0.0, self.start_node, graphemes, [], False)]
102 |
103 | q_next: typing.List[
104 | typing.Tuple[
105 | float,
106 | typing.Optional[int],
107 | typing.Sequence[str],
108 | typing.List[str],
109 | bool,
110 | ]
111 | ] = []
112 |
113 | # (prob, phonemes)
114 | best_heap: typing.List[typing.Tuple[float, typing.Sequence[str]]] = []
115 |
116 | # Avoid duplicate guesses
117 | guessed_phonemes: typing.Set[typing.Tuple[str, ...]] = set()
118 |
119 | while q:
120 | done_with_word = False
121 | q_next = []
122 |
123 | for prob, node, next_graphemes, output, is_final in q:
124 | if is_final:
125 | # Complete guess
126 | phonemes = tuple(output)
127 | if phonemes not in guessed_phonemes:
128 | best_heap.append((prob, phonemes))
129 | guessed_phonemes.add(phonemes)
130 |
131 | if len(best_heap) >= max_guesses:
132 | done_with_word = True
133 | break
134 |
135 | continue
136 |
137 | assert node is not None
138 |
139 | if not next_graphemes:
140 | if self.preloaded:
141 | final_prob = self.final_node_probs.get(node, _NOT_FINAL)
142 | else:
143 | final_prob = self.final_node_probs.get(node)
144 | if final_prob is None:
145 | final_idx = int(np.searchsorted(self.final_nodes, node))
146 | if self.final_nodes[final_idx] == node:
147 | # Cache
148 | final_prob = float(self.final_probs[final_idx])
149 | self.final_node_probs[node] = final_prob
150 | else:
151 | # Not a final state
152 | final_prob = _NOT_FINAL
153 | self.final_node_probs[node] = final_prob
154 |
155 | if final_prob != _NOT_FINAL:
156 | final_prob = typing.cast(float, final_prob)
157 | q_next.append((prob + final_prob, None, [], output, True))
158 |
159 | len_next_graphemes = len(next_graphemes)
160 | if self.preloaded:
161 | # Was pre-loaded in __init__
162 | edge_idxs = self.out_edges[node]
163 | else:
164 | # Build cache during search
165 | maybe_edge_idxs = self.out_edges.get(node)
166 | if maybe_edge_idxs is None:
167 | edge_idx = int(np.searchsorted(self.edges[:, 0], node))
168 | edge_idxs = []
169 | while self.edges[edge_idx][0] == node:
170 | edge_idxs.append(edge_idx)
171 | edge_idx += 1
172 |
173 | # Cache
174 | self.out_edges[node] = edge_idxs
175 | else:
176 | edge_idxs = maybe_edge_idxs
177 |
178 | for edge_idx in edge_idxs:
179 | _, to_node, ilabel_idx, olabel_idx = self.edges[edge_idx]
180 | out_prob = self.edge_probs[edge_idx]
181 |
182 | len_igraphemes, igraphemes = self.symbols[ilabel_idx]
183 |
184 | if len_igraphemes > len_next_graphemes:
185 | continue
186 |
187 | if igraphemes == [eps]:
188 | item = (prob + out_prob, to_node, next_graphemes, output, False)
189 | q_next.append(item)
190 | else:
191 | sub_graphemes = next_graphemes[:len_igraphemes]
192 | if igraphemes == sub_graphemes:
193 | _, olabel = self.symbols[olabel_idx]
194 | item = (
195 | prob + out_prob,
196 | to_node,
197 | next_graphemes[len(sub_graphemes) :],
198 | output + olabel,
199 | False,
200 | )
201 | q_next.append(item)
202 |
203 | if done_with_word:
204 | break
205 |
206 | q_next = sorted(q_next, key=lambda item: item[0])[:current_beam]
207 | q = q_next
208 |
209 | current_beam = max(min_beam, (int(current_beam * beam_scale)))
210 |
211 | # Yield guesses
212 | if best_heap:
213 | for _, guess_phonemes in sorted(best_heap, key=lambda item: item[0])[
214 | :max_guesses
215 | ]:
216 | yield [p for p in guess_phonemes if p]
217 | else:
218 | # No guesses
219 | yield []
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 | import numpy as np
5 |
6 |
7 | DEFAULT_MIN_BIN_WIDTH = 1e-3
8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3
9 | DEFAULT_MIN_DERIVATIVE = 1e-3
10 |
11 |
12 | def piecewise_rational_quadratic_transform(inputs,
13 | unnormalized_widths,
14 | unnormalized_heights,
15 | unnormalized_derivatives,
16 | inverse=False,
17 | tails=None,
18 | tail_bound=1.,
19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
21 | min_derivative=DEFAULT_MIN_DERIVATIVE):
22 |
23 | if tails is None:
24 | spline_fn = rational_quadratic_spline
25 | spline_kwargs = {}
26 | else:
27 | spline_fn = unconstrained_rational_quadratic_spline
28 | spline_kwargs = {
29 | 'tails': tails,
30 | 'tail_bound': tail_bound
31 | }
32 |
33 | outputs, logabsdet = spline_fn(
34 | inputs=inputs,
35 | unnormalized_widths=unnormalized_widths,
36 | unnormalized_heights=unnormalized_heights,
37 | unnormalized_derivatives=unnormalized_derivatives,
38 | inverse=inverse,
39 | min_bin_width=min_bin_width,
40 | min_bin_height=min_bin_height,
41 | min_derivative=min_derivative,
42 | **spline_kwargs
43 | )
44 | return outputs, logabsdet
45 |
46 |
47 | def searchsorted(bin_locations, inputs, eps=1e-6):
48 | bin_locations[..., -1] += eps
49 | return torch.sum(
50 | inputs[..., None] >= bin_locations,
51 | dim=-1
52 | ) - 1
53 |
54 |
55 | def unconstrained_rational_quadratic_spline(inputs,
56 | unnormalized_widths,
57 | unnormalized_heights,
58 | unnormalized_derivatives,
59 | inverse=False,
60 | tails='linear',
61 | tail_bound=1.,
62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
64 | min_derivative=DEFAULT_MIN_DERIVATIVE):
65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
66 | outside_interval_mask = ~inside_interval_mask
67 |
68 | outputs = torch.zeros_like(inputs)
69 | logabsdet = torch.zeros_like(inputs)
70 |
71 | if tails == 'linear':
72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
73 | constant = np.log(np.exp(1 - min_derivative) - 1)
74 | unnormalized_derivatives[..., 0] = constant
75 | unnormalized_derivatives[..., -1] = constant
76 |
77 | outputs[outside_interval_mask] = inputs[outside_interval_mask]
78 | logabsdet[outside_interval_mask] = 0
79 | else:
80 | raise RuntimeError('{} tails are not implemented.'.format(tails))
81 |
82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
83 | inputs=inputs[inside_interval_mask],
84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87 | inverse=inverse,
88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound,
89 | min_bin_width=min_bin_width,
90 | min_bin_height=min_bin_height,
91 | min_derivative=min_derivative
92 | )
93 |
94 | return outputs, logabsdet
95 |
96 | def rational_quadratic_spline(inputs,
97 | unnormalized_widths,
98 | unnormalized_heights,
99 | unnormalized_derivatives,
100 | inverse=False,
101 | left=0., right=1., bottom=0., top=1.,
102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH,
103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
104 | min_derivative=DEFAULT_MIN_DERIVATIVE):
105 | if torch.min(inputs) < left or torch.max(inputs) > right:
106 | raise ValueError('Input to a transform is not within its domain')
107 |
108 | num_bins = unnormalized_widths.shape[-1]
109 |
110 | if min_bin_width * num_bins > 1.0:
111 | raise ValueError('Minimal bin width too large for the number of bins')
112 | if min_bin_height * num_bins > 1.0:
113 | raise ValueError('Minimal bin height too large for the number of bins')
114 |
115 | widths = F.softmax(unnormalized_widths, dim=-1)
116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
117 | cumwidths = torch.cumsum(widths, dim=-1)
118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0)
119 | cumwidths = (right - left) * cumwidths + left
120 | cumwidths[..., 0] = left
121 | cumwidths[..., -1] = right
122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1]
123 |
124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives)
125 |
126 | heights = F.softmax(unnormalized_heights, dim=-1)
127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
128 | cumheights = torch.cumsum(heights, dim=-1)
129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0)
130 | cumheights = (top - bottom) * cumheights + bottom
131 | cumheights[..., 0] = bottom
132 | cumheights[..., -1] = top
133 | heights = cumheights[..., 1:] - cumheights[..., :-1]
134 |
135 | if inverse:
136 | bin_idx = searchsorted(cumheights, inputs)[..., None]
137 | else:
138 | bin_idx = searchsorted(cumwidths, inputs)[..., None]
139 |
140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
142 |
143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
144 | delta = heights / widths
145 | input_delta = delta.gather(-1, bin_idx)[..., 0]
146 |
147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
149 |
150 | input_heights = heights.gather(-1, bin_idx)[..., 0]
151 |
152 | if inverse:
153 | a = (((inputs - input_cumheights) * (input_derivatives
154 | + input_derivatives_plus_one
155 | - 2 * input_delta)
156 | + input_heights * (input_delta - input_derivatives)))
157 | b = (input_heights * input_derivatives
158 | - (inputs - input_cumheights) * (input_derivatives
159 | + input_derivatives_plus_one
160 | - 2 * input_delta))
161 | c = - input_delta * (inputs - input_cumheights)
162 |
163 | discriminant = b.pow(2) - 4 * a * c
164 | assert (discriminant >= 0).all()
165 |
166 | root = (2 * c) / (-b - torch.sqrt(discriminant))
167 | outputs = root * input_bin_widths + input_cumwidths
168 |
169 | theta_one_minus_theta = root * (1 - root)
170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
171 | * theta_one_minus_theta)
172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2)
173 | + 2 * input_delta * theta_one_minus_theta
174 | + input_derivatives * (1 - root).pow(2))
175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
176 |
177 | return outputs, -logabsdet
178 | else:
179 | theta = (inputs - input_cumwidths) / input_bin_widths
180 | theta_one_minus_theta = theta * (1 - theta)
181 |
182 | numerator = input_heights * (input_delta * theta.pow(2)
183 | + input_derivatives * theta_one_minus_theta)
184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta)
185 | * theta_one_minus_theta)
186 | outputs = input_cumheights + numerator / denominator
187 |
188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2)
189 | + 2 * input_delta * theta_one_minus_theta
190 | + input_derivatives * (1 - theta).pow(2))
191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
192 |
193 | return outputs, logabsdet
194 |
--------------------------------------------------------------------------------
/attentions.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | import commons
6 | import modules
7 | from modules import LayerNorm
8 |
9 |
10 | class Encoder(nn.Module):
11 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs):
12 | super().__init__()
13 | self.hidden_channels = hidden_channels
14 | self.filter_channels = filter_channels
15 | self.n_heads = n_heads
16 | self.n_layers = n_layers
17 | self.kernel_size = kernel_size
18 | self.p_dropout = p_dropout
19 | self.window_size = window_size
20 |
21 | self.drop = nn.Dropout(p_dropout)
22 | self.attn_layers = nn.ModuleList()
23 | self.norm_layers_1 = nn.ModuleList()
24 | self.ffn_layers = nn.ModuleList()
25 | self.norm_layers_2 = nn.ModuleList()
26 | for i in range(self.n_layers):
27 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
28 | self.norm_layers_1.append(LayerNorm(hidden_channels))
29 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
30 | self.norm_layers_2.append(LayerNorm(hidden_channels))
31 |
32 | def forward(self, x, x_mask):
33 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
34 | x = x * x_mask
35 | for i in range(self.n_layers):
36 | y = self.attn_layers[i](x, x, attn_mask)
37 | y = self.drop(y)
38 | x = self.norm_layers_1[i](x + y)
39 |
40 | y = self.ffn_layers[i](x, x_mask)
41 | y = self.drop(y)
42 | x = self.norm_layers_2[i](x + y)
43 | x = x * x_mask
44 | return x
45 |
46 |
47 | class Decoder(nn.Module):
48 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
49 | super().__init__()
50 | self.hidden_channels = hidden_channels
51 | self.filter_channels = filter_channels
52 | self.n_heads = n_heads
53 | self.n_layers = n_layers
54 | self.kernel_size = kernel_size
55 | self.p_dropout = p_dropout
56 | self.proximal_bias = proximal_bias
57 | self.proximal_init = proximal_init
58 |
59 | self.drop = nn.Dropout(p_dropout)
60 | self.self_attn_layers = nn.ModuleList()
61 | self.norm_layers_0 = nn.ModuleList()
62 | self.encdec_attn_layers = nn.ModuleList()
63 | self.norm_layers_1 = nn.ModuleList()
64 | self.ffn_layers = nn.ModuleList()
65 | self.norm_layers_2 = nn.ModuleList()
66 | for i in range(self.n_layers):
67 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init))
68 | self.norm_layers_0.append(LayerNorm(hidden_channels))
69 | self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
70 | self.norm_layers_1.append(LayerNorm(hidden_channels))
71 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
72 | self.norm_layers_2.append(LayerNorm(hidden_channels))
73 |
74 | def forward(self, x, x_mask, h, h_mask):
75 | """
76 | x: decoder input
77 | h: encoder output
78 | """
79 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
80 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
81 | x = x * x_mask
82 | for i in range(self.n_layers):
83 | y = self.self_attn_layers[i](x, x, self_attn_mask)
84 | y = self.drop(y)
85 | x = self.norm_layers_0[i](x + y)
86 |
87 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
88 | y = self.drop(y)
89 | x = self.norm_layers_1[i](x + y)
90 |
91 | y = self.ffn_layers[i](x, x_mask)
92 | y = self.drop(y)
93 | x = self.norm_layers_2[i](x + y)
94 | x = x * x_mask
95 | return x
96 |
97 |
98 | class MultiHeadAttention(nn.Module):
99 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
100 | super().__init__()
101 | assert channels % n_heads == 0
102 |
103 | self.channels = channels
104 | self.out_channels = out_channels
105 | self.n_heads = n_heads
106 | self.p_dropout = p_dropout
107 | self.window_size = window_size
108 | self.heads_share = heads_share
109 | self.block_length = block_length
110 | self.proximal_bias = proximal_bias
111 | self.proximal_init = proximal_init
112 | self.attn = None
113 |
114 | self.k_channels = channels // n_heads
115 | self.conv_q = nn.Conv1d(channels, channels, 1)
116 | self.conv_k = nn.Conv1d(channels, channels, 1)
117 | self.conv_v = nn.Conv1d(channels, channels, 1)
118 | self.conv_o = nn.Conv1d(channels, out_channels, 1)
119 | self.drop = nn.Dropout(p_dropout)
120 |
121 | if window_size is not None:
122 | n_heads_rel = 1 if heads_share else n_heads
123 | rel_stddev = self.k_channels**-0.5
124 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
125 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
126 |
127 | nn.init.xavier_uniform_(self.conv_q.weight)
128 | nn.init.xavier_uniform_(self.conv_k.weight)
129 | nn.init.xavier_uniform_(self.conv_v.weight)
130 | if proximal_init:
131 | with torch.no_grad():
132 | self.conv_k.weight.copy_(self.conv_q.weight)
133 | self.conv_k.bias.copy_(self.conv_q.bias)
134 |
135 | def forward(self, x, c, attn_mask=None):
136 | q = self.conv_q(x)
137 | k = self.conv_k(c)
138 | v = self.conv_v(c)
139 |
140 | x, self.attn = self.attention(q, k, v, mask=attn_mask)
141 |
142 | x = self.conv_o(x)
143 | return x
144 |
145 | def attention(self, query, key, value, mask=None):
146 | # reshape [b, d, t] -> [b, n_h, t, d_k]
147 | b, d, t_s, t_t = (*key.size(), query.size(2))
148 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
149 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
150 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
151 |
152 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
153 | if self.window_size is not None:
154 | assert t_s == t_t, "Relative attention is only available for self-attention."
155 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
156 | rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings)
157 | scores_local = self._relative_position_to_absolute_position(rel_logits)
158 | scores = scores + scores_local
159 | if self.proximal_bias:
160 | assert t_s == t_t, "Proximal bias is only available for self-attention."
161 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
162 | if mask is not None:
163 | scores = scores.masked_fill(mask == 0, -1e4)
164 | if self.block_length is not None:
165 | assert t_s == t_t, "Local attention is only available for self-attention."
166 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
167 | scores = scores.masked_fill(block_mask == 0, -1e4)
168 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
169 | p_attn = self.drop(p_attn)
170 | output = torch.matmul(p_attn, value)
171 | if self.window_size is not None:
172 | relative_weights = self._absolute_position_to_relative_position(p_attn)
173 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
174 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
175 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
176 | return output, p_attn
177 |
178 | def _matmul_with_relative_values(self, x, y):
179 | """
180 | x: [b, h, l, m]
181 | y: [h or 1, m, d]
182 | ret: [b, h, l, d]
183 | """
184 | ret = torch.matmul(x, y.unsqueeze(0))
185 | return ret
186 |
187 | def _matmul_with_relative_keys(self, x, y):
188 | """
189 | x: [b, h, l, d]
190 | y: [h or 1, m, d]
191 | ret: [b, h, l, m]
192 | """
193 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
194 | return ret
195 |
196 | def _get_relative_embeddings(self, relative_embeddings, length):
197 | max_relative_position = 2 * self.window_size + 1
198 | # Pad first before slice to avoid using cond ops.
199 | pad_length = max(length - (self.window_size + 1), 0)
200 | slice_start_position = max((self.window_size + 1) - length, 0)
201 | slice_end_position = slice_start_position + 2 * length - 1
202 | if pad_length > 0:
203 | padded_relative_embeddings = F.pad(
204 | relative_embeddings,
205 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
206 | else:
207 | padded_relative_embeddings = relative_embeddings
208 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position]
209 | return used_relative_embeddings
210 |
211 | def _relative_position_to_absolute_position(self, x):
212 | """
213 | x: [b, h, l, 2*l-1]
214 | ret: [b, h, l, l]
215 | """
216 | batch, heads, length, _ = x.size()
217 | # Concat columns of pad to shift from relative to absolute indexing.
218 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]]))
219 |
220 | # Concat extra elements so to add up to shape (len+1, 2*len-1).
221 | x_flat = x.view([batch, heads, length * 2 * length])
222 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]]))
223 |
224 | # Reshape and slice out the padded elements.
225 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:]
226 | return x_final
227 |
228 | def _absolute_position_to_relative_position(self, x):
229 | """
230 | x: [b, h, l, l]
231 | ret: [b, h, l, 2*l-1]
232 | """
233 | batch, heads, length, _ = x.size()
234 | # padd along column
235 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]]))
236 | x_flat = x.view([batch, heads, length**2 + length*(length -1)])
237 | # add 0's in the beginning that will skew the elements after reshape
238 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
239 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:]
240 | return x_final
241 |
242 | def _attention_bias_proximal(self, length):
243 | """Bias for self-attention to encourage attention to close positions.
244 | Args:
245 | length: an integer scalar.
246 | Returns:
247 | a Tensor with shape [1, 1, length, length]
248 | """
249 | r = torch.arange(length, dtype=torch.float32)
250 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
251 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
252 |
253 |
254 | class FFN(nn.Module):
255 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False):
256 | super().__init__()
257 | self.in_channels = in_channels
258 | self.out_channels = out_channels
259 | self.filter_channels = filter_channels
260 | self.kernel_size = kernel_size
261 | self.p_dropout = p_dropout
262 | self.activation = activation
263 | self.causal = causal
264 |
265 | if causal:
266 | self.padding = self._causal_padding
267 | else:
268 | self.padding = self._same_padding
269 |
270 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
271 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
272 | self.drop = nn.Dropout(p_dropout)
273 |
274 | def forward(self, x, x_mask):
275 | x = self.conv_1(self.padding(x * x_mask))
276 | if self.activation == "gelu":
277 | x = x * torch.sigmoid(1.702 * x)
278 | else:
279 | x = torch.relu(x)
280 | x = self.drop(x)
281 | x = self.conv_2(self.padding(x * x_mask))
282 | return x * x_mask
283 |
284 | def _causal_padding(self, x):
285 | if self.kernel_size == 1:
286 | return x
287 | pad_l = self.kernel_size - 1
288 | pad_r = 0
289 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
290 | x = F.pad(x, commons.convert_pad_shape(padding))
291 | return x
292 |
293 | def _same_padding(self, x):
294 | if self.kernel_size == 1:
295 | return x
296 | pad_l = (self.kernel_size - 1) // 2
297 | pad_r = self.kernel_size // 2
298 | padding = [[0, 0], [0, 0], [pad_l, pad_r]]
299 | x = F.pad(x, commons.convert_pad_shape(padding))
300 | return x
301 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from torch.nn import functional as F
7 |
8 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
9 | from torch.nn.utils import weight_norm, remove_weight_norm
10 |
11 | import commons
12 | from commons import init_weights, get_padding
13 | from transforms import piecewise_rational_quadratic_transform
14 |
15 |
16 | LRELU_SLOPE = 0.1
17 |
18 |
19 | class LayerNorm(nn.Module):
20 | def __init__(self, channels, eps=1e-5):
21 | super().__init__()
22 | self.channels = channels
23 | self.eps = eps
24 |
25 | self.gamma = nn.Parameter(torch.ones(channels))
26 | self.beta = nn.Parameter(torch.zeros(channels))
27 |
28 | def forward(self, x):
29 | x = x.transpose(1, -1)
30 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
31 | return x.transpose(1, -1)
32 |
33 |
34 | class ConvReluNorm(nn.Module):
35 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
36 | super().__init__()
37 | self.in_channels = in_channels
38 | self.hidden_channels = hidden_channels
39 | self.out_channels = out_channels
40 | self.kernel_size = kernel_size
41 | self.n_layers = n_layers
42 | self.p_dropout = p_dropout
43 | assert n_layers > 1, "Number of layers should be larger than 0."
44 |
45 | self.conv_layers = nn.ModuleList()
46 | self.norm_layers = nn.ModuleList()
47 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2))
48 | self.norm_layers.append(LayerNorm(hidden_channels))
49 | self.relu_drop = nn.Sequential(
50 | nn.ReLU(),
51 | nn.Dropout(p_dropout))
52 | for _ in range(n_layers-1):
53 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2))
54 | self.norm_layers.append(LayerNorm(hidden_channels))
55 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
56 | self.proj.weight.data.zero_()
57 | self.proj.bias.data.zero_()
58 |
59 | def forward(self, x, x_mask):
60 | x_org = x
61 | for i in range(self.n_layers):
62 | x = self.conv_layers[i](x * x_mask)
63 | x = self.norm_layers[i](x)
64 | x = self.relu_drop(x)
65 | x = x_org + self.proj(x)
66 | return x * x_mask
67 |
68 |
69 | class DDSConv(nn.Module):
70 | """
71 | Dialted and Depth-Separable Convolution
72 | """
73 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.):
74 | super().__init__()
75 | self.channels = channels
76 | self.kernel_size = kernel_size
77 | self.n_layers = n_layers
78 | self.p_dropout = p_dropout
79 |
80 | self.drop = nn.Dropout(p_dropout)
81 | self.convs_sep = nn.ModuleList()
82 | self.convs_1x1 = nn.ModuleList()
83 | self.norms_1 = nn.ModuleList()
84 | self.norms_2 = nn.ModuleList()
85 | for i in range(n_layers):
86 | dilation = kernel_size ** i
87 | padding = (kernel_size * dilation - dilation) // 2
88 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size,
89 | groups=channels, dilation=dilation, padding=padding
90 | ))
91 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
92 | self.norms_1.append(LayerNorm(channels))
93 | self.norms_2.append(LayerNorm(channels))
94 |
95 | def forward(self, x, x_mask, g=None):
96 | if g is not None:
97 | x = x + g
98 | for i in range(self.n_layers):
99 | y = self.convs_sep[i](x * x_mask)
100 | y = self.norms_1[i](y)
101 | y = F.gelu(y)
102 | y = self.convs_1x1[i](y)
103 | y = self.norms_2[i](y)
104 | y = F.gelu(y)
105 | y = self.drop(y)
106 | x = x + y
107 | return x * x_mask
108 |
109 |
110 | class WN(torch.nn.Module):
111 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
112 | super(WN, self).__init__()
113 | assert(kernel_size % 2 == 1)
114 | self.hidden_channels =hidden_channels
115 | self.kernel_size = kernel_size,
116 | self.dilation_rate = dilation_rate
117 | self.n_layers = n_layers
118 | self.gin_channels = gin_channels
119 | self.p_dropout = p_dropout
120 |
121 | self.in_layers = torch.nn.ModuleList()
122 | self.res_skip_layers = torch.nn.ModuleList()
123 | self.drop = nn.Dropout(p_dropout)
124 |
125 | if gin_channels != 0:
126 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1)
127 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
128 |
129 | for i in range(n_layers):
130 | dilation = dilation_rate ** i
131 | padding = int((kernel_size * dilation - dilation) / 2)
132 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size,
133 | dilation=dilation, padding=padding)
134 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
135 | self.in_layers.append(in_layer)
136 |
137 | # last one is not necessary
138 | if i < n_layers - 1:
139 | res_skip_channels = 2 * hidden_channels
140 | else:
141 | res_skip_channels = hidden_channels
142 |
143 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
144 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
145 | self.res_skip_layers.append(res_skip_layer)
146 |
147 | def forward(self, x, x_mask, g=None, **kwargs):
148 | output = torch.zeros_like(x)
149 | n_channels_tensor = torch.IntTensor([self.hidden_channels])
150 |
151 | if g is not None:
152 | g = self.cond_layer(g)
153 |
154 | for i in range(self.n_layers):
155 | x_in = self.in_layers[i](x)
156 | if g is not None:
157 | cond_offset = i * 2 * self.hidden_channels
158 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
159 | else:
160 | g_l = torch.zeros_like(x_in)
161 |
162 | acts = commons.fused_add_tanh_sigmoid_multiply(
163 | x_in,
164 | g_l,
165 | n_channels_tensor)
166 | acts = self.drop(acts)
167 |
168 | res_skip_acts = self.res_skip_layers[i](acts)
169 | if i < self.n_layers - 1:
170 | res_acts = res_skip_acts[:,:self.hidden_channels,:]
171 | x = (x + res_acts) * x_mask
172 | output = output + res_skip_acts[:,self.hidden_channels:,:]
173 | else:
174 | output = output + res_skip_acts
175 | return output * x_mask
176 |
177 | def remove_weight_norm(self):
178 | if self.gin_channels != 0:
179 | torch.nn.utils.remove_weight_norm(self.cond_layer)
180 | for l in self.in_layers:
181 | torch.nn.utils.remove_weight_norm(l)
182 | for l in self.res_skip_layers:
183 | torch.nn.utils.remove_weight_norm(l)
184 |
185 |
186 | class ResBlock1(torch.nn.Module):
187 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
188 | super(ResBlock1, self).__init__()
189 | self.convs1 = nn.ModuleList([
190 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
191 | padding=get_padding(kernel_size, dilation[0]))),
192 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
193 | padding=get_padding(kernel_size, dilation[1]))),
194 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
195 | padding=get_padding(kernel_size, dilation[2])))
196 | ])
197 | self.convs1.apply(init_weights)
198 |
199 | self.convs2 = nn.ModuleList([
200 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
201 | padding=get_padding(kernel_size, 1))),
202 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
203 | padding=get_padding(kernel_size, 1))),
204 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
205 | padding=get_padding(kernel_size, 1)))
206 | ])
207 | self.convs2.apply(init_weights)
208 |
209 | def forward(self, x, x_mask=None):
210 | for c1, c2 in zip(self.convs1, self.convs2):
211 | xt = F.leaky_relu(x, LRELU_SLOPE)
212 | if x_mask is not None:
213 | xt = xt * x_mask
214 | xt = c1(xt)
215 | xt = F.leaky_relu(xt, LRELU_SLOPE)
216 | if x_mask is not None:
217 | xt = xt * x_mask
218 | xt = c2(xt)
219 | x = xt + x
220 | if x_mask is not None:
221 | x = x * x_mask
222 | return x
223 |
224 | def remove_weight_norm(self):
225 | for l in self.convs1:
226 | remove_weight_norm(l)
227 | for l in self.convs2:
228 | remove_weight_norm(l)
229 |
230 |
231 | class ResBlock2(torch.nn.Module):
232 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
233 | super(ResBlock2, self).__init__()
234 | self.convs = nn.ModuleList([
235 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
236 | padding=get_padding(kernel_size, dilation[0]))),
237 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
238 | padding=get_padding(kernel_size, dilation[1])))
239 | ])
240 | self.convs.apply(init_weights)
241 |
242 | def forward(self, x, x_mask=None):
243 | for c in self.convs:
244 | xt = F.leaky_relu(x, LRELU_SLOPE)
245 | if x_mask is not None:
246 | xt = xt * x_mask
247 | xt = c(xt)
248 | x = xt + x
249 | if x_mask is not None:
250 | x = x * x_mask
251 | return x
252 |
253 | def remove_weight_norm(self):
254 | for l in self.convs:
255 | remove_weight_norm(l)
256 |
257 |
258 | class Log(nn.Module):
259 | def forward(self, x, x_mask, reverse=False, **kwargs):
260 | if not reverse:
261 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
262 | logdet = torch.sum(-y, [1, 2])
263 | return y, logdet
264 | else:
265 | x = torch.exp(x) * x_mask
266 | return x
267 |
268 |
269 | class Flip(nn.Module):
270 | def forward(self, x, *args, reverse=False, **kwargs):
271 | x = torch.flip(x, [1])
272 | if not reverse:
273 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
274 | return x, logdet
275 | else:
276 | return x
277 |
278 |
279 | class ElementwiseAffine(nn.Module):
280 | def __init__(self, channels):
281 | super().__init__()
282 | self.channels = channels
283 | self.m = nn.Parameter(torch.zeros(channels,1))
284 | self.logs = nn.Parameter(torch.zeros(channels,1))
285 |
286 | def forward(self, x, x_mask, reverse=False, **kwargs):
287 | if not reverse:
288 | y = self.m + torch.exp(self.logs) * x
289 | y = y * x_mask
290 | logdet = torch.sum(self.logs * x_mask, [1,2])
291 | return y, logdet
292 | else:
293 | x = (x - self.m) * torch.exp(-self.logs) * x_mask
294 | return x
295 |
296 |
297 | class ResidualCouplingLayer(nn.Module):
298 | def __init__(self,
299 | channels,
300 | hidden_channels,
301 | kernel_size,
302 | dilation_rate,
303 | n_layers,
304 | p_dropout=0,
305 | gin_channels=0,
306 | mean_only=False):
307 | assert channels % 2 == 0, "channels should be divisible by 2"
308 | super().__init__()
309 | self.channels = channels
310 | self.hidden_channels = hidden_channels
311 | self.kernel_size = kernel_size
312 | self.dilation_rate = dilation_rate
313 | self.n_layers = n_layers
314 | self.half_channels = channels // 2
315 | self.mean_only = mean_only
316 |
317 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
318 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
319 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
320 | self.post.weight.data.zero_()
321 | self.post.bias.data.zero_()
322 |
323 | def forward(self, x, x_mask, g=None, reverse=False):
324 | x0, x1 = torch.split(x, [self.half_channels]*2, 1)
325 | h = self.pre(x0) * x_mask
326 | h = self.enc(h, x_mask, g=g)
327 | stats = self.post(h) * x_mask
328 | if not self.mean_only:
329 | m, logs = torch.split(stats, [self.half_channels]*2, 1)
330 | else:
331 | m = stats
332 | logs = torch.zeros_like(m)
333 |
334 | if not reverse:
335 | x1 = m + x1 * torch.exp(logs) * x_mask
336 | x = torch.cat([x0, x1], 1)
337 | logdet = torch.sum(logs, [1,2])
338 | return x, logdet
339 | else:
340 | x1 = (x1 - m) * torch.exp(-logs) * x_mask
341 | x = torch.cat([x0, x1], 1)
342 | return x
343 |
344 |
345 | class ConvFlow(nn.Module):
346 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0):
347 | super().__init__()
348 | self.in_channels = in_channels
349 | self.filter_channels = filter_channels
350 | self.kernel_size = kernel_size
351 | self.n_layers = n_layers
352 | self.num_bins = num_bins
353 | self.tail_bound = tail_bound
354 | self.half_channels = in_channels // 2
355 |
356 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
357 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.)
358 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
359 | self.proj.weight.data.zero_()
360 | self.proj.bias.data.zero_()
361 |
362 | def forward(self, x, x_mask, g=None, reverse=False):
363 | x0, x1 = torch.split(x, [self.half_channels]*2, 1)
364 | h = self.pre(x0)
365 | h = self.convs(h, x_mask, g=g)
366 | h = self.proj(h) * x_mask
367 |
368 | b, c, t = x0.shape
369 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
370 |
371 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels)
372 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels)
373 | unnormalized_derivatives = h[..., 2 * self.num_bins:]
374 |
375 | x1, logabsdet = piecewise_rational_quadratic_transform(x1,
376 | unnormalized_widths,
377 | unnormalized_heights,
378 | unnormalized_derivatives,
379 | inverse=reverse,
380 | tails='linear',
381 | tail_bound=self.tail_bound
382 | )
383 |
384 | x = torch.cat([x0, x1], 1) * x_mask
385 | logdet = torch.sum(logabsdet * x_mask, [1,2])
386 | if not reverse:
387 | return x, logdet
388 | else:
389 | return x
390 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | import commons
8 | import modules
9 | import attentions
10 | import monotonic_align
11 |
12 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
13 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14 | from commons import init_weights, get_padding
15 |
16 | class StochasticDurationPredictor(nn.Module):
17 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0):
18 | super().__init__()
19 | filter_channels = in_channels # it needs to be removed from future version.
20 | self.in_channels = in_channels
21 | self.filter_channels = filter_channels
22 | self.kernel_size = kernel_size
23 | self.p_dropout = p_dropout
24 | self.n_flows = n_flows
25 | self.gin_channels = gin_channels
26 |
27 | self.log_flow = modules.Log()
28 | self.flows = nn.ModuleList()
29 | self.flows.append(modules.ElementwiseAffine(2))
30 | for i in range(n_flows):
31 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
32 | self.flows.append(modules.Flip())
33 |
34 | self.post_pre = nn.Conv1d(1, filter_channels, 1)
35 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
36 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
37 | self.post_flows = nn.ModuleList()
38 | self.post_flows.append(modules.ElementwiseAffine(2))
39 | for i in range(4):
40 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
41 | self.post_flows.append(modules.Flip())
42 |
43 | self.pre = nn.Conv1d(in_channels, filter_channels, 1)
44 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
45 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
46 | if gin_channels != 0:
47 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
48 |
49 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
50 | x = torch.detach(x)
51 | x = self.pre(x)
52 | if g is not None:
53 | g = torch.detach(g)
54 | x = x + self.cond(g)
55 | x = self.convs(x, x_mask)
56 | x = self.proj(x) * x_mask
57 |
58 | if not reverse:
59 | flows = self.flows
60 | assert w is not None
61 |
62 | logdet_tot_q = 0
63 | h_w = self.post_pre(w)
64 | h_w = self.post_convs(h_w, x_mask)
65 | h_w = self.post_proj(h_w) * x_mask
66 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
67 | z_q = e_q
68 | for flow in self.post_flows:
69 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
70 | logdet_tot_q += logdet_q
71 | z_u, z1 = torch.split(z_q, [1, 1], 1)
72 | u = torch.sigmoid(z_u) * x_mask
73 | z0 = (w - u) * x_mask
74 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
75 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q
76 |
77 | logdet_tot = 0
78 | z0, logdet = self.log_flow(z0, x_mask)
79 | logdet_tot += logdet
80 | z = torch.cat([z0, z1], 1)
81 | for flow in flows:
82 | z, logdet = flow(z, x_mask, g=x, reverse=reverse)
83 | logdet_tot = logdet_tot + logdet
84 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
85 | return nll + logq # [b]
86 | else:
87 | flows = list(reversed(self.flows))
88 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow
89 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
90 | for flow in flows:
91 | z = flow(z, x_mask, g=x, reverse=reverse)
92 | z0, z1 = torch.split(z, [1, 1], 1)
93 | logw = z0
94 | return logw
95 |
96 |
97 | class DurationPredictor(nn.Module):
98 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
99 | super().__init__()
100 |
101 | self.in_channels = in_channels
102 | self.filter_channels = filter_channels
103 | self.kernel_size = kernel_size
104 | self.p_dropout = p_dropout
105 | self.gin_channels = gin_channels
106 |
107 | self.drop = nn.Dropout(p_dropout)
108 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
109 | self.norm_1 = modules.LayerNorm(filter_channels)
110 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
111 | self.norm_2 = modules.LayerNorm(filter_channels)
112 | self.proj = nn.Conv1d(filter_channels, 1, 1)
113 |
114 | if gin_channels != 0:
115 | self.cond = nn.Conv1d(gin_channels, in_channels, 1)
116 |
117 | def forward(self, x, x_mask, g=None):
118 | x = torch.detach(x)
119 | if g is not None:
120 | g = torch.detach(g)
121 | x = x + self.cond(g)
122 | x = self.conv_1(x * x_mask)
123 | x = torch.relu(x)
124 | x = self.norm_1(x)
125 | x = self.drop(x)
126 | x = self.conv_2(x * x_mask)
127 | x = torch.relu(x)
128 | x = self.norm_2(x)
129 | x = self.drop(x)
130 | x = self.proj(x * x_mask)
131 | return x * x_mask
132 |
133 |
134 | class TextEncoder(nn.Module):
135 | def __init__(self,
136 | n_vocab,
137 | out_channels,
138 | hidden_channels,
139 | filter_channels,
140 | n_heads,
141 | n_layers,
142 | kernel_size,
143 | p_dropout):
144 | super().__init__()
145 | self.n_vocab = n_vocab
146 | self.out_channels = out_channels
147 | self.hidden_channels = hidden_channels
148 | self.filter_channels = filter_channels
149 | self.n_heads = n_heads
150 | self.n_layers = n_layers
151 | self.kernel_size = kernel_size
152 | self.p_dropout = p_dropout
153 |
154 | self.emb = nn.Embedding(n_vocab, hidden_channels)
155 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
156 |
157 | self.encoder = attentions.Encoder(
158 | hidden_channels,
159 | filter_channels,
160 | n_heads,
161 | n_layers,
162 | kernel_size,
163 | p_dropout)
164 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
165 |
166 | def forward(self, x, x_lengths):
167 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
168 | x = torch.transpose(x, 1, -1) # [b, h, t]
169 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
170 |
171 | x = self.encoder(x * x_mask, x_mask)
172 | stats = self.proj(x) * x_mask
173 |
174 | m, logs = torch.split(stats, self.out_channels, dim=1)
175 | return x, m, logs, x_mask
176 |
177 |
178 | class ResidualCouplingBlock(nn.Module):
179 | def __init__(self,
180 | channels,
181 | hidden_channels,
182 | kernel_size,
183 | dilation_rate,
184 | n_layers,
185 | n_flows=4,
186 | gin_channels=0):
187 | super().__init__()
188 | self.channels = channels
189 | self.hidden_channels = hidden_channels
190 | self.kernel_size = kernel_size
191 | self.dilation_rate = dilation_rate
192 | self.n_layers = n_layers
193 | self.n_flows = n_flows
194 | self.gin_channels = gin_channels
195 |
196 | self.flows = nn.ModuleList()
197 | for i in range(n_flows):
198 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
199 | self.flows.append(modules.Flip())
200 |
201 | def forward(self, x, x_mask, g=None, reverse=False):
202 | if not reverse:
203 | for flow in self.flows:
204 | x, _ = flow(x, x_mask, g=g, reverse=reverse)
205 | else:
206 | for flow in reversed(self.flows):
207 | x = flow(x, x_mask, g=g, reverse=reverse)
208 | return x
209 |
210 |
211 | class PosteriorEncoder(nn.Module):
212 | def __init__(self,
213 | in_channels,
214 | out_channels,
215 | hidden_channels,
216 | kernel_size,
217 | dilation_rate,
218 | n_layers,
219 | gin_channels=0):
220 | super().__init__()
221 | self.in_channels = in_channels
222 | self.out_channels = out_channels
223 | self.hidden_channels = hidden_channels
224 | self.kernel_size = kernel_size
225 | self.dilation_rate = dilation_rate
226 | self.n_layers = n_layers
227 | self.gin_channels = gin_channels
228 |
229 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
230 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
231 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
232 |
233 | def forward(self, x, x_lengths, g=None):
234 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
235 | x = self.pre(x) * x_mask
236 | x = self.enc(x, x_mask, g=g)
237 | stats = self.proj(x) * x_mask
238 | m, logs = torch.split(stats, self.out_channels, dim=1)
239 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
240 | return z, m, logs, x_mask
241 |
242 |
243 | class Generator(torch.nn.Module):
244 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
245 | super(Generator, self).__init__()
246 | self.num_kernels = len(resblock_kernel_sizes)
247 | self.num_upsamples = len(upsample_rates)
248 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
249 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2
250 |
251 | self.ups = nn.ModuleList()
252 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
253 | self.ups.append(weight_norm(
254 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
255 | k, u, padding=(k-u)//2)))
256 |
257 | self.resblocks = nn.ModuleList()
258 | for i in range(len(self.ups)):
259 | ch = upsample_initial_channel//(2**(i+1))
260 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
261 | self.resblocks.append(resblock(ch, k, d))
262 |
263 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
264 | self.ups.apply(init_weights)
265 |
266 | if gin_channels != 0:
267 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
268 |
269 | def forward(self, x, g=None):
270 | x = self.conv_pre(x)
271 | if g is not None:
272 | x = x + self.cond(g)
273 |
274 | for i in range(self.num_upsamples):
275 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
276 | x = self.ups[i](x)
277 | xs = None
278 | for j in range(self.num_kernels):
279 | if xs is None:
280 | xs = self.resblocks[i*self.num_kernels+j](x)
281 | else:
282 | xs += self.resblocks[i*self.num_kernels+j](x)
283 | x = xs / self.num_kernels
284 | x = F.leaky_relu(x)
285 | x = self.conv_post(x)
286 | x = torch.tanh(x)
287 |
288 | return x
289 |
290 | def remove_weight_norm(self):
291 | print('Removing weight norm...')
292 | for l in self.ups:
293 | remove_weight_norm(l)
294 | for l in self.resblocks:
295 | l.remove_weight_norm()
296 |
297 |
298 | class DiscriminatorP(torch.nn.Module):
299 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
300 | super(DiscriminatorP, self).__init__()
301 | self.period = period
302 | self.use_spectral_norm = use_spectral_norm
303 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
304 | self.convs = nn.ModuleList([
305 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
306 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
307 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
308 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))),
309 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))),
310 | ])
311 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
312 |
313 | def forward(self, x):
314 | fmap = []
315 |
316 | # 1d to 2d
317 | b, c, t = x.shape
318 | if t % self.period != 0: # pad first
319 | n_pad = self.period - (t % self.period)
320 | x = F.pad(x, (0, n_pad), "reflect")
321 | t = t + n_pad
322 | x = x.view(b, c, t // self.period, self.period)
323 |
324 | for l in self.convs:
325 | x = l(x)
326 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
327 | fmap.append(x)
328 | x = self.conv_post(x)
329 | fmap.append(x)
330 | x = torch.flatten(x, 1, -1)
331 |
332 | return x, fmap
333 |
334 |
335 | class DiscriminatorS(torch.nn.Module):
336 | def __init__(self, use_spectral_norm=False):
337 | super(DiscriminatorS, self).__init__()
338 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm
339 | self.convs = nn.ModuleList([
340 | norm_f(Conv1d(1, 16, 15, 1, padding=7)),
341 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
342 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
343 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
344 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
345 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
346 | ])
347 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
348 |
349 | def forward(self, x):
350 | fmap = []
351 |
352 | for l in self.convs:
353 | x = l(x)
354 | x = F.leaky_relu(x, modules.LRELU_SLOPE)
355 | fmap.append(x)
356 | x = self.conv_post(x)
357 | fmap.append(x)
358 | x = torch.flatten(x, 1, -1)
359 |
360 | return x, fmap
361 |
362 |
363 | class MultiPeriodDiscriminator(torch.nn.Module):
364 | def __init__(self, use_spectral_norm=False):
365 | super(MultiPeriodDiscriminator, self).__init__()
366 | periods = [2,3,5,7,11]
367 |
368 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
369 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
370 | self.discriminators = nn.ModuleList(discs)
371 |
372 | def forward(self, y, y_hat):
373 | y_d_rs = []
374 | y_d_gs = []
375 | fmap_rs = []
376 | fmap_gs = []
377 | for i, d in enumerate(self.discriminators):
378 | y_d_r, fmap_r = d(y)
379 | y_d_g, fmap_g = d(y_hat)
380 | y_d_rs.append(y_d_r)
381 | y_d_gs.append(y_d_g)
382 | fmap_rs.append(fmap_r)
383 | fmap_gs.append(fmap_g)
384 |
385 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
386 |
387 |
388 |
389 | class SynthesizerTrn(nn.Module):
390 | """
391 | Synthesizer for Training
392 | """
393 |
394 | def __init__(self,
395 | n_vocab,
396 | spec_channels,
397 | segment_size,
398 | inter_channels,
399 | hidden_channels,
400 | filter_channels,
401 | n_heads,
402 | n_layers,
403 | kernel_size,
404 | p_dropout,
405 | resblock,
406 | resblock_kernel_sizes,
407 | resblock_dilation_sizes,
408 | upsample_rates,
409 | upsample_initial_channel,
410 | upsample_kernel_sizes,
411 | n_speakers=0,
412 | gin_channels=0,
413 | use_sdp=True,
414 | **kwargs):
415 |
416 | super().__init__()
417 | self.n_vocab = n_vocab
418 | self.spec_channels = spec_channels
419 | self.inter_channels = inter_channels
420 | self.hidden_channels = hidden_channels
421 | self.filter_channels = filter_channels
422 | self.n_heads = n_heads
423 | self.n_layers = n_layers
424 | self.kernel_size = kernel_size
425 | self.p_dropout = p_dropout
426 | self.resblock = resblock
427 | self.resblock_kernel_sizes = resblock_kernel_sizes
428 | self.resblock_dilation_sizes = resblock_dilation_sizes
429 | self.upsample_rates = upsample_rates
430 | self.upsample_initial_channel = upsample_initial_channel
431 | self.upsample_kernel_sizes = upsample_kernel_sizes
432 | self.segment_size = segment_size
433 | self.n_speakers = n_speakers
434 | self.gin_channels = gin_channels
435 |
436 | self.use_sdp = use_sdp
437 |
438 | self.enc_p = TextEncoder(n_vocab,
439 | inter_channels,
440 | hidden_channels,
441 | filter_channels,
442 | n_heads,
443 | n_layers,
444 | kernel_size,
445 | p_dropout)
446 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
447 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
448 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
449 |
450 | if use_sdp:
451 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels)
452 | else:
453 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels)
454 |
455 | if n_speakers > 1:
456 | self.emb_g = nn.Embedding(n_speakers, gin_channels)
457 |
458 | def forward(self, x, x_lengths, y, y_lengths, sid=None):
459 |
460 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
461 | if self.n_speakers > 0:
462 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
463 | else:
464 | g = None
465 |
466 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
467 | z_p = self.flow(z, y_mask, g=g)
468 |
469 | with torch.no_grad():
470 | # negative cross-entropy
471 | s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
472 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s]
473 | neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
474 | neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
475 | neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s]
476 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
477 |
478 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
479 | attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()
480 |
481 | w = attn.sum(2)
482 | if self.use_sdp:
483 | l_length = self.dp(x, x_mask, w, g=g)
484 | l_length = l_length / torch.sum(x_mask)
485 | else:
486 | logw_ = torch.log(w + 1e-6) * x_mask
487 | logw = self.dp(x, x_mask, g=g)
488 | l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging
489 |
490 | # expand prior
491 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
492 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
493 |
494 | z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
495 | o = self.dec(z_slice, g=g)
496 | return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
497 |
498 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
499 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
500 | if self.n_speakers > 0:
501 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
502 | else:
503 | g = None
504 |
505 | if self.use_sdp:
506 | logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
507 | else:
508 | logw = self.dp(x, x_mask, g=g)
509 | w = torch.exp(logw) * x_mask * length_scale
510 | w_ceil = torch.ceil(w)
511 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
512 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype)
513 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
514 | attn = commons.generate_path(w_ceil, attn_mask)
515 |
516 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
517 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
518 |
519 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
520 | z = self.flow(z_p, y_mask, g=g, reverse=True)
521 | o = self.dec((z * y_mask)[:,:,:max_len], g=g)
522 | return o, attn, y_mask, (z, z_p, m_p, logs_p)
523 |
524 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
525 | assert self.n_speakers > 0, "n_speakers have to be larger than 0."
526 | g_src = self.emb_g(sid_src).unsqueeze(-1)
527 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
528 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
529 | z_p = self.flow(z, y_mask, g=g_src)
530 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
531 | o_hat = self.dec(z_hat * y_mask, g=g_tgt)
532 | return o_hat, y_mask, (z, z_p, z_hat)
533 |
534 |
--------------------------------------------------------------------------------
/wavfile.py:
--------------------------------------------------------------------------------
1 | """
2 | Module to read / write wav files using NumPy arrays
3 |
4 | Functions
5 | ---------
6 | `read`: Return the sample rate (in samples/sec) and data from a WAV file.
7 |
8 | `write`: Write a NumPy array as a WAV file.
9 |
10 | """
11 | import io
12 | import struct
13 | import sys
14 | import warnings
15 | from enum import IntEnum
16 |
17 | import numpy
18 |
19 | __all__ = ["WavFileWarning", "read", "write"]
20 |
21 |
22 | class WavFileWarning(UserWarning):
23 | pass
24 |
25 |
26 | class WAVE_FORMAT(IntEnum):
27 | """
28 | WAVE form wFormatTag IDs
29 |
30 | Complete list is in mmreg.h in Windows 10 SDK. ALAC and OPUS are the
31 | newest additions, in v10.0.14393 2016-07
32 | """
33 |
34 | UNKNOWN = 0x0000
35 | PCM = 0x0001
36 | ADPCM = 0x0002
37 | IEEE_FLOAT = 0x0003
38 | VSELP = 0x0004
39 | IBM_CVSD = 0x0005
40 | ALAW = 0x0006
41 | MULAW = 0x0007
42 | DTS = 0x0008
43 | DRM = 0x0009
44 | WMAVOICE9 = 0x000A
45 | WMAVOICE10 = 0x000B
46 | OKI_ADPCM = 0x0010
47 | DVI_ADPCM = 0x0011
48 | IMA_ADPCM = 0x0011 # Duplicate
49 | MEDIASPACE_ADPCM = 0x0012
50 | SIERRA_ADPCM = 0x0013
51 | G723_ADPCM = 0x0014
52 | DIGISTD = 0x0015
53 | DIGIFIX = 0x0016
54 | DIALOGIC_OKI_ADPCM = 0x0017
55 | MEDIAVISION_ADPCM = 0x0018
56 | CU_CODEC = 0x0019
57 | HP_DYN_VOICE = 0x001A
58 | YAMAHA_ADPCM = 0x0020
59 | SONARC = 0x0021
60 | DSPGROUP_TRUESPEECH = 0x0022
61 | ECHOSC1 = 0x0023
62 | AUDIOFILE_AF36 = 0x0024
63 | APTX = 0x0025
64 | AUDIOFILE_AF10 = 0x0026
65 | PROSODY_1612 = 0x0027
66 | LRC = 0x0028
67 | DOLBY_AC2 = 0x0030
68 | GSM610 = 0x0031
69 | MSNAUDIO = 0x0032
70 | ANTEX_ADPCME = 0x0033
71 | CONTROL_RES_VQLPC = 0x0034
72 | DIGIREAL = 0x0035
73 | DIGIADPCM = 0x0036
74 | CONTROL_RES_CR10 = 0x0037
75 | NMS_VBXADPCM = 0x0038
76 | CS_IMAADPCM = 0x0039
77 | ECHOSC3 = 0x003A
78 | ROCKWELL_ADPCM = 0x003B
79 | ROCKWELL_DIGITALK = 0x003C
80 | XEBEC = 0x003D
81 | G721_ADPCM = 0x0040
82 | G728_CELP = 0x0041
83 | MSG723 = 0x0042
84 | INTEL_G723_1 = 0x0043
85 | INTEL_G729 = 0x0044
86 | SHARP_G726 = 0x0045
87 | MPEG = 0x0050
88 | RT24 = 0x0052
89 | PAC = 0x0053
90 | MPEGLAYER3 = 0x0055
91 | LUCENT_G723 = 0x0059
92 | CIRRUS = 0x0060
93 | ESPCM = 0x0061
94 | VOXWARE = 0x0062
95 | CANOPUS_ATRAC = 0x0063
96 | G726_ADPCM = 0x0064
97 | G722_ADPCM = 0x0065
98 | DSAT = 0x0066
99 | DSAT_DISPLAY = 0x0067
100 | VOXWARE_BYTE_ALIGNED = 0x0069
101 | VOXWARE_AC8 = 0x0070
102 | VOXWARE_AC10 = 0x0071
103 | VOXWARE_AC16 = 0x0072
104 | VOXWARE_AC20 = 0x0073
105 | VOXWARE_RT24 = 0x0074
106 | VOXWARE_RT29 = 0x0075
107 | VOXWARE_RT29HW = 0x0076
108 | VOXWARE_VR12 = 0x0077
109 | VOXWARE_VR18 = 0x0078
110 | VOXWARE_TQ40 = 0x0079
111 | VOXWARE_SC3 = 0x007A
112 | VOXWARE_SC3_1 = 0x007B
113 | SOFTSOUND = 0x0080
114 | VOXWARE_TQ60 = 0x0081
115 | MSRT24 = 0x0082
116 | G729A = 0x0083
117 | MVI_MVI2 = 0x0084
118 | DF_G726 = 0x0085
119 | DF_GSM610 = 0x0086
120 | ISIAUDIO = 0x0088
121 | ONLIVE = 0x0089
122 | MULTITUDE_FT_SX20 = 0x008A
123 | INFOCOM_ITS_G721_ADPCM = 0x008B
124 | CONVEDIA_G729 = 0x008C
125 | CONGRUENCY = 0x008D
126 | SBC24 = 0x0091
127 | DOLBY_AC3_SPDIF = 0x0092
128 | MEDIASONIC_G723 = 0x0093
129 | PROSODY_8KBPS = 0x0094
130 | ZYXEL_ADPCM = 0x0097
131 | PHILIPS_LPCBB = 0x0098
132 | PACKED = 0x0099
133 | MALDEN_PHONYTALK = 0x00A0
134 | RACAL_RECORDER_GSM = 0x00A1
135 | RACAL_RECORDER_G720_A = 0x00A2
136 | RACAL_RECORDER_G723_1 = 0x00A3
137 | RACAL_RECORDER_TETRA_ACELP = 0x00A4
138 | NEC_AAC = 0x00B0
139 | RAW_AAC1 = 0x00FF
140 | RHETOREX_ADPCM = 0x0100
141 | IRAT = 0x0101
142 | VIVO_G723 = 0x0111
143 | VIVO_SIREN = 0x0112
144 | PHILIPS_CELP = 0x0120
145 | PHILIPS_GRUNDIG = 0x0121
146 | DIGITAL_G723 = 0x0123
147 | SANYO_LD_ADPCM = 0x0125
148 | SIPROLAB_ACEPLNET = 0x0130
149 | SIPROLAB_ACELP4800 = 0x0131
150 | SIPROLAB_ACELP8V3 = 0x0132
151 | SIPROLAB_G729 = 0x0133
152 | SIPROLAB_G729A = 0x0134
153 | SIPROLAB_KELVIN = 0x0135
154 | VOICEAGE_AMR = 0x0136
155 | G726ADPCM = 0x0140
156 | DICTAPHONE_CELP68 = 0x0141
157 | DICTAPHONE_CELP54 = 0x0142
158 | QUALCOMM_PUREVOICE = 0x0150
159 | QUALCOMM_HALFRATE = 0x0151
160 | TUBGSM = 0x0155
161 | MSAUDIO1 = 0x0160
162 | WMAUDIO2 = 0x0161
163 | WMAUDIO3 = 0x0162
164 | WMAUDIO_LOSSLESS = 0x0163
165 | WMASPDIF = 0x0164
166 | UNISYS_NAP_ADPCM = 0x0170
167 | UNISYS_NAP_ULAW = 0x0171
168 | UNISYS_NAP_ALAW = 0x0172
169 | UNISYS_NAP_16K = 0x0173
170 | SYCOM_ACM_SYC008 = 0x0174
171 | SYCOM_ACM_SYC701_G726L = 0x0175
172 | SYCOM_ACM_SYC701_CELP54 = 0x0176
173 | SYCOM_ACM_SYC701_CELP68 = 0x0177
174 | KNOWLEDGE_ADVENTURE_ADPCM = 0x0178
175 | FRAUNHOFER_IIS_MPEG2_AAC = 0x0180
176 | DTS_DS = 0x0190
177 | CREATIVE_ADPCM = 0x0200
178 | CREATIVE_FASTSPEECH8 = 0x0202
179 | CREATIVE_FASTSPEECH10 = 0x0203
180 | UHER_ADPCM = 0x0210
181 | ULEAD_DV_AUDIO = 0x0215
182 | ULEAD_DV_AUDIO_1 = 0x0216
183 | QUARTERDECK = 0x0220
184 | ILINK_VC = 0x0230
185 | RAW_SPORT = 0x0240
186 | ESST_AC3 = 0x0241
187 | GENERIC_PASSTHRU = 0x0249
188 | IPI_HSX = 0x0250
189 | IPI_RPELP = 0x0251
190 | CS2 = 0x0260
191 | SONY_SCX = 0x0270
192 | SONY_SCY = 0x0271
193 | SONY_ATRAC3 = 0x0272
194 | SONY_SPC = 0x0273
195 | TELUM_AUDIO = 0x0280
196 | TELUM_IA_AUDIO = 0x0281
197 | NORCOM_VOICE_SYSTEMS_ADPCM = 0x0285
198 | FM_TOWNS_SND = 0x0300
199 | MICRONAS = 0x0350
200 | MICRONAS_CELP833 = 0x0351
201 | BTV_DIGITAL = 0x0400
202 | INTEL_MUSIC_CODER = 0x0401
203 | INDEO_AUDIO = 0x0402
204 | QDESIGN_MUSIC = 0x0450
205 | ON2_VP7_AUDIO = 0x0500
206 | ON2_VP6_AUDIO = 0x0501
207 | VME_VMPCM = 0x0680
208 | TPC = 0x0681
209 | LIGHTWAVE_LOSSLESS = 0x08AE
210 | OLIGSM = 0x1000
211 | OLIADPCM = 0x1001
212 | OLICELP = 0x1002
213 | OLISBC = 0x1003
214 | OLIOPR = 0x1004
215 | LH_CODEC = 0x1100
216 | LH_CODEC_CELP = 0x1101
217 | LH_CODEC_SBC8 = 0x1102
218 | LH_CODEC_SBC12 = 0x1103
219 | LH_CODEC_SBC16 = 0x1104
220 | NORRIS = 0x1400
221 | ISIAUDIO_2 = 0x1401
222 | SOUNDSPACE_MUSICOMPRESS = 0x1500
223 | MPEG_ADTS_AAC = 0x1600
224 | MPEG_RAW_AAC = 0x1601
225 | MPEG_LOAS = 0x1602
226 | NOKIA_MPEG_ADTS_AAC = 0x1608
227 | NOKIA_MPEG_RAW_AAC = 0x1609
228 | VODAFONE_MPEG_ADTS_AAC = 0x160A
229 | VODAFONE_MPEG_RAW_AAC = 0x160B
230 | MPEG_HEAAC = 0x1610
231 | VOXWARE_RT24_SPEECH = 0x181C
232 | SONICFOUNDRY_LOSSLESS = 0x1971
233 | INNINGS_TELECOM_ADPCM = 0x1979
234 | LUCENT_SX8300P = 0x1C07
235 | LUCENT_SX5363S = 0x1C0C
236 | CUSEEME = 0x1F03
237 | NTCSOFT_ALF2CM_ACM = 0x1FC4
238 | DVM = 0x2000
239 | DTS2 = 0x2001
240 | MAKEAVIS = 0x3313
241 | DIVIO_MPEG4_AAC = 0x4143
242 | NOKIA_ADAPTIVE_MULTIRATE = 0x4201
243 | DIVIO_G726 = 0x4243
244 | LEAD_SPEECH = 0x434C
245 | LEAD_VORBIS = 0x564C
246 | WAVPACK_AUDIO = 0x5756
247 | OGG_VORBIS_MODE_1 = 0x674F
248 | OGG_VORBIS_MODE_2 = 0x6750
249 | OGG_VORBIS_MODE_3 = 0x6751
250 | OGG_VORBIS_MODE_1_PLUS = 0x676F
251 | OGG_VORBIS_MODE_2_PLUS = 0x6770
252 | OGG_VORBIS_MODE_3_PLUS = 0x6771
253 | ALAC = 0x6C61
254 | _3COM_NBX = 0x7000 # Can't have leading digit
255 | OPUS = 0x704F
256 | FAAD_AAC = 0x706D
257 | AMR_NB = 0x7361
258 | AMR_WB = 0x7362
259 | AMR_WP = 0x7363
260 | GSM_AMR_CBR = 0x7A21
261 | GSM_AMR_VBR_SID = 0x7A22
262 | COMVERSE_INFOSYS_G723_1 = 0xA100
263 | COMVERSE_INFOSYS_AVQSBC = 0xA101
264 | COMVERSE_INFOSYS_SBC = 0xA102
265 | SYMBOL_G729_A = 0xA103
266 | VOICEAGE_AMR_WB = 0xA104
267 | INGENIENT_G726 = 0xA105
268 | MPEG4_AAC = 0xA106
269 | ENCORE_G726 = 0xA107
270 | ZOLL_ASAO = 0xA108
271 | SPEEX_VOICE = 0xA109
272 | VIANIX_MASC = 0xA10A
273 | WM9_SPECTRUM_ANALYZER = 0xA10B
274 | WMF_SPECTRUM_ANAYZER = 0xA10C
275 | GSM_610 = 0xA10D
276 | GSM_620 = 0xA10E
277 | GSM_660 = 0xA10F
278 | GSM_690 = 0xA110
279 | GSM_ADAPTIVE_MULTIRATE_WB = 0xA111
280 | POLYCOM_G722 = 0xA112
281 | POLYCOM_G728 = 0xA113
282 | POLYCOM_G729_A = 0xA114
283 | POLYCOM_SIREN = 0xA115
284 | GLOBAL_IP_ILBC = 0xA116
285 | RADIOTIME_TIME_SHIFT_RADIO = 0xA117
286 | NICE_ACA = 0xA118
287 | NICE_ADPCM = 0xA119
288 | VOCORD_G721 = 0xA11A
289 | VOCORD_G726 = 0xA11B
290 | VOCORD_G722_1 = 0xA11C
291 | VOCORD_G728 = 0xA11D
292 | VOCORD_G729 = 0xA11E
293 | VOCORD_G729_A = 0xA11F
294 | VOCORD_G723_1 = 0xA120
295 | VOCORD_LBC = 0xA121
296 | NICE_G728 = 0xA122
297 | FRACE_TELECOM_G729 = 0xA123
298 | CODIAN = 0xA124
299 | FLAC = 0xF1AC
300 | EXTENSIBLE = 0xFFFE
301 | DEVELOPMENT = 0xFFFF
302 |
303 |
304 | KNOWN_WAVE_FORMATS = {WAVE_FORMAT.PCM, WAVE_FORMAT.IEEE_FLOAT}
305 |
306 |
307 | def _raise_bad_format(format_tag):
308 | try:
309 | format_name = WAVE_FORMAT(format_tag).name
310 | except ValueError:
311 | format_name = f"{format_tag:#06x}"
312 | raise ValueError(
313 | f"Unknown wave file format: {format_name}. Supported "
314 | "formats: " + ", ".join(x.name for x in KNOWN_WAVE_FORMATS)
315 | )
316 |
317 |
318 | def _read_fmt_chunk(fid, is_big_endian):
319 | """
320 | Returns
321 | -------
322 | size : int
323 | size of format subchunk in bytes (minus 8 for "fmt " and itself)
324 | format_tag : int
325 | PCM, float, or compressed format
326 | channels : int
327 | number of channels
328 | fs : int
329 | sampling frequency in samples per second
330 | bytes_per_second : int
331 | overall byte rate for the file
332 | block_align : int
333 | bytes per sample, including all channels
334 | bit_depth : int
335 | bits per sample
336 |
337 | Notes
338 | -----
339 | Assumes file pointer is immediately after the 'fmt ' id
340 | """
341 | if is_big_endian:
342 | fmt = ">"
343 | else:
344 | fmt = "<"
345 |
346 | size = struct.unpack(fmt + "I", fid.read(4))[0]
347 |
348 | if size < 16:
349 | raise ValueError("Binary structure of wave file is not compliant")
350 |
351 | res = struct.unpack(fmt + "HHIIHH", fid.read(16))
352 | bytes_read = 16
353 |
354 | format_tag, channels, fs, bytes_per_second, block_align, bit_depth = res
355 |
356 | if format_tag == WAVE_FORMAT.EXTENSIBLE and size >= (16 + 2):
357 | ext_chunk_size = struct.unpack(fmt + "H", fid.read(2))[0]
358 | bytes_read += 2
359 | if ext_chunk_size >= 22:
360 | extensible_chunk_data = fid.read(22)
361 | bytes_read += 22
362 | raw_guid = extensible_chunk_data[2 + 4 : 2 + 4 + 16]
363 | # GUID template {XXXXXXXX-0000-0010-8000-00AA00389B71} (RFC-2361)
364 | # MS GUID byte order: first three groups are native byte order,
365 | # rest is Big Endian
366 | if is_big_endian:
367 | tail = b"\x00\x00\x00\x10\x80\x00\x00\xAA\x00\x38\x9B\x71"
368 | else:
369 | tail = b"\x00\x00\x10\x00\x80\x00\x00\xAA\x00\x38\x9B\x71"
370 | if raw_guid.endswith(tail):
371 | format_tag = struct.unpack(fmt + "I", raw_guid[:4])[0]
372 | else:
373 | raise ValueError("Binary structure of wave file is not compliant")
374 |
375 | if format_tag not in KNOWN_WAVE_FORMATS:
376 | _raise_bad_format(format_tag)
377 |
378 | # move file pointer to next chunk
379 | if size > bytes_read:
380 | fid.read(size - bytes_read)
381 |
382 | # fmt should always be 16, 18 or 40, but handle it just in case
383 | _handle_pad_byte(fid, size)
384 |
385 | return (size, format_tag, channels, fs, bytes_per_second, block_align, bit_depth)
386 |
387 |
388 | def _read_data_chunk(
389 | fid, format_tag, channels, bit_depth, is_big_endian, block_align, mmap=False
390 | ):
391 | """
392 | Notes
393 | -----
394 | Assumes file pointer is immediately after the 'data' id
395 |
396 | It's possible to not use all available bits in a container, or to store
397 | samples in a container bigger than necessary, so bytes_per_sample uses
398 | the actual reported container size (nBlockAlign / nChannels). Real-world
399 | examples:
400 |
401 | Adobe Audition's "24-bit packed int (type 1, 20-bit)"
402 |
403 | nChannels = 2, nBlockAlign = 6, wBitsPerSample = 20
404 |
405 | http://www-mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/Samples/AFsp/M1F1-int12-AFsp.wav
406 | is:
407 |
408 | nChannels = 2, nBlockAlign = 4, wBitsPerSample = 12
409 |
410 | http://www-mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/Docs/multichaudP.pdf
411 | gives an example of:
412 |
413 | nChannels = 2, nBlockAlign = 8, wBitsPerSample = 20
414 | """
415 | if is_big_endian:
416 | fmt = ">"
417 | else:
418 | fmt = "<"
419 |
420 | # Size of the data subchunk in bytes
421 | size = struct.unpack(fmt + "I", fid.read(4))[0]
422 |
423 | # Number of bytes per sample (sample container size)
424 | bytes_per_sample = block_align // channels
425 | n_samples = size // bytes_per_sample
426 |
427 | if format_tag == WAVE_FORMAT.PCM:
428 | if 1 <= bit_depth <= 8:
429 | dtype = "u1" # WAV of 8-bit integer or less are unsigned
430 | elif bytes_per_sample in {3, 5, 6, 7}:
431 | # No compatible dtype. Load as raw bytes for reshaping later.
432 | dtype = "V1"
433 | elif bit_depth <= 64:
434 | # Remaining bit depths can map directly to signed numpy dtypes
435 | dtype = f"{fmt}i{bytes_per_sample}"
436 | else:
437 | raise ValueError(
438 | "Unsupported bit depth: the WAV file "
439 | f"has {bit_depth}-bit integer data."
440 | )
441 | elif format_tag == WAVE_FORMAT.IEEE_FLOAT:
442 | if bit_depth in {32, 64}:
443 | dtype = f"{fmt}f{bytes_per_sample}"
444 | else:
445 | raise ValueError(
446 | "Unsupported bit depth: the WAV file "
447 | f"has {bit_depth}-bit floating-point data."
448 | )
449 | else:
450 | _raise_bad_format(format_tag)
451 |
452 | start = fid.tell()
453 | if not mmap:
454 | try:
455 | count = size if dtype == "V1" else n_samples
456 | data = numpy.fromfile(fid, dtype=dtype, count=count)
457 | except io.UnsupportedOperation: # not a C-like file
458 | fid.seek(start, 0) # just in case it seeked, though it shouldn't
459 | data = numpy.frombuffer(fid.read(size), dtype=dtype)
460 |
461 | if dtype == "V1":
462 | # Rearrange raw bytes into smallest compatible numpy dtype
463 | dt = f"{fmt}i4" if bytes_per_sample == 3 else f"{fmt}i8"
464 | a = numpy.zeros(
465 | (len(data) // bytes_per_sample, numpy.dtype(dt).itemsize), dtype="V1"
466 | )
467 | if is_big_endian:
468 | a[:, :bytes_per_sample] = data.reshape((-1, bytes_per_sample))
469 | else:
470 | a[:, -bytes_per_sample:] = data.reshape((-1, bytes_per_sample))
471 | data = a.view(dt).reshape(a.shape[:-1])
472 | else:
473 | if bytes_per_sample in {1, 2, 4, 8}:
474 | start = fid.tell()
475 | data = numpy.memmap(
476 | fid, dtype=dtype, mode="c", offset=start, shape=(n_samples,)
477 | )
478 | fid.seek(start + size)
479 | else:
480 | raise ValueError(
481 | "mmap=True not compatible with "
482 | f"{bytes_per_sample}-byte container size."
483 | )
484 |
485 | _handle_pad_byte(fid, size)
486 |
487 | if channels > 1:
488 | data = data.reshape(-1, channels)
489 | return data
490 |
491 |
492 | def _skip_unknown_chunk(fid, is_big_endian):
493 | if is_big_endian:
494 | fmt = ">I"
495 | else:
496 | fmt = ">> from os.path import dirname, join as pjoin
613 | >>> from scipy.io import wavfile
614 | >>> import scipy.io
615 |
616 | Get the filename for an example .wav file from the tests/data directory.
617 |
618 | >>> data_dir = pjoin(dirname(scipy.io.__file__), 'tests', 'data')
619 | >>> wav_fname = pjoin(data_dir, 'test-44100Hz-2ch-32bit-float-be.wav')
620 |
621 | Load the .wav file contents.
622 |
623 | >>> samplerate, data = wavfile.read(wav_fname)
624 | >>> print(f"number of channels = {data.shape[1]}")
625 | number of channels = 2
626 | >>> length = data.shape[0] / samplerate
627 | >>> print(f"length = {length}s")
628 | length = 0.01s
629 |
630 | Plot the waveform.
631 |
632 | >>> import matplotlib.pyplot as plt
633 | >>> import numpy as np
634 | >>> time = np.linspace(0., length, data.shape[0])
635 | >>> plt.plot(time, data[:, 0], label="Left channel")
636 | >>> plt.plot(time, data[:, 1], label="Right channel")
637 | >>> plt.legend()
638 | >>> plt.xlabel("Time [s]")
639 | >>> plt.ylabel("Amplitude")
640 | >>> plt.show()
641 |
642 | """
643 | if hasattr(filename, "read"):
644 | fid = filename
645 | mmap = False
646 | else:
647 | # pylint: disable=consider-using-with
648 | fid = open(filename, "rb")
649 |
650 | try:
651 | file_size, is_big_endian = _read_riff_chunk(fid)
652 | fmt_chunk_received = False
653 | data_chunk_received = False
654 | while fid.tell() < file_size:
655 | # read the next chunk
656 | chunk_id = fid.read(4)
657 |
658 | if not chunk_id:
659 | if data_chunk_received:
660 | # End of file but data successfully read
661 | warnings.warn(
662 | f"Reached EOF prematurely; finished at {fid.tell()} bytes, "
663 | "expected {file_size} bytes from header.",
664 | WavFileWarning,
665 | stacklevel=2,
666 | )
667 | break
668 |
669 | raise ValueError("Unexpected end of file.")
670 | if len(chunk_id) < 4:
671 | msg = f"Incomplete chunk ID: {repr(chunk_id)}"
672 | # If we have the data, ignore the broken chunk
673 | if fmt_chunk_received and data_chunk_received:
674 | warnings.warn(msg + ", ignoring it.", WavFileWarning, stacklevel=2)
675 | else:
676 | raise ValueError(msg)
677 |
678 | if chunk_id == b"fmt ":
679 | fmt_chunk_received = True
680 | fmt_chunk = _read_fmt_chunk(fid, is_big_endian)
681 | format_tag, channels, fs = fmt_chunk[1:4]
682 | bit_depth = fmt_chunk[6]
683 | block_align = fmt_chunk[5]
684 | elif chunk_id == b"fact":
685 | _skip_unknown_chunk(fid, is_big_endian)
686 | elif chunk_id == b"data":
687 | data_chunk_received = True
688 | if not fmt_chunk_received:
689 | raise ValueError("No fmt chunk before data")
690 | data = _read_data_chunk(
691 | fid,
692 | format_tag,
693 | channels,
694 | bit_depth,
695 | is_big_endian,
696 | block_align,
697 | mmap,
698 | )
699 | elif chunk_id == b"LIST":
700 | # Someday this could be handled properly but for now skip it
701 | _skip_unknown_chunk(fid, is_big_endian)
702 | elif chunk_id in {b"JUNK", b"Fake"}:
703 | # Skip alignment chunks without warning
704 | _skip_unknown_chunk(fid, is_big_endian)
705 | else:
706 | warnings.warn(
707 | "Chunk (non-data) not understood, skipping it.",
708 | WavFileWarning,
709 | stacklevel=2,
710 | )
711 | _skip_unknown_chunk(fid, is_big_endian)
712 | finally:
713 | if not hasattr(filename, "read"):
714 | fid.close()
715 | else:
716 | fid.seek(0)
717 |
718 | return fs, data
719 |
720 |
721 | def write(filename, rate, data):
722 | """
723 | Write a NumPy array as a WAV file.
724 |
725 | Parameters
726 | ----------
727 | filename : string or open file handle
728 | Output wav file.
729 | rate : int
730 | The sample rate (in samples/sec).
731 | data : ndarray
732 | A 1-D or 2-D NumPy array of either integer or float data-type.
733 |
734 | Notes
735 | -----
736 | * Writes a simple uncompressed WAV file.
737 | * To write multiple-channels, use a 2-D array of shape
738 | (Nsamples, Nchannels).
739 | * The bits-per-sample and PCM/float will be determined by the data-type.
740 |
741 | Common data types: [1]_
742 |
743 | ===================== =========== =========== =============
744 | WAV format Min Max NumPy dtype
745 | ===================== =========== =========== =============
746 | 32-bit floating-point -1.0 +1.0 float32
747 | 32-bit PCM -2147483648 +2147483647 int32
748 | 16-bit PCM -32768 +32767 int16
749 | 8-bit PCM 0 255 uint8
750 | ===================== =========== =========== =============
751 |
752 | Note that 8-bit PCM is unsigned.
753 |
754 | References
755 | ----------
756 | .. [1] IBM Corporation and Microsoft Corporation, "Multimedia Programming
757 | Interface and Data Specifications 1.0", section "Data Format of the
758 | Samples", August 1991
759 | http://www.tactilemedia.com/info/MCI_Control_Info.html
760 |
761 | Examples
762 | --------
763 | Create a 100Hz sine wave, sampled at 44100Hz.
764 | Write to 16-bit PCM, Mono.
765 |
766 | >>> from scipy.io.wavfile import write
767 | >>> samplerate = 44100; fs = 100
768 | >>> t = np.linspace(0., 1., samplerate)
769 | >>> amplitude = np.iinfo(np.int16).max
770 | >>> data = amplitude * np.sin(2. * np.pi * fs * t)
771 | >>> write("example.wav", samplerate, data.astype(np.int16))
772 |
773 | """
774 | if hasattr(filename, "write"):
775 | fid = filename
776 | else:
777 | # pylint: disable=consider-using-with
778 | fid = open(filename, "wb")
779 |
780 | fs = rate
781 |
782 | try:
783 | dkind = data.dtype.kind
784 | if not (
785 | dkind == "i" or dkind == "f" or (dkind == "u" and data.dtype.itemsize == 1)
786 | ):
787 | raise ValueError(f"Unsupported data type '{data.dtype}'")
788 |
789 | header_data = b""
790 |
791 | header_data += b"RIFF"
792 | header_data += b"\x00\x00\x00\x00"
793 | header_data += b"WAVE"
794 |
795 | # fmt chunk
796 | header_data += b"fmt "
797 | if dkind == "f":
798 | format_tag = WAVE_FORMAT.IEEE_FLOAT
799 | else:
800 | format_tag = WAVE_FORMAT.PCM
801 | if data.ndim == 1:
802 | channels = 1
803 | else:
804 | channels = data.shape[1]
805 | bit_depth = data.dtype.itemsize * 8
806 | bytes_per_second = fs * (bit_depth // 8) * channels
807 | block_align = channels * (bit_depth // 8)
808 |
809 | fmt_chunk_data = struct.pack(
810 | " 0xFFFFFFFF:
832 | raise ValueError("Data exceeds wave file size limit")
833 |
834 | fid.write(header_data)
835 |
836 | # data chunk
837 | fid.write(b"data")
838 | fid.write(struct.pack("" or (
840 | data.dtype.byteorder == "=" and sys.byteorder == "big"
841 | ):
842 | data = data.byteswap()
843 | _array_tofile(fid, data)
844 |
845 | # Determine file size and place it in correct
846 | # position at start of the file.
847 | size = fid.tell()
848 | fid.seek(4)
849 | fid.write(struct.pack("