├── README.md
├── styletts2_colab.ipynb
├── styletts2_train_colab.ipynb
├── styletts2_test.ipynb
├── bucilianus1.ipynb
└── styletts2_test_colab.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | 🐣 Please follow me for new updates https://twitter.com/camenduru
2 | 🔥 Please join our discord server https://discord.gg/k5BwmmvJJU
3 | 🥳 Please join my patreon community https://patreon.com/camenduru
4 |
5 | # 🚦 WIP 🚦
6 |
7 | ## 🦒 Colab
8 |
9 | | Colab | Info
10 | | --- | --- |
11 | [](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/styletts2_colab.ipynb) | styletts2_colab (Thanks to [@realmrfakename](https://twitter.com/realmrfakename) ❤)
12 |
13 | ## Tutorial
14 |
15 | ## Main Repo
16 | https://github.com/yl4579/StyleTTS2
17 |
18 | ## Paper
19 | https://arxiv.org/abs/2306.07691
20 |
21 | ## Page
22 | https://styletts2.github.io/
23 |
24 | ## Output
25 |
26 | https://github.com/camenduru/styletts-colab/assets/54370274/f30b3d16-4260-49c7-808e-934451a5a8dc
27 |
28 | ## Sponsor
29 | https://modelslab.com
30 |
--------------------------------------------------------------------------------
/styletts2_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/styletts2_colab.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "VjYy0F2gZIPR"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "%cd /content\n",
21 | "!git clone -b dev https://github.com/camenduru/styletts2-hf\n",
22 | "%cd /content/styletts2-hf\n",
23 | "\n",
24 | "!pip install -q gradio cached-path munch einops einops-exts phonemizer tortoise-tts\n",
25 | "!pip install -q git+https://github.com/resemble-ai/monotonic_align\n",
26 | "\n",
27 | "!apt -y install -qq aria2 espeak-ng\n",
28 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/f-us-1.wav?download=true -d /content/styletts2-hf/voices -o f-us-1.wav\n",
29 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/f-us-2.wav?download=true -d /content/styletts2-hf/voices -o f-us-2.wav\n",
30 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/f-us-3.wav?download=true -d /content/styletts2-hf/voices -o f-us-3.wav\n",
31 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/f-us-4.wav?download=true -d /content/styletts2-hf/voices -o f-us-4.wav\n",
32 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/m-us-1.wav?download=true -d /content/styletts2-hf/voices -o m-us-1.wav\n",
33 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/m-us-2.wav?download=true -d /content/styletts2-hf/voices -o m-us-2.wav\n",
34 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/m-us-3.wav?download=true -d /content/styletts2-hf/voices -o m-us-3.wav\n",
35 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/m-us-4.wav?download=true -d /content/styletts2-hf/voices -o m-us-4.wav\n",
36 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices.pkl?download=true -d /content/styletts2-hf -o voices.pkl\n",
37 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/Utils/PLBERT/step_1000000.t7?download=true -d /content/styletts2-hf/Utils/PLBERT -o step_1000000.t7\n",
38 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/Utils/JDC/bst.t7?download=true -d /content/styletts2-hf/Utils/JDC -o bst.t7\n",
39 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/Utils/ASR/epoch_00080.pth?download=true -d /content/styletts2-hf/Utils/ASR -o epoch_00080.pth\n",
40 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/reference_audio.zip?download=true -d /content/styletts2-hf -o reference_audio.zip\n",
41 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/Data/OOD_texts.txt?download=true -d /content/styletts2-hf/Data -o OOD_texts.txt\n",
42 | "\n",
43 | "!python app.py"
44 | ]
45 | }
46 | ],
47 | "metadata": {
48 | "accelerator": "GPU",
49 | "colab": {
50 | "gpuType": "T4",
51 | "provenance": []
52 | },
53 | "kernelspec": {
54 | "display_name": "Python 3",
55 | "name": "python3"
56 | },
57 | "language_info": {
58 | "name": "python"
59 | }
60 | },
61 | "nbformat": 4,
62 | "nbformat_minor": 0
63 | }
64 |
--------------------------------------------------------------------------------
/styletts2_train_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/styletts2_train_colab.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "VjYy0F2gZIPR"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "%cd /content\n",
21 | "!apt install -y espeak-ng aria2 -y\n",
22 | "!git clone https://dagshub.com/StyleTTS/StyleTTS2\n",
23 | "%cd /content/StyleTTS2\n",
24 | "\n",
25 | "!pip install -q SoundFile torchaudio munch torch pydub pyyaml librosa matplotlib accelerate transformers phonemizer einops wandb\n",
26 | "!pip install -q gruut accelerate mlflow einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git\n",
27 | "!pip install -q nltk -U\n",
28 | " \n",
29 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/StyleTTS2-LibriTTS/raw/adedc90b53fb0676f83d4c9f1e01f1e8650ba15d/data/Models/LibriTTS/epochs_2nd_00020.pth -d /content/StyleTTS2/Models/LibriTTS -o epochs_2nd_00020.pth\n",
30 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/StyleTTS2-LibriTTS/raw/adedc90b53fb0676f83d4c9f1e01f1e8650ba15d/data/Models/LibriTTS/config.yml -d /content/StyleTTS2/Models/LibriTTS -o config.yml\n",
31 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/StyleTTS2-LibriTTS/raw/20b8b78c7e27b46d8525f3b1bb0d2725a12a56ff/data/Data.zip -d /content/StyleTTS2 -o Data.zip\n",
32 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/wavlm-base-plus/raw/624830cf7f4bc949d33bf94fa45895037e78c693/data/config.json -d /content/StyleTTS2/wavlm-base-plus -o config.json\n",
33 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/wavlm-base-plus/raw/624830cf7f4bc949d33bf94fa45895037e78c693/data/preprocessor_config.json -d /content/StyleTTS2/wavlm-base-plus -o preprocessor_config.json\n",
34 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/wavlm-base-plus/raw/624830cf7f4bc949d33bf94fa45895037e78c693/data/pytorch_model.bin -d /content/StyleTTS2/wavlm-base-plus -o pytorch_model.bin\n",
35 | "# !unzip -o Data.zip\n",
36 | "\n",
37 | "!rm -rf /content/StyleTTS2/Data\n",
38 | "!mkdir -p /content/StyleTTS2/Data/wavs\n",
39 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_0.wav -O /content/StyleTTS2/Data/wavs/0_0.wav\n",
40 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_1.wav -O /content/StyleTTS2/Data/wavs/0_1.wav\n",
41 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_2.wav -O /content/StyleTTS2/Data/wavs/0_2.wav\n",
42 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_3.wav -O /content/StyleTTS2/Data/wavs/0_3.wav\n",
43 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_4.wav -O /content/StyleTTS2/Data/wavs/0_4.wav\n",
44 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_5.wav -O /content/StyleTTS2/Data/wavs/0_5.wav\n",
45 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_6.wav -O /content/StyleTTS2/Data/wavs/0_6.wav\n",
46 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_7.wav -O /content/StyleTTS2/Data/wavs/0_7.wav\n",
47 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_8.wav -O /content/StyleTTS2/Data/wavs/0_8.wav\n",
48 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_9.wav -O /content/StyleTTS2/Data/wavs/0_9.wav\n",
49 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_10.wav -O /content/StyleTTS2/Data/wavs/0_10.wav\n",
50 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_11.wav -O /content/StyleTTS2/Data/wavs/0_11.wav\n",
51 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_12.wav -O /content/StyleTTS2/Data/wavs/0_12.wav\n",
52 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_13.wav -O /content/StyleTTS2/Data/wavs/0_13.wav\n",
53 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_14.wav -O /content/StyleTTS2/Data/wavs/0_14.wav\n",
54 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_15.wav -O /content/StyleTTS2/Data/wavs/0_15.wav\n",
55 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_16.wav -O /content/StyleTTS2/Data/wavs/0_16.wav\n",
56 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_17.wav -O /content/StyleTTS2/Data/wavs/0_17.wav\n",
57 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_18.wav -O /content/StyleTTS2/Data/wavs/0_18.wav\n",
58 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_19.wav -O /content/StyleTTS2/Data/wavs/0_19.wav\n",
59 | "!wget https://gist.github.com/camenduru/aea7c0f76cfe3ed79c521c27374f613d/raw/958d6e0035d89896d2f6c35d500d870b66075005/gistfile1.txt -O /content/StyleTTS2/Data/OOD_texts.txt\n",
60 | "!wget https://gist.github.com/camenduru/aea7c0f76cfe3ed79c521c27374f613d/raw/958d6e0035d89896d2f6c35d500d870b66075005/gistfile1.txt -O /content/StyleTTS2/Data/train_list.txt\n",
61 | "!wget https://gist.github.com/camenduru/aea7c0f76cfe3ed79c521c27374f613d/raw/958d6e0035d89896d2f6c35d500d870b66075005/gistfile1.txt -O /content/StyleTTS2/Data/val_list.txt"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "%cd /content/StyleTTS2\n",
71 | "\n",
72 | "config_path = \"Configs/config_ft.yml\"\n",
73 | "import yaml\n",
74 | "config = yaml.safe_load(open(config_path))\n",
75 | "config['data_params']['OOD_data'] = \"Data/OOD_texts.txt\"\n",
76 | "config['data_params']['root_path'] = \"Data/wavs/\"\n",
77 | "config['data_params']['train_data'] = \"Data/train_list.txt\"\n",
78 | "config['data_params']['val_data'] = \"Data/val_list.txt\"\n",
79 | "config['batch_size'] = 2 # not enough RAM\n",
80 | "config['max_len'] = 100 # not enough RAM\n",
81 | "config['loss_params']['joint_epoch'] = 150 # we do not do SLM adversarial training due to not enough RAM\n",
82 | "config['save_freq'] = 1\n",
83 | "config['data_params']['logger'] = \"mlflow\"\n",
84 | "\n",
85 | "with open(config_path, 'w') as outfile:\n",
86 | " yaml.dump(config, outfile, default_flow_style=True)\n",
87 | "\n",
88 | "import os\n",
89 | "os.environ['MLFLOW_TRACKING_URI'] = 'MLFLOW_TRACKING_URI'\n",
90 | "os.environ['MLFLOW_TRACKING_USERNAME'] = 'MLFLOW_TRACKING_USERNAME'\n",
91 | "os.environ['MLFLOW_TRACKING_PASSWORD'] = 'MLFLOW_TRACKING_PASSWORD'\n",
92 | "\n",
93 | "!cat Configs/config_ft.yml"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "%cd /content/StyleTTS2\n",
103 | "# !accelerate launch --mixed_precision=fp16 --multi_gpu train_finetune_accelerate.py --config_path ./Configs/config_ft.yml\n",
104 | "!python train_finetune.py --config_path ./Configs/config_ft.yml"
105 | ]
106 | }
107 | ],
108 | "metadata": {
109 | "accelerator": "GPU",
110 | "colab": {
111 | "gpuType": "T4",
112 | "provenance": []
113 | },
114 | "kernelspec": {
115 | "display_name": "Python 3",
116 | "name": "python3"
117 | },
118 | "language_info": {
119 | "name": "python"
120 | }
121 | },
122 | "nbformat": 4,
123 | "nbformat_minor": 0
124 | }
125 |
--------------------------------------------------------------------------------
/styletts2_test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/styletts2_test.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "VjYy0F2gZIPR"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "%cd /content\n",
21 | "!git clone https://github.com/yl4579/StyleTTS2.git\n",
22 | "%cd StyleTTS2\n",
23 | "!pip install -q munch pydub accelerate phonemizer einops einops-exts git+https://github.com/resemble-ai/monotonic_align.git\n",
24 | "!apt install espeak-ng aria2\n",
25 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00012.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00012.pth\n",
26 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/raw/main/config_ft.yml -d /content/StyleTTS2/Models/LJSpeech -o config_ft.yml\n",
27 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/LJ001-0110.wav -d /content/StyleTTS2/Data/wavs -o LJ001-0110.wav\n",
28 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/ash.wav -d /content/StyleTTS2/Data/wavs -o ash.wav\n",
29 | "\n",
30 | "%cd StyleTTS2\n",
31 | "\n",
32 | "import nltk\n",
33 | "nltk.download('punkt')\n",
34 | "\n",
35 | "import torch\n",
36 | "torch.manual_seed(0)\n",
37 | "torch.backends.cudnn.benchmark = False\n",
38 | "torch.backends.cudnn.deterministic = True\n",
39 | "\n",
40 | "import random\n",
41 | "random.seed(0)\n",
42 | "\n",
43 | "import numpy as np\n",
44 | "np.random.seed(0)\n",
45 | "\n",
46 | "# load packages\n",
47 | "import time\n",
48 | "import random\n",
49 | "import yaml\n",
50 | "from munch import Munch\n",
51 | "import numpy as np\n",
52 | "import torch\n",
53 | "from torch import nn\n",
54 | "import torch.nn.functional as F\n",
55 | "import torchaudio\n",
56 | "import librosa\n",
57 | "from nltk.tokenize import word_tokenize\n",
58 | "\n",
59 | "from models import *\n",
60 | "from utils import *\n",
61 | "from text_utils import TextCleaner\n",
62 | "textclenaer = TextCleaner()\n",
63 | "\n",
64 | "%matplotlib inline\n",
65 | "\n",
66 | "to_mel = torchaudio.transforms.MelSpectrogram(\n",
67 | " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
68 | "mean, std = -4, 4\n",
69 | "\n",
70 | "def length_to_mask(lengths):\n",
71 | " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
72 | " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
73 | " return mask\n",
74 | "\n",
75 | "def preprocess(wave):\n",
76 | " wave_tensor = torch.from_numpy(wave).float()\n",
77 | " mel_tensor = to_mel(wave_tensor)\n",
78 | " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
79 | " return mel_tensor\n",
80 | "\n",
81 | "def compute_style(path):\n",
82 | " wave, sr = librosa.load(path, sr=24000)\n",
83 | " audio, index = librosa.effects.trim(wave, top_db=30)\n",
84 | " if sr != 24000:\n",
85 | " audio = librosa.resample(audio, sr, 24000)\n",
86 | " mel_tensor = preprocess(audio).to(device)\n",
87 | "\n",
88 | " with torch.no_grad():\n",
89 | " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
90 | " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
91 | "\n",
92 | " return torch.cat([ref_s, ref_p], dim=1)\n",
93 | "\n",
94 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
95 | "\n",
96 | "# load phonemizer\n",
97 | "import phonemizer\n",
98 | "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n",
99 | "\n",
100 | "config = yaml.safe_load(open(\"Models/LJSpeech/config_ft.yml\"))\n",
101 | "\n",
102 | "# load pretrained ASR model\n",
103 | "ASR_config = config.get('ASR_config', False)\n",
104 | "ASR_path = config.get('ASR_path', False)\n",
105 | "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
106 | "\n",
107 | "# load pretrained F0 model\n",
108 | "F0_path = config.get('F0_path', False)\n",
109 | "pitch_extractor = load_F0_models(F0_path)\n",
110 | "\n",
111 | "# load BERT model\n",
112 | "from Utils.PLBERT.util import load_plbert\n",
113 | "BERT_path = config.get('PLBERT_dir', False)\n",
114 | "plbert = load_plbert(BERT_path)\n",
115 | "\n",
116 | "model_params = recursive_munch(config['model_params'])\n",
117 | "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
118 | "_ = [model[key].eval() for key in model]\n",
119 | "_ = [model[key].to(device) for key in model]\n",
120 | "\n",
121 | "files = [f for f in os.listdir(\"Models/LJSpeech/\") if f.endswith('.pth')]\n",
122 | "sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))\n",
123 | "\n",
124 | "params_whole = torch.load(\"Models/LJSpeech/\" + sorted_files[-1], map_location='cpu')\n",
125 | "params = params_whole['net']\n",
126 | "\n",
127 | "for key in model:\n",
128 | " if key in params:\n",
129 | " print('%s loaded' % key)\n",
130 | " try:\n",
131 | " model[key].load_state_dict(params[key])\n",
132 | " except:\n",
133 | " from collections import OrderedDict\n",
134 | " state_dict = params[key]\n",
135 | " new_state_dict = OrderedDict()\n",
136 | " for k, v in state_dict.items():\n",
137 | " name = k[7:] # remove `module.`\n",
138 | " new_state_dict[name] = v\n",
139 | " # load params\n",
140 | " model[key].load_state_dict(new_state_dict, strict=False)\n",
141 | "# except:\n",
142 | "# _load(params[key], model[key])\n",
143 | "_ = [model[key].eval() for key in model]\n",
144 | "\n",
145 | "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n",
146 | "\n",
147 | "sampler = DiffusionSampler(\n",
148 | " model.diffusion.diffusion,\n",
149 | " sampler=ADPM2Sampler(),\n",
150 | " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
151 | " clamp=False\n",
152 | ")\n",
153 | "\n",
154 | "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
155 | " text = text.strip()\n",
156 | " ps = global_phonemizer.phonemize([text])\n",
157 | " ps = word_tokenize(ps[0])\n",
158 | " ps = ' '.join(ps)\n",
159 | " tokens = textclenaer(ps)\n",
160 | " tokens.insert(0, 0)\n",
161 | " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
162 | "\n",
163 | " with torch.no_grad():\n",
164 | " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
165 | " text_mask = length_to_mask(input_lengths).to(device)\n",
166 | "\n",
167 | " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
168 | " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
169 | " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
170 | "\n",
171 | " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
172 | " embedding=bert_dur,\n",
173 | " embedding_scale=embedding_scale,\n",
174 | " features=ref_s, # reference from the same speaker as the embedding\n",
175 | " num_steps=diffusion_steps).squeeze(1)\n",
176 | "\n",
177 | "\n",
178 | " s = s_pred[:, 128:]\n",
179 | " ref = s_pred[:, :128]\n",
180 | "\n",
181 | " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
182 | " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
183 | "\n",
184 | " d = model.predictor.text_encoder(d_en,\n",
185 | " s, input_lengths, text_mask)\n",
186 | "\n",
187 | " x, _ = model.predictor.lstm(d)\n",
188 | " duration = model.predictor.duration_proj(x)\n",
189 | "\n",
190 | " duration = torch.sigmoid(duration).sum(axis=-1)\n",
191 | " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
192 | "\n",
193 | " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
194 | " c_frame = 0\n",
195 | " for i in range(pred_aln_trg.size(0)):\n",
196 | " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
197 | " c_frame += int(pred_dur[i].data)\n",
198 | "\n",
199 | " # encode prosody\n",
200 | " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
201 | " if model_params.decoder.type == \"hifigan\":\n",
202 | " asr_new = torch.zeros_like(en)\n",
203 | " asr_new[:, :, 0] = en[:, :, 0]\n",
204 | " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
205 | " en = asr_new\n",
206 | "\n",
207 | " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
208 | "\n",
209 | " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
210 | " if model_params.decoder.type == \"hifigan\":\n",
211 | " asr_new = torch.zeros_like(asr)\n",
212 | " asr_new[:, :, 0] = asr[:, :, 0]\n",
213 | " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
214 | " asr = asr_new\n",
215 | "\n",
216 | " out = model.decoder(asr,\n",
217 | " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
218 | "\n",
219 | "\n",
220 | " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "text = '''Maltby and Company would issue warrants on them deliverable to the importer, and the goods were then passed to be stored in neighboring warehouses.'''\n",
230 | "# path = \"Data/wavs/LJ001-0110.wav\"\n",
231 | "path = \"Data/wavs/ash.wav\"\n",
232 | "ref_s = compute_style(path)\n",
233 | "\n",
234 | "start = time.time()\n",
235 | "# wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)\n",
236 | "wav = inference(text, ref_s)\n",
237 | "rtf = (time.time() - start) / (len(wav) / 24000)\n",
238 | "print(f\"RTF = {rtf:5f}\")\n",
239 | "import IPython.display as ipd\n",
240 | "display(ipd.Audio(wav, rate=24000, normalize=False))\n",
241 | "sorted_files[-1]"
242 | ]
243 | }
244 | ],
245 | "metadata": {
246 | "accelerator": "GPU",
247 | "colab": {
248 | "gpuType": "T4",
249 | "provenance": []
250 | },
251 | "kernelspec": {
252 | "display_name": "Python 3",
253 | "name": "python3"
254 | },
255 | "language_info": {
256 | "name": "python"
257 | }
258 | },
259 | "nbformat": 4,
260 | "nbformat_minor": 0
261 | }
262 |
--------------------------------------------------------------------------------
/bucilianus1.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/bucilianus1.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "VjYy0F2gZIPR"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "%cd /content\n",
21 | "!git clone https://github.com/yl4579/StyleTTS2.git\n",
22 | "%cd StyleTTS2\n",
23 | "!pip install -q munch pydub accelerate phonemizer einops einops-exts git+https://github.com/resemble-ai/monotonic_align.git\n",
24 | "!apt install espeak-ng aria2\n",
25 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/bucilianus-1/resolve/main/epoch_2nd_00030.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00030.pth\n",
26 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/raw/main/config_ft.yml -d /content/StyleTTS2/Models/LJSpeech -o config_ft.yml\n",
27 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/LJ001-0110.wav -d /content/StyleTTS2/Data/wavs -o LJ001-0110.wav\n",
28 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/ash.wav -d /content/StyleTTS2/Data/wavs -o ash.wav\n",
29 | "!wget https://replicate.delivery/pbxt/KB9Q5ER6n39loKuLKdkmPlOr8uR0eCNHY9dFYZQvL3MIfsBH/demo_speaker2.mp3 -O /content/demo_speaker2.mp3\n",
30 | "!wget https://replicate.delivery/pbxt/KB9YGZ4v1woTUDd8rSRlizdqIQjxqGjDnosgvlQkTZo7IRI7/demo_speaker0.mp3 -O /content/demo_speaker0.mp3\n",
31 | "\n",
32 | "%cd StyleTTS2\n",
33 | "\n",
34 | "import nltk\n",
35 | "nltk.download('punkt')\n",
36 | "\n",
37 | "import torch\n",
38 | "torch.manual_seed(0)\n",
39 | "torch.backends.cudnn.benchmark = False\n",
40 | "torch.backends.cudnn.deterministic = True\n",
41 | "\n",
42 | "import random\n",
43 | "random.seed(0)\n",
44 | "\n",
45 | "import numpy as np\n",
46 | "np.random.seed(0)\n",
47 | "\n",
48 | "# load packages\n",
49 | "import time\n",
50 | "import random\n",
51 | "import yaml\n",
52 | "from munch import Munch\n",
53 | "import numpy as np\n",
54 | "import torch\n",
55 | "from torch import nn\n",
56 | "import torch.nn.functional as F\n",
57 | "import torchaudio\n",
58 | "import librosa\n",
59 | "from nltk.tokenize import word_tokenize\n",
60 | "\n",
61 | "from models import *\n",
62 | "from utils import *\n",
63 | "from text_utils import TextCleaner\n",
64 | "textclenaer = TextCleaner()\n",
65 | "\n",
66 | "%matplotlib inline\n",
67 | "\n",
68 | "to_mel = torchaudio.transforms.MelSpectrogram(\n",
69 | " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
70 | "mean, std = -4, 4\n",
71 | "\n",
72 | "def length_to_mask(lengths):\n",
73 | " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
74 | " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
75 | " return mask\n",
76 | "\n",
77 | "def preprocess(wave):\n",
78 | " wave_tensor = torch.from_numpy(wave).float()\n",
79 | " mel_tensor = to_mel(wave_tensor)\n",
80 | " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
81 | " return mel_tensor\n",
82 | "\n",
83 | "def compute_style(path):\n",
84 | " wave, sr = librosa.load(path, sr=24000)\n",
85 | " audio, index = librosa.effects.trim(wave, top_db=30)\n",
86 | " if sr != 24000:\n",
87 | " audio = librosa.resample(audio, sr, 24000)\n",
88 | " mel_tensor = preprocess(audio).to(device)\n",
89 | "\n",
90 | " with torch.no_grad():\n",
91 | " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
92 | " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
93 | "\n",
94 | " return torch.cat([ref_s, ref_p], dim=1)\n",
95 | "\n",
96 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
97 | "\n",
98 | "# load phonemizer\n",
99 | "import phonemizer\n",
100 | "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n",
101 | "\n",
102 | "config = yaml.safe_load(open(\"Models/LJSpeech/config_ft.yml\"))\n",
103 | "\n",
104 | "# load pretrained ASR model\n",
105 | "ASR_config = config.get('ASR_config', False)\n",
106 | "ASR_path = config.get('ASR_path', False)\n",
107 | "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
108 | "\n",
109 | "# load pretrained F0 model\n",
110 | "F0_path = config.get('F0_path', False)\n",
111 | "pitch_extractor = load_F0_models(F0_path)\n",
112 | "\n",
113 | "# load BERT model\n",
114 | "from Utils.PLBERT.util import load_plbert\n",
115 | "BERT_path = config.get('PLBERT_dir', False)\n",
116 | "plbert = load_plbert(BERT_path)\n",
117 | "\n",
118 | "model_params = recursive_munch(config['model_params'])\n",
119 | "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
120 | "_ = [model[key].eval() for key in model]\n",
121 | "_ = [model[key].to(device) for key in model]\n",
122 | "\n",
123 | "files = [f for f in os.listdir(\"Models/LJSpeech/\") if f.endswith('.pth')]\n",
124 | "sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))\n",
125 | "\n",
126 | "params_whole = torch.load(\"Models/LJSpeech/\" + sorted_files[-1], map_location='cpu')\n",
127 | "params = params_whole['net']\n",
128 | "\n",
129 | "for key in model:\n",
130 | " if key in params:\n",
131 | " print('%s loaded' % key)\n",
132 | " try:\n",
133 | " model[key].load_state_dict(params[key])\n",
134 | " except:\n",
135 | " from collections import OrderedDict\n",
136 | " state_dict = params[key]\n",
137 | " new_state_dict = OrderedDict()\n",
138 | " for k, v in state_dict.items():\n",
139 | " name = k[7:] # remove `module.`\n",
140 | " new_state_dict[name] = v\n",
141 | " # load params\n",
142 | " model[key].load_state_dict(new_state_dict, strict=False)\n",
143 | "# except:\n",
144 | "# _load(params[key], model[key])\n",
145 | "_ = [model[key].eval() for key in model]\n",
146 | "\n",
147 | "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n",
148 | "\n",
149 | "sampler = DiffusionSampler(\n",
150 | " model.diffusion.diffusion,\n",
151 | " sampler=ADPM2Sampler(),\n",
152 | " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
153 | " clamp=False\n",
154 | ")\n",
155 | "\n",
156 | "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
157 | " text = text.strip()\n",
158 | " ps = global_phonemizer.phonemize([text])\n",
159 | " ps = word_tokenize(ps[0])\n",
160 | " ps = ' '.join(ps)\n",
161 | " tokens = textclenaer(ps)\n",
162 | " tokens.insert(0, 0)\n",
163 | " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
164 | "\n",
165 | " with torch.no_grad():\n",
166 | " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
167 | " text_mask = length_to_mask(input_lengths).to(device)\n",
168 | "\n",
169 | " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
170 | " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
171 | " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
172 | "\n",
173 | " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
174 | " embedding=bert_dur,\n",
175 | " embedding_scale=embedding_scale,\n",
176 | " features=ref_s, # reference from the same speaker as the embedding\n",
177 | " num_steps=diffusion_steps).squeeze(1)\n",
178 | "\n",
179 | "\n",
180 | " s = s_pred[:, 128:]\n",
181 | " ref = s_pred[:, :128]\n",
182 | "\n",
183 | " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
184 | " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
185 | "\n",
186 | " d = model.predictor.text_encoder(d_en,\n",
187 | " s, input_lengths, text_mask)\n",
188 | "\n",
189 | " x, _ = model.predictor.lstm(d)\n",
190 | " duration = model.predictor.duration_proj(x)\n",
191 | "\n",
192 | " duration = torch.sigmoid(duration).sum(axis=-1)\n",
193 | " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
194 | "\n",
195 | " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
196 | " c_frame = 0\n",
197 | " for i in range(pred_aln_trg.size(0)):\n",
198 | " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
199 | " c_frame += int(pred_dur[i].data)\n",
200 | "\n",
201 | " # encode prosody\n",
202 | " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
203 | " if model_params.decoder.type == \"hifigan\":\n",
204 | " asr_new = torch.zeros_like(en)\n",
205 | " asr_new[:, :, 0] = en[:, :, 0]\n",
206 | " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
207 | " en = asr_new\n",
208 | "\n",
209 | " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
210 | "\n",
211 | " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
212 | " if model_params.decoder.type == \"hifigan\":\n",
213 | " asr_new = torch.zeros_like(asr)\n",
214 | " asr_new[:, :, 0] = asr[:, :, 0]\n",
215 | " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
216 | " asr = asr_new\n",
217 | "\n",
218 | " out = model.decoder(asr,\n",
219 | " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
220 | "\n",
221 | "\n",
222 | " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
223 | ]
224 | },
225 | {
226 | "cell_type": "code",
227 | "execution_count": null,
228 | "metadata": {},
229 | "outputs": [],
230 | "source": [
231 | "text = '''Supporters of Caesar argue that he was a successful military leader who brought stability and reforms to Rome. They may view his assassination as a betrayal by jealous and power-hungry senators who feared his increasing authority.'''\n",
232 | "# path = \"Data/wavs/LJ001-0110.wav\"\n",
233 | "path = \"/content/demo_speaker2.mp3\"\n",
234 | "# path = \"/content/demo_speaker0.mp3\"\n",
235 | "# path = \"Data/wavs/ash.wav\"\n",
236 | "ref_s = compute_style(path)\n",
237 | "\n",
238 | "start = time.time()\n",
239 | "# wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)\n",
240 | "wav = inference(text, ref_s)\n",
241 | "rtf = (time.time() - start) / (len(wav) / 24000)\n",
242 | "print(f\"RTF = {rtf:5f}\")\n",
243 | "import IPython.display as ipd\n",
244 | "display(ipd.Audio(wav, rate=24000, normalize=False))\n",
245 | "sorted_files[-1]"
246 | ]
247 | }
248 | ],
249 | "metadata": {
250 | "accelerator": "GPU",
251 | "colab": {
252 | "gpuType": "T4",
253 | "provenance": []
254 | },
255 | "kernelspec": {
256 | "display_name": "Python 3",
257 | "name": "python3"
258 | },
259 | "language_info": {
260 | "name": "python"
261 | }
262 | },
263 | "nbformat": 4,
264 | "nbformat_minor": 0
265 | }
266 |
--------------------------------------------------------------------------------
/styletts2_test_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/styletts2_test_colab.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "VjYy0F2gZIPR"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "%cd /content\n",
21 | "!git clone https://github.com/yl4579/StyleTTS2.git\n",
22 | "%cd StyleTTS2\n",
23 | "!pip install -q munch pydub accelerate phonemizer einops einops-exts git+https://github.com/resemble-ai/monotonic_align.git\n",
24 | "!apt install espeak-ng aria2\n",
25 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00000.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00000.pth\n",
26 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00001.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00001.pth\n",
27 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00002.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00002.pth\n",
28 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00003.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00003.pth\n",
29 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00004.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00004.pth\n",
30 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00005.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00005.pth\n",
31 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00006.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00006.pth\n",
32 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00007.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00007.pth\n",
33 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00008.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00008.pth\n",
34 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00009.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00009.pth\n",
35 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00009_2.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00009_2.pth\n",
36 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00009_3.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00009_3.pth\n",
37 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00009_4.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00009_4.pth\n",
38 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00010.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00010.pth\n",
39 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00011.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00011.pth\n",
40 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00012.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00012.pth\n",
41 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00013.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00013.pth\n",
42 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00014.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00014.pth\n",
43 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00015.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00015.pth\n",
44 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00016.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00016.pth\n",
45 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00017.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00017.pth\n",
46 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00018.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00018.pth\n",
47 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00019.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00019.pth\n",
48 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00020.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00020.pth\n",
49 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model3/raw/9734ed6cbac07b09747b795811a081ce9896d8b3/model/epoch_2nd_00003.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00003.pth\n",
50 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00000.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00000.pth\n",
51 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00001.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00001.pth\n",
52 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00002.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00002.pth\n",
53 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00003.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00003.pth\n",
54 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00004.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00004.pth\n",
55 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00005.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00005.pth\n",
56 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00006.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00006.pth\n",
57 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00007.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00007.pth\n",
58 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00008.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00008.pth\n",
59 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00009.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00009.pth\n",
60 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00010.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00010.pth\n",
61 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00011.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00011.pth\n",
62 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00012.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00012.pth\n",
63 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/raw/main/config_ft.yml -d /content/StyleTTS2/Models/LJSpeech -o config_ft.yml\n",
64 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/LJ001-0110.wav -d /content/StyleTTS2/Data/wavs -o LJ001-0110.wav\n",
65 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/ash.wav -d /content/StyleTTS2/Data/wavs -o ash.wav\n",
66 | "\n",
67 | "%cd StyleTTS2\n",
68 | "\n",
69 | "import nltk\n",
70 | "nltk.download('punkt')\n",
71 | "\n",
72 | "import torch\n",
73 | "torch.manual_seed(0)\n",
74 | "torch.backends.cudnn.benchmark = False\n",
75 | "torch.backends.cudnn.deterministic = True\n",
76 | "\n",
77 | "import random\n",
78 | "random.seed(0)\n",
79 | "\n",
80 | "import numpy as np\n",
81 | "np.random.seed(0)\n",
82 | "\n",
83 | "# load packages\n",
84 | "import time\n",
85 | "import random\n",
86 | "import yaml\n",
87 | "from munch import Munch\n",
88 | "import numpy as np\n",
89 | "import torch\n",
90 | "from torch import nn\n",
91 | "import torch.nn.functional as F\n",
92 | "import torchaudio\n",
93 | "import librosa\n",
94 | "from nltk.tokenize import word_tokenize\n",
95 | "\n",
96 | "from models import *\n",
97 | "from utils import *\n",
98 | "from text_utils import TextCleaner\n",
99 | "textclenaer = TextCleaner()\n",
100 | "\n",
101 | "%matplotlib inline\n",
102 | "\n",
103 | "to_mel = torchaudio.transforms.MelSpectrogram(\n",
104 | " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n",
105 | "mean, std = -4, 4\n",
106 | "\n",
107 | "def length_to_mask(lengths):\n",
108 | " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n",
109 | " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n",
110 | " return mask\n",
111 | "\n",
112 | "def preprocess(wave):\n",
113 | " wave_tensor = torch.from_numpy(wave).float()\n",
114 | " mel_tensor = to_mel(wave_tensor)\n",
115 | " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n",
116 | " return mel_tensor\n",
117 | "\n",
118 | "def compute_style(path):\n",
119 | " wave, sr = librosa.load(path, sr=24000)\n",
120 | " audio, index = librosa.effects.trim(wave, top_db=30)\n",
121 | " if sr != 24000:\n",
122 | " audio = librosa.resample(audio, sr, 24000)\n",
123 | " mel_tensor = preprocess(audio).to(device)\n",
124 | "\n",
125 | " with torch.no_grad():\n",
126 | " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n",
127 | " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n",
128 | "\n",
129 | " return torch.cat([ref_s, ref_p], dim=1)\n",
130 | "\n",
131 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
132 | "\n",
133 | "# load phonemizer\n",
134 | "import phonemizer\n",
135 | "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n",
136 | "\n",
137 | "config = yaml.safe_load(open(\"Models/LJSpeech/config_ft.yml\"))\n",
138 | "\n",
139 | "# load pretrained ASR model\n",
140 | "ASR_config = config.get('ASR_config', False)\n",
141 | "ASR_path = config.get('ASR_path', False)\n",
142 | "text_aligner = load_ASR_models(ASR_path, ASR_config)\n",
143 | "\n",
144 | "# load pretrained F0 model\n",
145 | "F0_path = config.get('F0_path', False)\n",
146 | "pitch_extractor = load_F0_models(F0_path)\n",
147 | "\n",
148 | "# load BERT model\n",
149 | "from Utils.PLBERT.util import load_plbert\n",
150 | "BERT_path = config.get('PLBERT_dir', False)\n",
151 | "plbert = load_plbert(BERT_path)\n",
152 | "\n",
153 | "model_params = recursive_munch(config['model_params'])\n",
154 | "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n",
155 | "_ = [model[key].eval() for key in model]\n",
156 | "_ = [model[key].to(device) for key in model]\n",
157 | "\n",
158 | "files = [f for f in os.listdir(\"Models/LJSpeech/\") if f.endswith('.pth')]\n",
159 | "sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))\n",
160 | "\n",
161 | "params_whole = torch.load(\"Models/LJSpeech/\" + sorted_files[-1], map_location='cpu')\n",
162 | "params = params_whole['net']\n",
163 | "\n",
164 | "for key in model:\n",
165 | " if key in params:\n",
166 | " print('%s loaded' % key)\n",
167 | " try:\n",
168 | " model[key].load_state_dict(params[key])\n",
169 | " except:\n",
170 | " from collections import OrderedDict\n",
171 | " state_dict = params[key]\n",
172 | " new_state_dict = OrderedDict()\n",
173 | " for k, v in state_dict.items():\n",
174 | " name = k[7:] # remove `module.`\n",
175 | " new_state_dict[name] = v\n",
176 | " # load params\n",
177 | " model[key].load_state_dict(new_state_dict, strict=False)\n",
178 | "# except:\n",
179 | "# _load(params[key], model[key])\n",
180 | "_ = [model[key].eval() for key in model]\n",
181 | "\n",
182 | "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n",
183 | "\n",
184 | "sampler = DiffusionSampler(\n",
185 | " model.diffusion.diffusion,\n",
186 | " sampler=ADPM2Sampler(),\n",
187 | " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n",
188 | " clamp=False\n",
189 | ")\n",
190 | "\n",
191 | "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n",
192 | " text = text.strip()\n",
193 | " ps = global_phonemizer.phonemize([text])\n",
194 | " ps = word_tokenize(ps[0])\n",
195 | " ps = ' '.join(ps)\n",
196 | " tokens = textclenaer(ps)\n",
197 | " tokens.insert(0, 0)\n",
198 | " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n",
199 | "\n",
200 | " with torch.no_grad():\n",
201 | " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n",
202 | " text_mask = length_to_mask(input_lengths).to(device)\n",
203 | "\n",
204 | " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n",
205 | " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n",
206 | " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n",
207 | "\n",
208 | " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n",
209 | " embedding=bert_dur,\n",
210 | " embedding_scale=embedding_scale,\n",
211 | " features=ref_s, # reference from the same speaker as the embedding\n",
212 | " num_steps=diffusion_steps).squeeze(1)\n",
213 | "\n",
214 | "\n",
215 | " s = s_pred[:, 128:]\n",
216 | " ref = s_pred[:, :128]\n",
217 | "\n",
218 | " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n",
219 | " s = beta * s + (1 - beta) * ref_s[:, 128:]\n",
220 | "\n",
221 | " d = model.predictor.text_encoder(d_en,\n",
222 | " s, input_lengths, text_mask)\n",
223 | "\n",
224 | " x, _ = model.predictor.lstm(d)\n",
225 | " duration = model.predictor.duration_proj(x)\n",
226 | "\n",
227 | " duration = torch.sigmoid(duration).sum(axis=-1)\n",
228 | " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n",
229 | "\n",
230 | " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n",
231 | " c_frame = 0\n",
232 | " for i in range(pred_aln_trg.size(0)):\n",
233 | " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n",
234 | " c_frame += int(pred_dur[i].data)\n",
235 | "\n",
236 | " # encode prosody\n",
237 | " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n",
238 | " if model_params.decoder.type == \"hifigan\":\n",
239 | " asr_new = torch.zeros_like(en)\n",
240 | " asr_new[:, :, 0] = en[:, :, 0]\n",
241 | " asr_new[:, :, 1:] = en[:, :, 0:-1]\n",
242 | " en = asr_new\n",
243 | "\n",
244 | " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n",
245 | "\n",
246 | " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n",
247 | " if model_params.decoder.type == \"hifigan\":\n",
248 | " asr_new = torch.zeros_like(asr)\n",
249 | " asr_new[:, :, 0] = asr[:, :, 0]\n",
250 | " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n",
251 | " asr = asr_new\n",
252 | "\n",
253 | " out = model.decoder(asr,\n",
254 | " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n",
255 | "\n",
256 | "\n",
257 | " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": null,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "text = '''Maltby and Company would issue warrants on them deliverable to the importer, and the goods were then passed to be stored in neighboring warehouses.'''\n",
267 | "# path = \"Data/wavs/LJ001-0110.wav\"\n",
268 | "path = \"Data/wavs/ash.wav\"\n",
269 | "ref_s = compute_style(path)\n",
270 | "\n",
271 | "start = time.time()\n",
272 | "# wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)\n",
273 | "wav = inference(text, ref_s)\n",
274 | "rtf = (time.time() - start) / (len(wav) / 24000)\n",
275 | "print(f\"RTF = {rtf:5f}\")\n",
276 | "import IPython.display as ipd\n",
277 | "display(ipd.Audio(wav, rate=24000, normalize=False))\n",
278 | "sorted_files[-1]"
279 | ]
280 | }
281 | ],
282 | "metadata": {
283 | "accelerator": "GPU",
284 | "colab": {
285 | "gpuType": "T4",
286 | "provenance": []
287 | },
288 | "kernelspec": {
289 | "display_name": "Python 3",
290 | "name": "python3"
291 | },
292 | "language_info": {
293 | "name": "python"
294 | }
295 | },
296 | "nbformat": 4,
297 | "nbformat_minor": 0
298 | }
299 |
--------------------------------------------------------------------------------