├── README.md ├── styletts2_colab.ipynb ├── styletts2_train_colab.ipynb ├── styletts2_test.ipynb ├── bucilianus1.ipynb └── styletts2_test_colab.ipynb /README.md: -------------------------------------------------------------------------------- 1 | 🐣 Please follow me for new updates https://twitter.com/camenduru
2 | 🔥 Please join our discord server https://discord.gg/k5BwmmvJJU
3 | 🥳 Please join my patreon community https://patreon.com/camenduru
4 | 5 | # 🚦 WIP 🚦 6 | 7 | ## 🦒 Colab 8 | 9 | | Colab | Info 10 | | --- | --- | 11 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/styletts2_colab.ipynb) | styletts2_colab (Thanks to [@realmrfakename](https://twitter.com/realmrfakename) ❤) 12 | 13 | ## Tutorial 14 | 15 | ## Main Repo 16 | https://github.com/yl4579/StyleTTS2 17 | 18 | ## Paper 19 | https://arxiv.org/abs/2306.07691 20 | 21 | ## Page 22 | https://styletts2.github.io/ 23 | 24 | ## Output 25 | 26 | https://github.com/camenduru/styletts-colab/assets/54370274/f30b3d16-4260-49c7-808e-934451a5a8dc 27 | 28 | ## Sponsor 29 | https://modelslab.com 30 | -------------------------------------------------------------------------------- /styletts2_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/styletts2_colab.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "VjYy0F2gZIPR" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%cd /content\n", 21 | "!git clone -b dev https://github.com/camenduru/styletts2-hf\n", 22 | "%cd /content/styletts2-hf\n", 23 | "\n", 24 | "!pip install -q gradio cached-path munch einops einops-exts phonemizer tortoise-tts\n", 25 | "!pip install -q git+https://github.com/resemble-ai/monotonic_align\n", 26 | "\n", 27 | "!apt -y install -qq aria2 espeak-ng\n", 28 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/f-us-1.wav?download=true -d /content/styletts2-hf/voices -o f-us-1.wav\n", 29 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/f-us-2.wav?download=true -d /content/styletts2-hf/voices -o f-us-2.wav\n", 30 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/f-us-3.wav?download=true -d /content/styletts2-hf/voices -o f-us-3.wav\n", 31 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/f-us-4.wav?download=true -d /content/styletts2-hf/voices -o f-us-4.wav\n", 32 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/m-us-1.wav?download=true -d /content/styletts2-hf/voices -o m-us-1.wav\n", 33 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/m-us-2.wav?download=true -d /content/styletts2-hf/voices -o m-us-2.wav\n", 34 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/m-us-3.wav?download=true -d /content/styletts2-hf/voices -o m-us-3.wav\n", 35 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices/m-us-4.wav?download=true -d /content/styletts2-hf/voices -o m-us-4.wav\n", 36 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/voices.pkl?download=true -d /content/styletts2-hf -o voices.pkl\n", 37 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/Utils/PLBERT/step_1000000.t7?download=true -d /content/styletts2-hf/Utils/PLBERT -o step_1000000.t7\n", 38 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/Utils/JDC/bst.t7?download=true -d /content/styletts2-hf/Utils/JDC -o bst.t7\n", 39 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/Utils/ASR/epoch_00080.pth?download=true -d /content/styletts2-hf/Utils/ASR -o epoch_00080.pth\n", 40 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/reference_audio.zip?download=true -d /content/styletts2-hf -o reference_audio.zip\n", 41 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/spaces/styletts2/styletts2/resolve/main/Data/OOD_texts.txt?download=true -d /content/styletts2-hf/Data -o OOD_texts.txt\n", 42 | "\n", 43 | "!python app.py" 44 | ] 45 | } 46 | ], 47 | "metadata": { 48 | "accelerator": "GPU", 49 | "colab": { 50 | "gpuType": "T4", 51 | "provenance": [] 52 | }, 53 | "kernelspec": { 54 | "display_name": "Python 3", 55 | "name": "python3" 56 | }, 57 | "language_info": { 58 | "name": "python" 59 | } 60 | }, 61 | "nbformat": 4, 62 | "nbformat_minor": 0 63 | } 64 | -------------------------------------------------------------------------------- /styletts2_train_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/styletts2_train_colab.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "VjYy0F2gZIPR" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%cd /content\n", 21 | "!apt install -y espeak-ng aria2 -y\n", 22 | "!git clone https://dagshub.com/StyleTTS/StyleTTS2\n", 23 | "%cd /content/StyleTTS2\n", 24 | "\n", 25 | "!pip install -q SoundFile torchaudio munch torch pydub pyyaml librosa matplotlib accelerate transformers phonemizer einops wandb\n", 26 | "!pip install -q gruut accelerate mlflow einops-exts tqdm typing-extensions git+https://github.com/resemble-ai/monotonic_align.git\n", 27 | "!pip install -q nltk -U\n", 28 | " \n", 29 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/StyleTTS2-LibriTTS/raw/adedc90b53fb0676f83d4c9f1e01f1e8650ba15d/data/Models/LibriTTS/epochs_2nd_00020.pth -d /content/StyleTTS2/Models/LibriTTS -o epochs_2nd_00020.pth\n", 30 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/StyleTTS2-LibriTTS/raw/adedc90b53fb0676f83d4c9f1e01f1e8650ba15d/data/Models/LibriTTS/config.yml -d /content/StyleTTS2/Models/LibriTTS -o config.yml\n", 31 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/StyleTTS2-LibriTTS/raw/20b8b78c7e27b46d8525f3b1bb0d2725a12a56ff/data/Data.zip -d /content/StyleTTS2 -o Data.zip\n", 32 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/wavlm-base-plus/raw/624830cf7f4bc949d33bf94fa45895037e78c693/data/config.json -d /content/StyleTTS2/wavlm-base-plus -o config.json\n", 33 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/wavlm-base-plus/raw/624830cf7f4bc949d33bf94fa45895037e78c693/data/preprocessor_config.json -d /content/StyleTTS2/wavlm-base-plus -o preprocessor_config.json\n", 34 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/audio/wavlm-base-plus/raw/624830cf7f4bc949d33bf94fa45895037e78c693/data/pytorch_model.bin -d /content/StyleTTS2/wavlm-base-plus -o pytorch_model.bin\n", 35 | "# !unzip -o Data.zip\n", 36 | "\n", 37 | "!rm -rf /content/StyleTTS2/Data\n", 38 | "!mkdir -p /content/StyleTTS2/Data/wavs\n", 39 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_0.wav -O /content/StyleTTS2/Data/wavs/0_0.wav\n", 40 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_1.wav -O /content/StyleTTS2/Data/wavs/0_1.wav\n", 41 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_2.wav -O /content/StyleTTS2/Data/wavs/0_2.wav\n", 42 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_3.wav -O /content/StyleTTS2/Data/wavs/0_3.wav\n", 43 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_4.wav -O /content/StyleTTS2/Data/wavs/0_4.wav\n", 44 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_5.wav -O /content/StyleTTS2/Data/wavs/0_5.wav\n", 45 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_6.wav -O /content/StyleTTS2/Data/wavs/0_6.wav\n", 46 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_7.wav -O /content/StyleTTS2/Data/wavs/0_7.wav\n", 47 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_8.wav -O /content/StyleTTS2/Data/wavs/0_8.wav\n", 48 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_9.wav -O /content/StyleTTS2/Data/wavs/0_9.wav\n", 49 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_10.wav -O /content/StyleTTS2/Data/wavs/0_10.wav\n", 50 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_11.wav -O /content/StyleTTS2/Data/wavs/0_11.wav\n", 51 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_12.wav -O /content/StyleTTS2/Data/wavs/0_12.wav\n", 52 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_13.wav -O /content/StyleTTS2/Data/wavs/0_13.wav\n", 53 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_14.wav -O /content/StyleTTS2/Data/wavs/0_14.wav\n", 54 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_15.wav -O /content/StyleTTS2/Data/wavs/0_15.wav\n", 55 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_16.wav -O /content/StyleTTS2/Data/wavs/0_16.wav\n", 56 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_17.wav -O /content/StyleTTS2/Data/wavs/0_17.wav\n", 57 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_18.wav -O /content/StyleTTS2/Data/wavs/0_18.wav\n", 58 | "!wget https://dagshub.com/StyleTTS/Data/raw/b1eedf16d1f162b3081c0f10c6fba11a8420cb26/data/wavs/0_19.wav -O /content/StyleTTS2/Data/wavs/0_19.wav\n", 59 | "!wget https://gist.github.com/camenduru/aea7c0f76cfe3ed79c521c27374f613d/raw/958d6e0035d89896d2f6c35d500d870b66075005/gistfile1.txt -O /content/StyleTTS2/Data/OOD_texts.txt\n", 60 | "!wget https://gist.github.com/camenduru/aea7c0f76cfe3ed79c521c27374f613d/raw/958d6e0035d89896d2f6c35d500d870b66075005/gistfile1.txt -O /content/StyleTTS2/Data/train_list.txt\n", 61 | "!wget https://gist.github.com/camenduru/aea7c0f76cfe3ed79c521c27374f613d/raw/958d6e0035d89896d2f6c35d500d870b66075005/gistfile1.txt -O /content/StyleTTS2/Data/val_list.txt" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "%cd /content/StyleTTS2\n", 71 | "\n", 72 | "config_path = \"Configs/config_ft.yml\"\n", 73 | "import yaml\n", 74 | "config = yaml.safe_load(open(config_path))\n", 75 | "config['data_params']['OOD_data'] = \"Data/OOD_texts.txt\"\n", 76 | "config['data_params']['root_path'] = \"Data/wavs/\"\n", 77 | "config['data_params']['train_data'] = \"Data/train_list.txt\"\n", 78 | "config['data_params']['val_data'] = \"Data/val_list.txt\"\n", 79 | "config['batch_size'] = 2 # not enough RAM\n", 80 | "config['max_len'] = 100 # not enough RAM\n", 81 | "config['loss_params']['joint_epoch'] = 150 # we do not do SLM adversarial training due to not enough RAM\n", 82 | "config['save_freq'] = 1\n", 83 | "config['data_params']['logger'] = \"mlflow\"\n", 84 | "\n", 85 | "with open(config_path, 'w') as outfile:\n", 86 | " yaml.dump(config, outfile, default_flow_style=True)\n", 87 | "\n", 88 | "import os\n", 89 | "os.environ['MLFLOW_TRACKING_URI'] = 'MLFLOW_TRACKING_URI'\n", 90 | "os.environ['MLFLOW_TRACKING_USERNAME'] = 'MLFLOW_TRACKING_USERNAME'\n", 91 | "os.environ['MLFLOW_TRACKING_PASSWORD'] = 'MLFLOW_TRACKING_PASSWORD'\n", 92 | "\n", 93 | "!cat Configs/config_ft.yml" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "%cd /content/StyleTTS2\n", 103 | "# !accelerate launch --mixed_precision=fp16 --multi_gpu train_finetune_accelerate.py --config_path ./Configs/config_ft.yml\n", 104 | "!python train_finetune.py --config_path ./Configs/config_ft.yml" 105 | ] 106 | } 107 | ], 108 | "metadata": { 109 | "accelerator": "GPU", 110 | "colab": { 111 | "gpuType": "T4", 112 | "provenance": [] 113 | }, 114 | "kernelspec": { 115 | "display_name": "Python 3", 116 | "name": "python3" 117 | }, 118 | "language_info": { 119 | "name": "python" 120 | } 121 | }, 122 | "nbformat": 4, 123 | "nbformat_minor": 0 124 | } 125 | -------------------------------------------------------------------------------- /styletts2_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/styletts2_test.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "VjYy0F2gZIPR" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%cd /content\n", 21 | "!git clone https://github.com/yl4579/StyleTTS2.git\n", 22 | "%cd StyleTTS2\n", 23 | "!pip install -q munch pydub accelerate phonemizer einops einops-exts git+https://github.com/resemble-ai/monotonic_align.git\n", 24 | "!apt install espeak-ng aria2\n", 25 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00012.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00012.pth\n", 26 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/raw/main/config_ft.yml -d /content/StyleTTS2/Models/LJSpeech -o config_ft.yml\n", 27 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/LJ001-0110.wav -d /content/StyleTTS2/Data/wavs -o LJ001-0110.wav\n", 28 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/ash.wav -d /content/StyleTTS2/Data/wavs -o ash.wav\n", 29 | "\n", 30 | "%cd StyleTTS2\n", 31 | "\n", 32 | "import nltk\n", 33 | "nltk.download('punkt')\n", 34 | "\n", 35 | "import torch\n", 36 | "torch.manual_seed(0)\n", 37 | "torch.backends.cudnn.benchmark = False\n", 38 | "torch.backends.cudnn.deterministic = True\n", 39 | "\n", 40 | "import random\n", 41 | "random.seed(0)\n", 42 | "\n", 43 | "import numpy as np\n", 44 | "np.random.seed(0)\n", 45 | "\n", 46 | "# load packages\n", 47 | "import time\n", 48 | "import random\n", 49 | "import yaml\n", 50 | "from munch import Munch\n", 51 | "import numpy as np\n", 52 | "import torch\n", 53 | "from torch import nn\n", 54 | "import torch.nn.functional as F\n", 55 | "import torchaudio\n", 56 | "import librosa\n", 57 | "from nltk.tokenize import word_tokenize\n", 58 | "\n", 59 | "from models import *\n", 60 | "from utils import *\n", 61 | "from text_utils import TextCleaner\n", 62 | "textclenaer = TextCleaner()\n", 63 | "\n", 64 | "%matplotlib inline\n", 65 | "\n", 66 | "to_mel = torchaudio.transforms.MelSpectrogram(\n", 67 | " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n", 68 | "mean, std = -4, 4\n", 69 | "\n", 70 | "def length_to_mask(lengths):\n", 71 | " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n", 72 | " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n", 73 | " return mask\n", 74 | "\n", 75 | "def preprocess(wave):\n", 76 | " wave_tensor = torch.from_numpy(wave).float()\n", 77 | " mel_tensor = to_mel(wave_tensor)\n", 78 | " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n", 79 | " return mel_tensor\n", 80 | "\n", 81 | "def compute_style(path):\n", 82 | " wave, sr = librosa.load(path, sr=24000)\n", 83 | " audio, index = librosa.effects.trim(wave, top_db=30)\n", 84 | " if sr != 24000:\n", 85 | " audio = librosa.resample(audio, sr, 24000)\n", 86 | " mel_tensor = preprocess(audio).to(device)\n", 87 | "\n", 88 | " with torch.no_grad():\n", 89 | " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n", 90 | " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n", 91 | "\n", 92 | " return torch.cat([ref_s, ref_p], dim=1)\n", 93 | "\n", 94 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 95 | "\n", 96 | "# load phonemizer\n", 97 | "import phonemizer\n", 98 | "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n", 99 | "\n", 100 | "config = yaml.safe_load(open(\"Models/LJSpeech/config_ft.yml\"))\n", 101 | "\n", 102 | "# load pretrained ASR model\n", 103 | "ASR_config = config.get('ASR_config', False)\n", 104 | "ASR_path = config.get('ASR_path', False)\n", 105 | "text_aligner = load_ASR_models(ASR_path, ASR_config)\n", 106 | "\n", 107 | "# load pretrained F0 model\n", 108 | "F0_path = config.get('F0_path', False)\n", 109 | "pitch_extractor = load_F0_models(F0_path)\n", 110 | "\n", 111 | "# load BERT model\n", 112 | "from Utils.PLBERT.util import load_plbert\n", 113 | "BERT_path = config.get('PLBERT_dir', False)\n", 114 | "plbert = load_plbert(BERT_path)\n", 115 | "\n", 116 | "model_params = recursive_munch(config['model_params'])\n", 117 | "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n", 118 | "_ = [model[key].eval() for key in model]\n", 119 | "_ = [model[key].to(device) for key in model]\n", 120 | "\n", 121 | "files = [f for f in os.listdir(\"Models/LJSpeech/\") if f.endswith('.pth')]\n", 122 | "sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))\n", 123 | "\n", 124 | "params_whole = torch.load(\"Models/LJSpeech/\" + sorted_files[-1], map_location='cpu')\n", 125 | "params = params_whole['net']\n", 126 | "\n", 127 | "for key in model:\n", 128 | " if key in params:\n", 129 | " print('%s loaded' % key)\n", 130 | " try:\n", 131 | " model[key].load_state_dict(params[key])\n", 132 | " except:\n", 133 | " from collections import OrderedDict\n", 134 | " state_dict = params[key]\n", 135 | " new_state_dict = OrderedDict()\n", 136 | " for k, v in state_dict.items():\n", 137 | " name = k[7:] # remove `module.`\n", 138 | " new_state_dict[name] = v\n", 139 | " # load params\n", 140 | " model[key].load_state_dict(new_state_dict, strict=False)\n", 141 | "# except:\n", 142 | "# _load(params[key], model[key])\n", 143 | "_ = [model[key].eval() for key in model]\n", 144 | "\n", 145 | "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n", 146 | "\n", 147 | "sampler = DiffusionSampler(\n", 148 | " model.diffusion.diffusion,\n", 149 | " sampler=ADPM2Sampler(),\n", 150 | " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n", 151 | " clamp=False\n", 152 | ")\n", 153 | "\n", 154 | "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n", 155 | " text = text.strip()\n", 156 | " ps = global_phonemizer.phonemize([text])\n", 157 | " ps = word_tokenize(ps[0])\n", 158 | " ps = ' '.join(ps)\n", 159 | " tokens = textclenaer(ps)\n", 160 | " tokens.insert(0, 0)\n", 161 | " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n", 162 | "\n", 163 | " with torch.no_grad():\n", 164 | " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n", 165 | " text_mask = length_to_mask(input_lengths).to(device)\n", 166 | "\n", 167 | " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n", 168 | " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n", 169 | " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n", 170 | "\n", 171 | " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n", 172 | " embedding=bert_dur,\n", 173 | " embedding_scale=embedding_scale,\n", 174 | " features=ref_s, # reference from the same speaker as the embedding\n", 175 | " num_steps=diffusion_steps).squeeze(1)\n", 176 | "\n", 177 | "\n", 178 | " s = s_pred[:, 128:]\n", 179 | " ref = s_pred[:, :128]\n", 180 | "\n", 181 | " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n", 182 | " s = beta * s + (1 - beta) * ref_s[:, 128:]\n", 183 | "\n", 184 | " d = model.predictor.text_encoder(d_en,\n", 185 | " s, input_lengths, text_mask)\n", 186 | "\n", 187 | " x, _ = model.predictor.lstm(d)\n", 188 | " duration = model.predictor.duration_proj(x)\n", 189 | "\n", 190 | " duration = torch.sigmoid(duration).sum(axis=-1)\n", 191 | " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n", 192 | "\n", 193 | " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n", 194 | " c_frame = 0\n", 195 | " for i in range(pred_aln_trg.size(0)):\n", 196 | " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n", 197 | " c_frame += int(pred_dur[i].data)\n", 198 | "\n", 199 | " # encode prosody\n", 200 | " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n", 201 | " if model_params.decoder.type == \"hifigan\":\n", 202 | " asr_new = torch.zeros_like(en)\n", 203 | " asr_new[:, :, 0] = en[:, :, 0]\n", 204 | " asr_new[:, :, 1:] = en[:, :, 0:-1]\n", 205 | " en = asr_new\n", 206 | "\n", 207 | " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n", 208 | "\n", 209 | " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n", 210 | " if model_params.decoder.type == \"hifigan\":\n", 211 | " asr_new = torch.zeros_like(asr)\n", 212 | " asr_new[:, :, 0] = asr[:, :, 0]\n", 213 | " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n", 214 | " asr = asr_new\n", 215 | "\n", 216 | " out = model.decoder(asr,\n", 217 | " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n", 218 | "\n", 219 | "\n", 220 | " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "text = '''Maltby and Company would issue warrants on them deliverable to the importer, and the goods were then passed to be stored in neighboring warehouses.'''\n", 230 | "# path = \"Data/wavs/LJ001-0110.wav\"\n", 231 | "path = \"Data/wavs/ash.wav\"\n", 232 | "ref_s = compute_style(path)\n", 233 | "\n", 234 | "start = time.time()\n", 235 | "# wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)\n", 236 | "wav = inference(text, ref_s)\n", 237 | "rtf = (time.time() - start) / (len(wav) / 24000)\n", 238 | "print(f\"RTF = {rtf:5f}\")\n", 239 | "import IPython.display as ipd\n", 240 | "display(ipd.Audio(wav, rate=24000, normalize=False))\n", 241 | "sorted_files[-1]" 242 | ] 243 | } 244 | ], 245 | "metadata": { 246 | "accelerator": "GPU", 247 | "colab": { 248 | "gpuType": "T4", 249 | "provenance": [] 250 | }, 251 | "kernelspec": { 252 | "display_name": "Python 3", 253 | "name": "python3" 254 | }, 255 | "language_info": { 256 | "name": "python" 257 | } 258 | }, 259 | "nbformat": 4, 260 | "nbformat_minor": 0 261 | } 262 | -------------------------------------------------------------------------------- /bucilianus1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/bucilianus1.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "VjYy0F2gZIPR" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%cd /content\n", 21 | "!git clone https://github.com/yl4579/StyleTTS2.git\n", 22 | "%cd StyleTTS2\n", 23 | "!pip install -q munch pydub accelerate phonemizer einops einops-exts git+https://github.com/resemble-ai/monotonic_align.git\n", 24 | "!apt install espeak-ng aria2\n", 25 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/bucilianus-1/resolve/main/epoch_2nd_00030.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00030.pth\n", 26 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/raw/main/config_ft.yml -d /content/StyleTTS2/Models/LJSpeech -o config_ft.yml\n", 27 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/LJ001-0110.wav -d /content/StyleTTS2/Data/wavs -o LJ001-0110.wav\n", 28 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/ash.wav -d /content/StyleTTS2/Data/wavs -o ash.wav\n", 29 | "!wget https://replicate.delivery/pbxt/KB9Q5ER6n39loKuLKdkmPlOr8uR0eCNHY9dFYZQvL3MIfsBH/demo_speaker2.mp3 -O /content/demo_speaker2.mp3\n", 30 | "!wget https://replicate.delivery/pbxt/KB9YGZ4v1woTUDd8rSRlizdqIQjxqGjDnosgvlQkTZo7IRI7/demo_speaker0.mp3 -O /content/demo_speaker0.mp3\n", 31 | "\n", 32 | "%cd StyleTTS2\n", 33 | "\n", 34 | "import nltk\n", 35 | "nltk.download('punkt')\n", 36 | "\n", 37 | "import torch\n", 38 | "torch.manual_seed(0)\n", 39 | "torch.backends.cudnn.benchmark = False\n", 40 | "torch.backends.cudnn.deterministic = True\n", 41 | "\n", 42 | "import random\n", 43 | "random.seed(0)\n", 44 | "\n", 45 | "import numpy as np\n", 46 | "np.random.seed(0)\n", 47 | "\n", 48 | "# load packages\n", 49 | "import time\n", 50 | "import random\n", 51 | "import yaml\n", 52 | "from munch import Munch\n", 53 | "import numpy as np\n", 54 | "import torch\n", 55 | "from torch import nn\n", 56 | "import torch.nn.functional as F\n", 57 | "import torchaudio\n", 58 | "import librosa\n", 59 | "from nltk.tokenize import word_tokenize\n", 60 | "\n", 61 | "from models import *\n", 62 | "from utils import *\n", 63 | "from text_utils import TextCleaner\n", 64 | "textclenaer = TextCleaner()\n", 65 | "\n", 66 | "%matplotlib inline\n", 67 | "\n", 68 | "to_mel = torchaudio.transforms.MelSpectrogram(\n", 69 | " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n", 70 | "mean, std = -4, 4\n", 71 | "\n", 72 | "def length_to_mask(lengths):\n", 73 | " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n", 74 | " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n", 75 | " return mask\n", 76 | "\n", 77 | "def preprocess(wave):\n", 78 | " wave_tensor = torch.from_numpy(wave).float()\n", 79 | " mel_tensor = to_mel(wave_tensor)\n", 80 | " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n", 81 | " return mel_tensor\n", 82 | "\n", 83 | "def compute_style(path):\n", 84 | " wave, sr = librosa.load(path, sr=24000)\n", 85 | " audio, index = librosa.effects.trim(wave, top_db=30)\n", 86 | " if sr != 24000:\n", 87 | " audio = librosa.resample(audio, sr, 24000)\n", 88 | " mel_tensor = preprocess(audio).to(device)\n", 89 | "\n", 90 | " with torch.no_grad():\n", 91 | " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n", 92 | " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n", 93 | "\n", 94 | " return torch.cat([ref_s, ref_p], dim=1)\n", 95 | "\n", 96 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 97 | "\n", 98 | "# load phonemizer\n", 99 | "import phonemizer\n", 100 | "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n", 101 | "\n", 102 | "config = yaml.safe_load(open(\"Models/LJSpeech/config_ft.yml\"))\n", 103 | "\n", 104 | "# load pretrained ASR model\n", 105 | "ASR_config = config.get('ASR_config', False)\n", 106 | "ASR_path = config.get('ASR_path', False)\n", 107 | "text_aligner = load_ASR_models(ASR_path, ASR_config)\n", 108 | "\n", 109 | "# load pretrained F0 model\n", 110 | "F0_path = config.get('F0_path', False)\n", 111 | "pitch_extractor = load_F0_models(F0_path)\n", 112 | "\n", 113 | "# load BERT model\n", 114 | "from Utils.PLBERT.util import load_plbert\n", 115 | "BERT_path = config.get('PLBERT_dir', False)\n", 116 | "plbert = load_plbert(BERT_path)\n", 117 | "\n", 118 | "model_params = recursive_munch(config['model_params'])\n", 119 | "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n", 120 | "_ = [model[key].eval() for key in model]\n", 121 | "_ = [model[key].to(device) for key in model]\n", 122 | "\n", 123 | "files = [f for f in os.listdir(\"Models/LJSpeech/\") if f.endswith('.pth')]\n", 124 | "sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))\n", 125 | "\n", 126 | "params_whole = torch.load(\"Models/LJSpeech/\" + sorted_files[-1], map_location='cpu')\n", 127 | "params = params_whole['net']\n", 128 | "\n", 129 | "for key in model:\n", 130 | " if key in params:\n", 131 | " print('%s loaded' % key)\n", 132 | " try:\n", 133 | " model[key].load_state_dict(params[key])\n", 134 | " except:\n", 135 | " from collections import OrderedDict\n", 136 | " state_dict = params[key]\n", 137 | " new_state_dict = OrderedDict()\n", 138 | " for k, v in state_dict.items():\n", 139 | " name = k[7:] # remove `module.`\n", 140 | " new_state_dict[name] = v\n", 141 | " # load params\n", 142 | " model[key].load_state_dict(new_state_dict, strict=False)\n", 143 | "# except:\n", 144 | "# _load(params[key], model[key])\n", 145 | "_ = [model[key].eval() for key in model]\n", 146 | "\n", 147 | "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n", 148 | "\n", 149 | "sampler = DiffusionSampler(\n", 150 | " model.diffusion.diffusion,\n", 151 | " sampler=ADPM2Sampler(),\n", 152 | " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n", 153 | " clamp=False\n", 154 | ")\n", 155 | "\n", 156 | "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n", 157 | " text = text.strip()\n", 158 | " ps = global_phonemizer.phonemize([text])\n", 159 | " ps = word_tokenize(ps[0])\n", 160 | " ps = ' '.join(ps)\n", 161 | " tokens = textclenaer(ps)\n", 162 | " tokens.insert(0, 0)\n", 163 | " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n", 164 | "\n", 165 | " with torch.no_grad():\n", 166 | " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n", 167 | " text_mask = length_to_mask(input_lengths).to(device)\n", 168 | "\n", 169 | " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n", 170 | " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n", 171 | " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n", 172 | "\n", 173 | " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n", 174 | " embedding=bert_dur,\n", 175 | " embedding_scale=embedding_scale,\n", 176 | " features=ref_s, # reference from the same speaker as the embedding\n", 177 | " num_steps=diffusion_steps).squeeze(1)\n", 178 | "\n", 179 | "\n", 180 | " s = s_pred[:, 128:]\n", 181 | " ref = s_pred[:, :128]\n", 182 | "\n", 183 | " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n", 184 | " s = beta * s + (1 - beta) * ref_s[:, 128:]\n", 185 | "\n", 186 | " d = model.predictor.text_encoder(d_en,\n", 187 | " s, input_lengths, text_mask)\n", 188 | "\n", 189 | " x, _ = model.predictor.lstm(d)\n", 190 | " duration = model.predictor.duration_proj(x)\n", 191 | "\n", 192 | " duration = torch.sigmoid(duration).sum(axis=-1)\n", 193 | " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n", 194 | "\n", 195 | " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n", 196 | " c_frame = 0\n", 197 | " for i in range(pred_aln_trg.size(0)):\n", 198 | " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n", 199 | " c_frame += int(pred_dur[i].data)\n", 200 | "\n", 201 | " # encode prosody\n", 202 | " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n", 203 | " if model_params.decoder.type == \"hifigan\":\n", 204 | " asr_new = torch.zeros_like(en)\n", 205 | " asr_new[:, :, 0] = en[:, :, 0]\n", 206 | " asr_new[:, :, 1:] = en[:, :, 0:-1]\n", 207 | " en = asr_new\n", 208 | "\n", 209 | " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n", 210 | "\n", 211 | " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n", 212 | " if model_params.decoder.type == \"hifigan\":\n", 213 | " asr_new = torch.zeros_like(asr)\n", 214 | " asr_new[:, :, 0] = asr[:, :, 0]\n", 215 | " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n", 216 | " asr = asr_new\n", 217 | "\n", 218 | " out = model.decoder(asr,\n", 219 | " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n", 220 | "\n", 221 | "\n", 222 | " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "text = '''Supporters of Caesar argue that he was a successful military leader who brought stability and reforms to Rome. They may view his assassination as a betrayal by jealous and power-hungry senators who feared his increasing authority.'''\n", 232 | "# path = \"Data/wavs/LJ001-0110.wav\"\n", 233 | "path = \"/content/demo_speaker2.mp3\"\n", 234 | "# path = \"/content/demo_speaker0.mp3\"\n", 235 | "# path = \"Data/wavs/ash.wav\"\n", 236 | "ref_s = compute_style(path)\n", 237 | "\n", 238 | "start = time.time()\n", 239 | "# wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)\n", 240 | "wav = inference(text, ref_s)\n", 241 | "rtf = (time.time() - start) / (len(wav) / 24000)\n", 242 | "print(f\"RTF = {rtf:5f}\")\n", 243 | "import IPython.display as ipd\n", 244 | "display(ipd.Audio(wav, rate=24000, normalize=False))\n", 245 | "sorted_files[-1]" 246 | ] 247 | } 248 | ], 249 | "metadata": { 250 | "accelerator": "GPU", 251 | "colab": { 252 | "gpuType": "T4", 253 | "provenance": [] 254 | }, 255 | "kernelspec": { 256 | "display_name": "Python 3", 257 | "name": "python3" 258 | }, 259 | "language_info": { 260 | "name": "python" 261 | } 262 | }, 263 | "nbformat": 4, 264 | "nbformat_minor": 0 265 | } 266 | -------------------------------------------------------------------------------- /styletts2_test_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/styletts-colab/blob/main/styletts2_test_colab.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "VjYy0F2gZIPR" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%cd /content\n", 21 | "!git clone https://github.com/yl4579/StyleTTS2.git\n", 22 | "%cd StyleTTS2\n", 23 | "!pip install -q munch pydub accelerate phonemizer einops einops-exts git+https://github.com/resemble-ai/monotonic_align.git\n", 24 | "!apt install espeak-ng aria2\n", 25 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00000.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00000.pth\n", 26 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00001.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00001.pth\n", 27 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00002.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00002.pth\n", 28 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00003.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00003.pth\n", 29 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00004.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00004.pth\n", 30 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00005.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00005.pth\n", 31 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00006.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00006.pth\n", 32 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00007.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00007.pth\n", 33 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00008.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00008.pth\n", 34 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00009.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00009.pth\n", 35 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00009_2.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00009_2.pth\n", 36 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00009_3.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00009_3.pth\n", 37 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00009_4.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00009_4.pth\n", 38 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00010.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00010.pth\n", 39 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00011.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00011.pth\n", 40 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00012.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00012.pth\n", 41 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00013.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00013.pth\n", 42 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00014.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00014.pth\n", 43 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00015.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00015.pth\n", 44 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00016.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00016.pth\n", 45 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00017.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00017.pth\n", 46 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00018.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00018.pth\n", 47 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00019.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00019.pth\n", 48 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model2/raw/22c1bada3b2eee2bee63dd8a98792cb3ab6aa0c9/model/epoch_2nd_00020.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00020.pth\n", 49 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://dagshub.com/StyleTTS/Model3/raw/9734ed6cbac07b09747b795811a081ce9896d8b3/model/epoch_2nd_00003.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00003.pth\n", 50 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00000.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00000.pth\n", 51 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00001.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00001.pth\n", 52 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00002.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00002.pth\n", 53 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00003.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00003.pth\n", 54 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00004.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00004.pth\n", 55 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00005.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00005.pth\n", 56 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00006.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00006.pth\n", 57 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00007.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00007.pth\n", 58 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00008.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00008.pth\n", 59 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00009.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00009.pth\n", 60 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00010.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00010.pth\n", 61 | "# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00011.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00011.pth\n", 62 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/camenduru/StyleTTS2/resolve/main/epoch_2nd_00012.pth -d /content/StyleTTS2/Models/LJSpeech -o epoch_2nd_00012.pth\n", 63 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/raw/main/config_ft.yml -d /content/StyleTTS2/Models/LJSpeech -o config_ft.yml\n", 64 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/LJ001-0110.wav -d /content/StyleTTS2/Data/wavs -o LJ001-0110.wav\n", 65 | "!aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/audo/all_test/resolve/main/ash.wav -d /content/StyleTTS2/Data/wavs -o ash.wav\n", 66 | "\n", 67 | "%cd StyleTTS2\n", 68 | "\n", 69 | "import nltk\n", 70 | "nltk.download('punkt')\n", 71 | "\n", 72 | "import torch\n", 73 | "torch.manual_seed(0)\n", 74 | "torch.backends.cudnn.benchmark = False\n", 75 | "torch.backends.cudnn.deterministic = True\n", 76 | "\n", 77 | "import random\n", 78 | "random.seed(0)\n", 79 | "\n", 80 | "import numpy as np\n", 81 | "np.random.seed(0)\n", 82 | "\n", 83 | "# load packages\n", 84 | "import time\n", 85 | "import random\n", 86 | "import yaml\n", 87 | "from munch import Munch\n", 88 | "import numpy as np\n", 89 | "import torch\n", 90 | "from torch import nn\n", 91 | "import torch.nn.functional as F\n", 92 | "import torchaudio\n", 93 | "import librosa\n", 94 | "from nltk.tokenize import word_tokenize\n", 95 | "\n", 96 | "from models import *\n", 97 | "from utils import *\n", 98 | "from text_utils import TextCleaner\n", 99 | "textclenaer = TextCleaner()\n", 100 | "\n", 101 | "%matplotlib inline\n", 102 | "\n", 103 | "to_mel = torchaudio.transforms.MelSpectrogram(\n", 104 | " n_mels=80, n_fft=2048, win_length=1200, hop_length=300)\n", 105 | "mean, std = -4, 4\n", 106 | "\n", 107 | "def length_to_mask(lengths):\n", 108 | " mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)\n", 109 | " mask = torch.gt(mask+1, lengths.unsqueeze(1))\n", 110 | " return mask\n", 111 | "\n", 112 | "def preprocess(wave):\n", 113 | " wave_tensor = torch.from_numpy(wave).float()\n", 114 | " mel_tensor = to_mel(wave_tensor)\n", 115 | " mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std\n", 116 | " return mel_tensor\n", 117 | "\n", 118 | "def compute_style(path):\n", 119 | " wave, sr = librosa.load(path, sr=24000)\n", 120 | " audio, index = librosa.effects.trim(wave, top_db=30)\n", 121 | " if sr != 24000:\n", 122 | " audio = librosa.resample(audio, sr, 24000)\n", 123 | " mel_tensor = preprocess(audio).to(device)\n", 124 | "\n", 125 | " with torch.no_grad():\n", 126 | " ref_s = model.style_encoder(mel_tensor.unsqueeze(1))\n", 127 | " ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))\n", 128 | "\n", 129 | " return torch.cat([ref_s, ref_p], dim=1)\n", 130 | "\n", 131 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 132 | "\n", 133 | "# load phonemizer\n", 134 | "import phonemizer\n", 135 | "global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)\n", 136 | "\n", 137 | "config = yaml.safe_load(open(\"Models/LJSpeech/config_ft.yml\"))\n", 138 | "\n", 139 | "# load pretrained ASR model\n", 140 | "ASR_config = config.get('ASR_config', False)\n", 141 | "ASR_path = config.get('ASR_path', False)\n", 142 | "text_aligner = load_ASR_models(ASR_path, ASR_config)\n", 143 | "\n", 144 | "# load pretrained F0 model\n", 145 | "F0_path = config.get('F0_path', False)\n", 146 | "pitch_extractor = load_F0_models(F0_path)\n", 147 | "\n", 148 | "# load BERT model\n", 149 | "from Utils.PLBERT.util import load_plbert\n", 150 | "BERT_path = config.get('PLBERT_dir', False)\n", 151 | "plbert = load_plbert(BERT_path)\n", 152 | "\n", 153 | "model_params = recursive_munch(config['model_params'])\n", 154 | "model = build_model(model_params, text_aligner, pitch_extractor, plbert)\n", 155 | "_ = [model[key].eval() for key in model]\n", 156 | "_ = [model[key].to(device) for key in model]\n", 157 | "\n", 158 | "files = [f for f in os.listdir(\"Models/LJSpeech/\") if f.endswith('.pth')]\n", 159 | "sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))\n", 160 | "\n", 161 | "params_whole = torch.load(\"Models/LJSpeech/\" + sorted_files[-1], map_location='cpu')\n", 162 | "params = params_whole['net']\n", 163 | "\n", 164 | "for key in model:\n", 165 | " if key in params:\n", 166 | " print('%s loaded' % key)\n", 167 | " try:\n", 168 | " model[key].load_state_dict(params[key])\n", 169 | " except:\n", 170 | " from collections import OrderedDict\n", 171 | " state_dict = params[key]\n", 172 | " new_state_dict = OrderedDict()\n", 173 | " for k, v in state_dict.items():\n", 174 | " name = k[7:] # remove `module.`\n", 175 | " new_state_dict[name] = v\n", 176 | " # load params\n", 177 | " model[key].load_state_dict(new_state_dict, strict=False)\n", 178 | "# except:\n", 179 | "# _load(params[key], model[key])\n", 180 | "_ = [model[key].eval() for key in model]\n", 181 | "\n", 182 | "from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule\n", 183 | "\n", 184 | "sampler = DiffusionSampler(\n", 185 | " model.diffusion.diffusion,\n", 186 | " sampler=ADPM2Sampler(),\n", 187 | " sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters\n", 188 | " clamp=False\n", 189 | ")\n", 190 | "\n", 191 | "def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):\n", 192 | " text = text.strip()\n", 193 | " ps = global_phonemizer.phonemize([text])\n", 194 | " ps = word_tokenize(ps[0])\n", 195 | " ps = ' '.join(ps)\n", 196 | " tokens = textclenaer(ps)\n", 197 | " tokens.insert(0, 0)\n", 198 | " tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)\n", 199 | "\n", 200 | " with torch.no_grad():\n", 201 | " input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)\n", 202 | " text_mask = length_to_mask(input_lengths).to(device)\n", 203 | "\n", 204 | " t_en = model.text_encoder(tokens, input_lengths, text_mask)\n", 205 | " bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())\n", 206 | " d_en = model.bert_encoder(bert_dur).transpose(-1, -2)\n", 207 | "\n", 208 | " s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),\n", 209 | " embedding=bert_dur,\n", 210 | " embedding_scale=embedding_scale,\n", 211 | " features=ref_s, # reference from the same speaker as the embedding\n", 212 | " num_steps=diffusion_steps).squeeze(1)\n", 213 | "\n", 214 | "\n", 215 | " s = s_pred[:, 128:]\n", 216 | " ref = s_pred[:, :128]\n", 217 | "\n", 218 | " ref = alpha * ref + (1 - alpha) * ref_s[:, :128]\n", 219 | " s = beta * s + (1 - beta) * ref_s[:, 128:]\n", 220 | "\n", 221 | " d = model.predictor.text_encoder(d_en,\n", 222 | " s, input_lengths, text_mask)\n", 223 | "\n", 224 | " x, _ = model.predictor.lstm(d)\n", 225 | " duration = model.predictor.duration_proj(x)\n", 226 | "\n", 227 | " duration = torch.sigmoid(duration).sum(axis=-1)\n", 228 | " pred_dur = torch.round(duration.squeeze()).clamp(min=1)\n", 229 | "\n", 230 | " pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))\n", 231 | " c_frame = 0\n", 232 | " for i in range(pred_aln_trg.size(0)):\n", 233 | " pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1\n", 234 | " c_frame += int(pred_dur[i].data)\n", 235 | "\n", 236 | " # encode prosody\n", 237 | " en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))\n", 238 | " if model_params.decoder.type == \"hifigan\":\n", 239 | " asr_new = torch.zeros_like(en)\n", 240 | " asr_new[:, :, 0] = en[:, :, 0]\n", 241 | " asr_new[:, :, 1:] = en[:, :, 0:-1]\n", 242 | " en = asr_new\n", 243 | "\n", 244 | " F0_pred, N_pred = model.predictor.F0Ntrain(en, s)\n", 245 | "\n", 246 | " asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))\n", 247 | " if model_params.decoder.type == \"hifigan\":\n", 248 | " asr_new = torch.zeros_like(asr)\n", 249 | " asr_new[:, :, 0] = asr[:, :, 0]\n", 250 | " asr_new[:, :, 1:] = asr[:, :, 0:-1]\n", 251 | " asr = asr_new\n", 252 | "\n", 253 | " out = model.decoder(asr,\n", 254 | " F0_pred, N_pred, ref.squeeze().unsqueeze(0))\n", 255 | "\n", 256 | "\n", 257 | " return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "text = '''Maltby and Company would issue warrants on them deliverable to the importer, and the goods were then passed to be stored in neighboring warehouses.'''\n", 267 | "# path = \"Data/wavs/LJ001-0110.wav\"\n", 268 | "path = \"Data/wavs/ash.wav\"\n", 269 | "ref_s = compute_style(path)\n", 270 | "\n", 271 | "start = time.time()\n", 272 | "# wav = inference(text, ref_s, alpha=0.9, beta=0.9, diffusion_steps=10, embedding_scale=1)\n", 273 | "wav = inference(text, ref_s)\n", 274 | "rtf = (time.time() - start) / (len(wav) / 24000)\n", 275 | "print(f\"RTF = {rtf:5f}\")\n", 276 | "import IPython.display as ipd\n", 277 | "display(ipd.Audio(wav, rate=24000, normalize=False))\n", 278 | "sorted_files[-1]" 279 | ] 280 | } 281 | ], 282 | "metadata": { 283 | "accelerator": "GPU", 284 | "colab": { 285 | "gpuType": "T4", 286 | "provenance": [] 287 | }, 288 | "kernelspec": { 289 | "display_name": "Python 3", 290 | "name": "python3" 291 | }, 292 | "language_info": { 293 | "name": "python" 294 | } 295 | }, 296 | "nbformat": 4, 297 | "nbformat_minor": 0 298 | } 299 | --------------------------------------------------------------------------------