├── Controllable_TalkNet.ipynb ├── README.md └── TalkNet_Training_A100.ipynb /Controllable_TalkNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "include_colab_link": true 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "Gss5Ox_RNiba" 33 | }, 34 | "source": [ 35 | "# Controllable TalkNet\n", 36 | "To run TalkNet, click on Runtime -> Run all. The interface will appear at the bottom of the page when it's ready.\n", 37 | "\n", 38 | "## Instructions\n", 39 | "\n", 40 | "* Once the notebook is running, click on Files (the folder icon on the left edge).\n", 41 | "* Upload audio clips of a singing or speaking voice by dragging and dropping them onto the sidebar.\n", 42 | "* Click on \"Update file list\" in the TalkNet interface. Select an audio file from the dropdown, and type what it says into the Transcript box.\n", 43 | "* Select a character, and press Generate. The first line will take a little longer to generate.\n", 44 | "\n", 45 | "## Tips and tricks\n", 46 | "* If you want to use TalkNet as regular text-to-speech system, without any reference audio, tick the \"Disable reference audio\" checkbox.\n", 47 | "* You can use [ARPABET](http://www.speech.cs.cmu.edu/cgi-bin/cmudict) to override the pronunciation of words, like this: *She made a little bow, then she picked up her {B OW}.*\n", 48 | "* If you're running out of memory generating lines, try to work with shorter clips.\n", 49 | "* The singing models are trained on very little data, and can have a hard time pronouncing certain words. Try experimenting with ARPABET and punctuation.\n", 50 | "* If the voice is off-key, the problem is usually with the extracted pitch. Press \"Debug pitch\" to listen to it. Reference audio with lots of echo/reverb or background noise, or singers with a very high vocal range can cause issues.\n", 51 | "* If the singing voice sounds strained, try enabling \"Change input pitch\" and adjusting it up or down a few semitones. If you're remixing a song, remember to pitch-shift your background track as well.\n", 52 | "\n", 53 | ">Maintained by: `justinjohn0306`\n", 54 | "\n", 55 | ">Special thanks to: `Tapiocapioca#6641` and `effusiveperiscope`" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "metadata": { 61 | "id": "LCqXqFgP2ri0", 62 | "cellView": "form" 63 | }, 64 | "source": [ 65 | "#@markdown **Step 1:** Check which GPU you've been allocated.\n", 66 | "\n", 67 | "!nvidia-smi -L\n", 68 | "!nvidia-smi" 69 | ], 70 | "execution_count": null, 71 | "outputs": [] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "metadata": { 76 | "id": "teF-Ut8Z7Gjp", 77 | "cellView": "form" 78 | }, 79 | "source": [ 80 | "#@markdown **Step 2:** Download dependencies.\n", 81 | "\n", 82 | "#@markdown >### Note: The runtime will crash after this cell completes, which is intended.\n", 83 | "\n", 84 | "import nltk\n", 85 | "nltk.download('averaged_perceptron_tagger_eng')\n", 86 | "\n", 87 | "import os\n", 88 | "\n", 89 | "custom_lists = [\n", 90 | " #\"https://gist.githubusercontent.com/SortAnon/997cda157954a189259c9876fd804e53/raw/example_models.json\",\n", 91 | "]\n", 92 | "\n", 93 | "!apt-get install sox libsndfile1 ffmpeg\n", 94 | "\n", 95 | "# 3.10 pytorch-lightning fix\n", 96 | "!pip install -U torch --index-url https://download.pytorch.org/whl/cu118\n", 97 | "!pip install pytorch-lightning==1.6.5\n", 98 | "!pip install torchmetrics==0.11.4\n", 99 | "# 3.10 fix\n", 100 | "!pip install numpy==1.23.5 scipy==1.10.1 librosa==0.8.1\n", 101 | "!pip install tensorflow dash==1.21.0 dash-bootstrap-components==0.13.0 jupyter-dash==0.4.0 psola wget unidecode pysptk frozendict torchvision torchaudio torchtext torch_stft kaldiio pydub pyannote.audio g2p_en pesq pystoi crepe resampy ffmpeg-python torchcrepe einops taming-transformers-rom1504==0.0.6 tensorflow-hub\n", 102 | "!pip uninstall gdown -y\n", 103 | "!pip install git+https://github.com/wkentaro/gdown.git\n", 104 | "!python -m pip install git+https://github.com/effusiveperiscope/NeMo.git\n", 105 | "if not os.path.exists(\"hifi-gan\"):\n", 106 | " !git clone -q --recursive https://github.com/justinjohn0306/talknet-hifi-gan hifi-gan\n", 107 | "!git clone -q https://github.com/effusiveperiscope/ControllableTalkNet\n", 108 | "os.chdir(\"/content/ControllableTalkNet\")\n", 109 | "!git archive --output=./files.tar --format=tar HEAD\n", 110 | "os.chdir(\"/content\")\n", 111 | "!tar xf ControllableTalkNet/files.tar\n", 112 | "!rm -rf ControllableTalkNet\n", 113 | "\n", 114 | "# 3.10 werkzeug fix\n", 115 | "!python -m pip install werkzeug==2.0.0 flask==2.1.2\n", 116 | "# pytorch cuda fix???\n", 117 | "!pip install -U torch torchtext torchvision torchaudio torch_stft torchcrepe --index-url https://download.pytorch.org/whl/cu121\n", 118 | "!pip install librosa==0.8.1\n", 119 | "\n", 120 | "os.chdir(\"/content/model_lists\")\n", 121 | "for c in custom_lists:\n", 122 | " !wget \"{c}\"\n", 123 | "os.chdir(\"/content\")\n", 124 | "\n", 125 | "\n", 126 | "# restart the runtime\n", 127 | "os.kill(os.getpid(), 9)\n" 128 | ], 129 | "execution_count": null, 130 | "outputs": [] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "metadata": { 135 | "id": "tOXejargIPTq", 136 | "cellView": "form" 137 | }, 138 | "source": [ 139 | "# @markdown **Step 3:** **Start GUI**\n", 140 | "using_inline = True\n", 141 | "import pkg_resources\n", 142 | "from pkg_resources import DistributionNotFound, VersionConflict\n", 143 | "\"\"\"dependencies = [\n", 144 | "\"tensorflow==2.4.1\",\n", 145 | "\"dash\",\n", 146 | "\"jupyter-dash\",\n", 147 | "\"psola\",\n", 148 | "\"wget\",\n", 149 | "\"unidecode\",\n", 150 | "\"pysptk\",\n", 151 | "\"frozendict\",\n", 152 | "\"torchvision==0.9.1\",\n", 153 | "\"torchaudio==0.8.1\",\n", 154 | "\"torchtext==0.9.1\",\n", 155 | "\"torch_stft\",\n", 156 | "\"kaldiio\",\n", 157 | "\"pydub\",\n", 158 | "\"pyannote.audio\",\n", 159 | "\"g2p_en\",\n", 160 | "\"pesq\",\n", 161 | "\"pystoi\",\n", 162 | "\"crepe\",\n", 163 | "\"resampy\",\n", 164 | "\"ffmpeg-python\",\n", 165 | "\"numpy\",\n", 166 | "\"scipy\",\n", 167 | "\"nemo_toolkit\",\n", 168 | "\"tqdm\",\n", 169 | "\"gdown\",\n", 170 | "]\n", 171 | "pkg_resources.require(dependencies)\"\"\"\n", 172 | "\n", 173 | "from controllable_talknet import *\n", 174 | "app.run_server(\n", 175 | " mode=\"inline\",\n", 176 | " #dev_tools_ui=True,\n", 177 | " #dev_tools_hot_reload=True,\n", 178 | " threaded=True,\n", 179 | ")" 180 | ], 181 | "execution_count": null, 182 | "outputs": [] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "metadata": { 187 | "cellView": "form", 188 | "id": "F4WDzfRIgb5Z" 189 | }, 190 | "source": [ 191 | "# @markdown **Step 3B:** If the above fails with a 403 error, do the following:\n", 192 | "# @markdown * Go to Runtime -> Restart runtime\n", 193 | "# @markdown * Run this cell (click the play button)\n", 194 | "# @markdown * Click on the googleusercontent.com link to use TalkNet in a separate tab\n", 195 | "try:\n", 196 | " using_inline\n", 197 | "except:\n", 198 | " using_inline = False\n", 199 | "if not using_inline:\n", 200 | " from controllable_talknet import *\n", 201 | " from google.colab.output import eval_js\n", 202 | "\n", 203 | " print(eval_js(\"google.colab.kernel.proxyPort(8050)\"))\n", 204 | " app.run_server(\n", 205 | " mode=\"external\",\n", 206 | " debug=False,\n", 207 | " #dev_tools_ui=True,\n", 208 | " #dev_tools_hot_reload=True,\n", 209 | " threaded=True,\n", 210 | " )" 211 | ], 212 | "execution_count": null, 213 | "outputs": [] 214 | } 215 | ] 216 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TalkNET-colab 2 | 3 | **TalkNET Model Training Notebook**: Open In Colab 4 | 5 | **TalkNET Synthesis Notebook (when you have a model)**: Open In Colab 6 | 7 | 8 | Taken from [TalkNET](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/models/talknet.py) and made trainable on colab by [SortAnon](https://github.com/SortAnon/ControllableTalkNet) and fixes made by [Justin John](https://github.com/justinjohn0306). 9 | 10 | If you are looking for assistance, feel free to join [discord](https://discord.com/invite/NhJZGtH) and ask in the tech support there. 11 | -------------------------------------------------------------------------------- /TalkNet_Training_A100.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "include_colab_link": true 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": { 32 | "id": "Gss5Ox_RNiba" 33 | }, 34 | "source": [ 35 | "# TalkNet Training\n", 36 | "Last updated: 2023-05-17\n", 37 | "\n", 38 | "To train a 22KHz TalkNet, run the cells below and follow the instructions.\n", 39 | "\n", 40 | "This will take a while, and you might have to do it in multiple Colab sessions. The notebook will automatically resume training any models from the last saved checkpoint. If you're resuming from a new session, always re-run steps 1 through 5 first.\n", 41 | "\n", 42 | "##**IMPORTANT:**\n", 43 | "Your Trash folder on Drive will fill up with old checkpoints\n", 44 | "as you train the various models. Keep an eye on your Drive storage, and empty the trash if it starts to become full.\n", 45 | "\n", 46 | "- Fixes by ``justinjohn0306`` and ``Tapiocapioca#6641``" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "metadata": { 52 | "id": "LCqXqFgP2ri0", 53 | "cellView": "form" 54 | }, 55 | "source": [ 56 | "#@markdown **Step 1:** Check which GPU you've been allocated.\n", 57 | "\n", 58 | "#@markdown You want a P100, V100, T4, V100 or A100.\n", 59 | "!nvidia-smi -L\n", 60 | "!nvidia-smi" 61 | ], 62 | "execution_count": null, 63 | "outputs": [] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "metadata": { 68 | "id": "j9TI-Q6m3qlx", 69 | "cellView": "form" 70 | }, 71 | "source": [ 72 | "#@markdown **Step 2:** Mount Google Drive.\n", 73 | "from google.colab import drive\n", 74 | "drive.mount(\"/content/drive\")" 75 | ], 76 | "execution_count": null, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "metadata": { 82 | "id": "nfSawDUD5tqv", 83 | "cellView": "form" 84 | }, 85 | "source": [ 86 | "#@markdown **Step 3:** Configure training data paths. Upload the following to your Drive and change the paths below:\n", 87 | "#@markdown * A dataset of .wav files, packaged as a .zip or .tar file\n", 88 | "#@markdown * Training and validation filelists, in LJSpeech format with relative paths (note: ARPABET transcripts are not supported)\n", 89 | "#@markdown * An output path for checkpoints\n", 90 | "\n", 91 | "import os\n", 92 | "\n", 93 | "dataset = \"/content/drive/My Drive/path_to_dataset.zip\" #@param {type:\"string\"}\n", 94 | "train_filelist = \"/content/drive/My Drive/train_filelist.txt\" #@param {type:\"string\"}\n", 95 | "val_filelist = \"/content/drive/My Drive/val_filelist.txt\" #@param {type:\"string\"}\n", 96 | "output_dir = \"/content/drive/My Drive/talknet/name_of_character\" #@param {type:\"string\"}\n", 97 | "assert os.path.exists(dataset), \"Cannot find dataset\"\n", 98 | "assert os.path.exists(train_filelist), \"Cannot find training filelist\"\n", 99 | "assert os.path.exists(val_filelist), \"Cannot find validation filelist\"\n", 100 | "if not os.path.exists(output_dir):\n", 101 | " os.makedirs(output_dir)\n", 102 | "print(\"OK\")\n" 103 | ], 104 | "execution_count": null, 105 | "outputs": [] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "source": [ 110 | " **Note: The runtime will crash after step 4, which is intentional.** **Just re-run step 3 and proceed to step 5 and further.**" 111 | ], 112 | "metadata": { 113 | "id": "d02YMYEnqsGJ" 114 | } 115 | }, 116 | { 117 | "cell_type": "code", 118 | "metadata": { 119 | "id": "teF-Ut8Z7Gjp", 120 | "cellView": "form" 121 | }, 122 | "source": [ 123 | "#@markdown **Step 4:** Download NVIDIA NeMo.\n", 124 | "!pip install pip==24.0\n", 125 | "!pip uninstall torch torchvision torchaudio -y\n", 126 | "!pip uninstall gdown -y\n", 127 | "!pip install git+https://github.com/wkentaro/gdown.git\n", 128 | "import os\n", 129 | "import time\n", 130 | "import gdown\n", 131 | "\n", 132 | "os.chdir('/content')\n", 133 | "!apt-get install sox libsndfile1 ffmpeg\n", 134 | "!pip install \"cython<3.0.0\" && pip install --no-build-isolation \"pyyaml<6.0\" && pip install --no-build-isolation \"pysptk==0.2.0\"\n", 135 | "!pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117\n", 136 | "!pip install wget unidecode tensorflow==2.9 tensorboardX pysptk frozendict torch_stft pytorch-lightning==1.3.8 kaldiio pydub pyannote.audio g2p_en pesq pystoi crepe ffmpeg-python\n", 137 | "!python -m pip install git+https://github.com/SortAnon/NeMo.git\n", 138 | "!git clone -q https://github.com/SortAnon/hifi-gan.git\n", 139 | "!pip install --pre torchtext==0.6.0 --no-deps --quiet\n", 140 | "!pip install -q numpy --upgrade torchmetrics==0.6 omegaconf==2.2.3 hmmlearn==0.2.7 crepe==0.0.12 tensorboard==2.9 librosa==0.9.1 protobuf==3.20.0 torch_stft==0.1.4\n", 141 | "\n", 142 | "\n", 143 | "!mkdir -p conf && cd conf \\\n", 144 | "&& wget https://raw.githubusercontent.com/SortAnon/NeMo/main/examples/tts/conf/talknet-durs.yaml \\\n", 145 | "&& wget https://raw.githubusercontent.com/SortAnon/NeMo/main/examples/tts/conf/talknet-pitch.yaml \\\n", 146 | "&& wget https://raw.githubusercontent.com/SortAnon/NeMo/main/examples/tts/conf/talknet-spect.yaml \\\n", 147 | "&& cd ..\n", 148 | "\n", 149 | "# Download pre-trained models\n", 150 | "zip_path = \"tts_en_talknet_1.0.0rc1.zip\"\n", 151 | "for i in range(10):\n", 152 | " if not os.path.exists(zip_path) or os.stat(zip_path).st_size < 100:\n", 153 | " gdown.download(\n", 154 | " \"https://drive.google.com/uc?id=19wSym9mNEnmzLS9XdPlfNAW9_u-mP1hR\",\n", 155 | " zip_path,\n", 156 | " quiet=False,\n", 157 | " )\n", 158 | "!unzip -qo {zip_path}\n", 159 | "\n", 160 | "# restart the runtime\n", 161 | "os.kill(os.getpid(), 9)" 162 | ], 163 | "execution_count": null, 164 | "outputs": [] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "metadata": { 169 | "id": "bxFr3Fdi_kOC", 170 | "cellView": "form" 171 | }, 172 | "source": [ 173 | "#@markdown **Step 5:** Dataset processing, part 1.\n", 174 | "\n", 175 | "#@markdown If this step fails, try the following:\n", 176 | "#@markdown * Make sure your filelists are correct. They should have relative\n", 177 | "#@markdown paths that match the contents of the archive.\n", 178 | "\n", 179 | "import os\n", 180 | "import shutil\n", 181 | "import sys\n", 182 | "import json\n", 183 | "import nemo\n", 184 | "import torch\n", 185 | "import torchaudio\n", 186 | "import numpy as np\n", 187 | "from pysptk import sptk\n", 188 | "from pathlib import Path\n", 189 | "from tqdm.notebook import tqdm\n", 190 | "import ffmpeg\n", 191 | "\n", 192 | "def fix_transcripts(inpath):\n", 193 | " found_arpabet = False\n", 194 | " found_grapheme = False\n", 195 | " with open(inpath, \"r\", encoding=\"utf8\") as f:\n", 196 | " lines = f.readlines()\n", 197 | " with open(inpath, \"w\", encoding=\"utf8\") as f:\n", 198 | " for l in lines:\n", 199 | " if l.strip() == \"\":\n", 200 | " continue\n", 201 | " if \"{\" in l:\n", 202 | " if not found_arpabet:\n", 203 | " print(\"Warning: Skipping ARPABET lines (not supported).\")\n", 204 | " found_arpabet = True\n", 205 | " else:\n", 206 | " f.write(l)\n", 207 | " found_grapheme = True\n", 208 | " assert found_grapheme, \"No non-ARPABET lines found in \" + inpath\n", 209 | "\n", 210 | "def generate_json(inpath, outpath):\n", 211 | " output = \"\"\n", 212 | " sample_rate = 22050\n", 213 | " with open(inpath, \"r\", encoding=\"utf8\") as f:\n", 214 | " for l in f.readlines():\n", 215 | " lpath = l.split(\"|\")[0].strip()\n", 216 | " if lpath[:5] != \"wavs/\":\n", 217 | " lpath = \"wavs/\" + lpath\n", 218 | " size = os.stat(\n", 219 | " os.path.join(os.path.dirname(inpath), lpath)\n", 220 | " ).st_size\n", 221 | " x = {\n", 222 | " \"audio_filepath\": lpath,\n", 223 | " \"duration\": size / (sample_rate * 2),\n", 224 | " \"text\": l.split(\"|\")[1].strip(),\n", 225 | " }\n", 226 | " output += json.dumps(x) + \"\\n\"\n", 227 | " with open(outpath, \"w\", encoding=\"utf8\") as w:\n", 228 | " w.write(output)\n", 229 | "\n", 230 | "def convert_to_22k(inpath):\n", 231 | " if inpath.strip()[-4:].lower() != \".wav\":\n", 232 | " print(\"Warning: \" + inpath.strip() + \" is not a .wav file!\")\n", 233 | " return\n", 234 | " ffmpeg.input(inpath).output(\n", 235 | " inpath + \"_22k.wav\",\n", 236 | " ar=\"22050\",\n", 237 | " ac=\"1\",\n", 238 | " acodec=\"pcm_s16le\",\n", 239 | " map_metadata=\"-1\",\n", 240 | " fflags=\"+bitexact\",\n", 241 | " ).overwrite_output().run(quiet=True)\n", 242 | " os.remove(inpath)\n", 243 | " os.rename(inpath + \"_22k.wav\", inpath)\n", 244 | "\n", 245 | "# Extract dataset\n", 246 | "os.chdir('/content')\n", 247 | "if os.path.exists(\"/content/wavs\"):\n", 248 | " shutil.rmtree(\"/content/wavs\")\n", 249 | "os.mkdir(\"wavs\")\n", 250 | "os.chdir(\"wavs\")\n", 251 | "if dataset[-4:] == \".zip\":\n", 252 | " !unzip -q \"{dataset}\"\n", 253 | "elif dataset[-4:] == \".tar\":\n", 254 | " !tar -xf \"{dataset}\"\n", 255 | "else:\n", 256 | " raise Exception(\"Unknown extension for dataset\")\n", 257 | "if os.path.exists(\"/content/wavs/wavs\"):\n", 258 | " shutil.move(\"/content/wavs/wavs\", \"/content/tempwavs\")\n", 259 | " shutil.rmtree(\"/content/wavs\")\n", 260 | " shutil.move(\"/content/tempwavs\", \"/content/wavs\")\n", 261 | "\n", 262 | "# Filelist for preprocessing\n", 263 | "os.chdir('/content')\n", 264 | "shutil.copy(train_filelist, \"trainfiles.txt\")\n", 265 | "shutil.copy(val_filelist, \"valfiles.txt\")\n", 266 | "fix_transcripts(\"trainfiles.txt\")\n", 267 | "fix_transcripts(\"valfiles.txt\")\n", 268 | "seen_files = []\n", 269 | "with open(\"trainfiles.txt\", encoding=\"utf-8\") as f:\n", 270 | " t = f.read().split(\"\\n\")\n", 271 | "with open(\"valfiles.txt\", encoding=\"utf-8\") as f:\n", 272 | " v = f.read().split(\"\\n\")\n", 273 | " all_filelist = t[:] + v[:]\n", 274 | "with open(\"/content/allfiles.txt\", \"w\", encoding=\"utf-8\") as f:\n", 275 | " for x in all_filelist:\n", 276 | " if x.strip() == \"\":\n", 277 | " continue\n", 278 | " if x.split(\"|\")[0] not in seen_files:\n", 279 | " seen_files.append(x.split(\"|\")[0])\n", 280 | " f.write(x.strip() + \"\\n\")\n", 281 | "\n", 282 | "# Ensure audio is 22k\n", 283 | "print(\"Converting audio...\")\n", 284 | "for r, _, f in os.walk(\"/content/wavs\"):\n", 285 | " for name in tqdm(f):\n", 286 | " convert_to_22k(os.path.join(r, name))\n", 287 | "\n", 288 | "# Convert to JSON\n", 289 | "generate_json(\"trainfiles.txt\", \"trainfiles.json\")\n", 290 | "generate_json(\"valfiles.txt\", \"valfiles.json\")\n", 291 | "generate_json(\"allfiles.txt\", \"allfiles.json\")\n", 292 | "\n", 293 | "print(\"OK\")" 294 | ], 295 | "execution_count": null, 296 | "outputs": [] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "metadata": { 301 | "id": "sos9vsxPkIN7", 302 | "cellView": "form" 303 | }, 304 | "source": [ 305 | "#@markdown **Step 6:** Dataset processing, part 2. This takes a while, but\n", 306 | "#@markdown you only have to run this once per dataset (results are saved to Drive).\n", 307 | "\n", 308 | "#@markdown If this step fails, try the following:\n", 309 | "#@markdown * Make sure your dataset only contains WAV files.\n", 310 | "\n", 311 | "# Extract phoneme duration\n", 312 | "\n", 313 | "import nltk\n", 314 | "nltk.download('averaged_perceptron_tagger_eng')\n", 315 | "\n", 316 | "import json\n", 317 | "from nemo.collections.asr.models import EncDecCTCModel\n", 318 | "asr_model = EncDecCTCModel.from_pretrained(model_name=\"asr_talknet_aligner\").cpu().eval()\n", 319 | "\n", 320 | "def forward_extractor(tokens, log_probs, blank):\n", 321 | " \"\"\"Computes states f and p.\"\"\"\n", 322 | " n, m = len(tokens), log_probs.shape[0]\n", 323 | " # `f[s, t]` -- max sum of log probs for `s` first codes\n", 324 | " # with `t` first timesteps with ending in `tokens[s]`.\n", 325 | " f = np.empty((n + 1, m + 1), dtype=float)\n", 326 | " f.fill(-(10 ** 9))\n", 327 | " p = np.empty((n + 1, m + 1), dtype=int)\n", 328 | " f[0, 0] = 0.0 # Start\n", 329 | " for s in range(1, n + 1):\n", 330 | " c = tokens[s - 1]\n", 331 | " for t in range((s + 1) // 2, m + 1):\n", 332 | " f[s, t] = log_probs[t - 1, c]\n", 333 | " # Option #1: prev char is equal to current one.\n", 334 | " if s == 1 or c == blank or c == tokens[s - 3]:\n", 335 | " options = f[s : (s - 2 if s > 1 else None) : -1, t - 1]\n", 336 | " else: # Is not equal to current one.\n", 337 | " options = f[s : (s - 3 if s > 2 else None) : -1, t - 1]\n", 338 | " f[s, t] += np.max(options)\n", 339 | " p[s, t] = np.argmax(options)\n", 340 | " return f, p\n", 341 | "\n", 342 | "\n", 343 | "def backward_extractor(f, p):\n", 344 | " \"\"\"Computes durs from f and p.\"\"\"\n", 345 | " n, m = f.shape\n", 346 | " n -= 1\n", 347 | " m -= 1\n", 348 | " durs = np.zeros(n, dtype=int)\n", 349 | " if f[-1, -1] >= f[-2, -1]:\n", 350 | " s, t = n, m\n", 351 | " else:\n", 352 | " s, t = n - 1, m\n", 353 | " while s > 0:\n", 354 | " durs[s - 1] += 1\n", 355 | " s -= p[s, t]\n", 356 | " t -= 1\n", 357 | " assert durs.shape[0] == n\n", 358 | " assert np.sum(durs) == m\n", 359 | " assert np.all(durs[1::2] > 0)\n", 360 | " return durs\n", 361 | "\n", 362 | "def preprocess_tokens(tokens, blank):\n", 363 | " new_tokens = [blank]\n", 364 | " for c in tokens:\n", 365 | " new_tokens.extend([c, blank])\n", 366 | " tokens = new_tokens\n", 367 | " return tokens\n", 368 | "\n", 369 | "data_config = {\n", 370 | " 'manifest_filepath': \"allfiles.json\",\n", 371 | " 'sample_rate': 22050,\n", 372 | " 'labels': asr_model.decoder.vocabulary,\n", 373 | " 'batch_size': 1,\n", 374 | "}\n", 375 | "\n", 376 | "parser = nemo.collections.asr.data.audio_to_text.AudioToCharWithDursF0Dataset.make_vocab(\n", 377 | " notation='phonemes', punct=True, spaces=True, stresses=False, add_blank_at=\"last\"\n", 378 | ")\n", 379 | "\n", 380 | "dataset = nemo.collections.asr.data.audio_to_text._AudioTextDataset(\n", 381 | " manifest_filepath=data_config['manifest_filepath'], sample_rate=data_config['sample_rate'], parser=parser,\n", 382 | ")\n", 383 | "\n", 384 | "dl = torch.utils.data.DataLoader(\n", 385 | " dataset=dataset, batch_size=data_config['batch_size'], collate_fn=dataset.collate_fn, shuffle=False,\n", 386 | ")\n", 387 | "\n", 388 | "blank_id = asr_model.decoder.num_classes_with_blank - 1\n", 389 | "\n", 390 | "if os.path.exists(os.path.join(output_dir, \"durations.pt\")):\n", 391 | " print(\"durations.pt already exists; skipping\")\n", 392 | "else:\n", 393 | " dur_data = {}\n", 394 | " for sample_idx, test_sample in tqdm(enumerate(dl), total=len(dl)):\n", 395 | " log_probs, _, greedy_predictions = asr_model(\n", 396 | " input_signal=test_sample[0], input_signal_length=test_sample[1]\n", 397 | " )\n", 398 | "\n", 399 | " log_probs = log_probs[0].cpu().detach().numpy()\n", 400 | " seq_ids = test_sample[2][0].cpu().detach().numpy()\n", 401 | "\n", 402 | " target_tokens = preprocess_tokens(seq_ids, blank_id)\n", 403 | "\n", 404 | " f, p = forward_extractor(target_tokens, log_probs, blank_id)\n", 405 | " durs = backward_extractor(f, p)\n", 406 | "\n", 407 | " dur_key = Path(dl.dataset.collection[sample_idx].audio_file).stem\n", 408 | " dur_data[dur_key] = {\n", 409 | " 'blanks': torch.tensor(durs[::2], dtype=torch.long).cpu().detach(),\n", 410 | " 'tokens': torch.tensor(durs[1::2], dtype=torch.long).cpu().detach()\n", 411 | " }\n", 412 | "\n", 413 | " del test_sample\n", 414 | "\n", 415 | " torch.save(dur_data, os.path.join(output_dir, \"durations.pt\"))\n", 416 | "\n", 417 | "#Extract F0 (pitch)\n", 418 | "import crepe\n", 419 | "from scipy.io import wavfile\n", 420 | "\n", 421 | "def crepe_f0(audio_file, hop_length=256):\n", 422 | " sr, audio = wavfile.read(audio_file)\n", 423 | " audio_x = np.arange(0, len(audio)) / 22050.0\n", 424 | " time, frequency, confidence, activation = crepe.predict(audio, sr, viterbi=True)\n", 425 | "\n", 426 | " x = np.arange(0, len(audio), hop_length) / 22050.0\n", 427 | " freq_interp = np.interp(x, time, frequency)\n", 428 | " conf_interp = np.interp(x, time, confidence)\n", 429 | " audio_interp = np.interp(x, audio_x, np.absolute(audio)) / 32768.0\n", 430 | " weights = [0.5, 0.25, 0.25]\n", 431 | " audio_smooth = np.convolve(audio_interp, np.array(weights)[::-1], \"same\")\n", 432 | "\n", 433 | " conf_threshold = 0.25\n", 434 | " audio_threshold = 0.0005\n", 435 | " for i in range(len(freq_interp)):\n", 436 | " if conf_interp[i] < conf_threshold:\n", 437 | " freq_interp[i] = 0.0\n", 438 | " if audio_smooth[i] < audio_threshold:\n", 439 | " freq_interp[i] = 0.0\n", 440 | "\n", 441 | " # Hack to make f0 and mel lengths equal\n", 442 | " if len(audio) % hop_length == 0:\n", 443 | " freq_interp = np.pad(freq_interp, pad_width=[0, 1])\n", 444 | " return torch.from_numpy(freq_interp.astype(np.float32))\n", 445 | "\n", 446 | "if os.path.exists(os.path.join(output_dir, \"f0s.pt\")):\n", 447 | " print(\"f0s.pt already exists; skipping\")\n", 448 | "else:\n", 449 | " f0_data = {}\n", 450 | " with open(\"allfiles.json\") as f:\n", 451 | " for i, l in enumerate(f.readlines()):\n", 452 | " print(str(i))\n", 453 | " audio_path = json.loads(l)[\"audio_filepath\"]\n", 454 | " f0_data[Path(audio_path).stem] = crepe_f0(audio_path)\n", 455 | "\n", 456 | " # calculate f0 stats (mean & std) only for train set\n", 457 | " with open(\"trainfiles.json\") as f:\n", 458 | " train_ids = {Path(json.loads(l)[\"audio_filepath\"]).stem for l in f}\n", 459 | " all_f0 = torch.cat([f0[f0 >= 1e-5] for f0_id, f0 in f0_data.items() if f0_id in train_ids])\n", 460 | "\n", 461 | " F0_MEAN, F0_STD = all_f0.mean().item(), all_f0.std().item()\n", 462 | " print(\"F0_MEAN: \" + str(F0_MEAN) + \", F0_STD: \" + str(F0_STD))\n", 463 | " torch.save(f0_data, os.path.join(output_dir, \"f0s.pt\"))\n", 464 | " with open(os.path.join(output_dir, \"f0_info.json\"), \"w\") as f:\n", 465 | " f.write(json.dumps({\"FO_MEAN\": F0_MEAN, \"F0_STD\": F0_STD}))" 466 | ], 467 | "execution_count": null, 468 | "outputs": [] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "metadata": { 473 | "id": "nM7-bMpKO7U2", 474 | "cellView": "form" 475 | }, 476 | "source": [ 477 | "#@markdown **Step 7:** Train duration predictor.\n", 478 | "\n", 479 | "#@markdown If CUDA runs out of memory, try the following:\n", 480 | "#@markdown * Click on Runtime -> Restart runtime, re-run step 3, and try again.\n", 481 | "#@markdown * If that doesn't help, reduce the batch size (default 64).\n", 482 | "batch_size = 64 #@param {type:\"integer\"}\n", 483 | "\n", 484 | "epochs = 20\n", 485 | "learning_rate = 1e-3\n", 486 | "min_learning_rate = 3e-6\n", 487 | "load_checkpoints = True\n", 488 | "\n", 489 | "import os\n", 490 | "from hydra.experimental import compose, initialize\n", 491 | "from hydra.core.global_hydra import GlobalHydra\n", 492 | "from omegaconf import OmegaConf\n", 493 | "import pytorch_lightning as pl\n", 494 | "from nemo.collections.common.callbacks import LogEpochTimeCallback\n", 495 | "from nemo.collections.tts.models import TalkNetDursModel\n", 496 | "from nemo.core.config import hydra_runner\n", 497 | "from nemo.utils.exp_manager import exp_manager\n", 498 | "\n", 499 | "def train(cfg):\n", 500 | " cfg.sample_rate = 22050\n", 501 | " cfg.train_dataset = \"trainfiles.json\"\n", 502 | " cfg.validation_datasets = \"valfiles.json\"\n", 503 | " cfg.durs_file = os.path.join(output_dir, \"durations.pt\")\n", 504 | " cfg.f0_file = os.path.join(output_dir, \"f0s.pt\")\n", 505 | " cfg.trainer.accelerator = \"dp\"\n", 506 | " cfg.trainer.max_epochs = epochs\n", 507 | " cfg.trainer.check_val_every_n_epoch = 5\n", 508 | " cfg.model.train_ds.dataloader_params.batch_size = batch_size\n", 509 | " cfg.model.validation_ds.dataloader_params.batch_size = batch_size\n", 510 | " cfg.model.optim.lr = learning_rate\n", 511 | " cfg.model.optim.sched.min_lr = min_learning_rate\n", 512 | " cfg.exp_manager.exp_dir = output_dir\n", 513 | "\n", 514 | " # Find checkpoints\n", 515 | " ckpt_path = \"\"\n", 516 | " if load_checkpoints:\n", 517 | " path0 = os.path.join(output_dir, \"TalkNetDurs\")\n", 518 | " if os.path.exists(path0):\n", 519 | " path1 = sorted(os.listdir(path0))\n", 520 | " for i in range(len(path1)):\n", 521 | " path2 = os.path.join(path0, path1[-(1+i)], \"checkpoints\")\n", 522 | " if os.path.exists(path2):\n", 523 | " match = [x for x in os.listdir(path2) if \"last.ckpt\" in x]\n", 524 | " if len(match) > 0:\n", 525 | " ckpt_path = os.path.join(path2, match[0])\n", 526 | " print(\"Resuming training from \" + match[0])\n", 527 | " break\n", 528 | "\n", 529 | " if ckpt_path != \"\":\n", 530 | " trainer = pl.Trainer(**cfg.trainer, resume_from_checkpoint = ckpt_path)\n", 531 | " model = TalkNetDursModel(cfg=cfg.model, trainer=trainer)\n", 532 | " else:\n", 533 | " warmstart_path = \"/content/talknet_durs.nemo\"\n", 534 | " trainer = pl.Trainer(**cfg.trainer)\n", 535 | " model = TalkNetDursModel.restore_from(warmstart_path, override_config_path=cfg)\n", 536 | " model.set_trainer(trainer)\n", 537 | " model.setup_training_data(cfg.model.train_ds)\n", 538 | " model.setup_validation_data(cfg.model.validation_ds)\n", 539 | " model.setup_optimization(cfg.model.optim)\n", 540 | " print(\"Warm-starting from \" + warmstart_path)\n", 541 | " exp_manager(trainer, cfg.get('exp_manager', None))\n", 542 | " trainer.callbacks.extend([pl.callbacks.LearningRateMonitor(), LogEpochTimeCallback()]) # noqa\n", 543 | " trainer.fit(model)\n", 544 | "\n", 545 | "GlobalHydra().clear()\n", 546 | "initialize(config_path=\"conf\")\n", 547 | "cfg = compose(config_name=\"talknet-durs\")\n", 548 | "train(cfg)\n" 549 | ], 550 | "execution_count": null, 551 | "outputs": [] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "metadata": { 556 | "id": "JLfm00NuJfon", 557 | "cellView": "form" 558 | }, 559 | "source": [ 560 | "#@markdown **Step 8:** Train pitch predictor.\n", 561 | "\n", 562 | "#@markdown If CUDA runs out of memory, try the following:\n", 563 | "#@markdown * Click on Runtime -> Restart runtime, re-run step 3, and try again.\n", 564 | "#@markdown * If that doesn't help, reduce the batch size (default 64).\n", 565 | "batch_size = 64 #@param {type:\"integer\"}\n", 566 | "epochs = 50\n", 567 | "\n", 568 | "import json\n", 569 | "\n", 570 | "with open(os.path.join(output_dir, \"f0_info.json\"), \"r\") as f:\n", 571 | " f0_info = json.load(f)\n", 572 | " f0_mean = f0_info[\"FO_MEAN\"]\n", 573 | " f0_std = f0_info[\"F0_STD\"]\n", 574 | "\n", 575 | "learning_rate = 1e-3\n", 576 | "min_learning_rate = 3e-6\n", 577 | "load_checkpoints = True\n", 578 | "\n", 579 | "import os\n", 580 | "from hydra.experimental import compose, initialize\n", 581 | "from hydra.core.global_hydra import GlobalHydra\n", 582 | "from omegaconf import OmegaConf\n", 583 | "import pytorch_lightning as pl\n", 584 | "from nemo.collections.common.callbacks import LogEpochTimeCallback\n", 585 | "from nemo.collections.tts.models import TalkNetPitchModel\n", 586 | "from nemo.core.config import hydra_runner\n", 587 | "from nemo.utils.exp_manager import exp_manager\n", 588 | "\n", 589 | "def train(cfg):\n", 590 | " cfg.sample_rate = 22050\n", 591 | " cfg.train_dataset = \"trainfiles.json\"\n", 592 | " cfg.validation_datasets = \"valfiles.json\"\n", 593 | " cfg.durs_file = os.path.join(output_dir, \"durations.pt\")\n", 594 | " cfg.f0_file = os.path.join(output_dir, \"f0s.pt\")\n", 595 | " cfg.trainer.accelerator = \"dp\"\n", 596 | " cfg.trainer.max_epochs = epochs\n", 597 | " cfg.trainer.check_val_every_n_epoch = 5\n", 598 | " cfg.model.f0_mean=f0_mean\n", 599 | " cfg.model.f0_std=f0_std\n", 600 | " cfg.model.train_ds.dataloader_params.batch_size = batch_size\n", 601 | " cfg.model.validation_ds.dataloader_params.batch_size = batch_size\n", 602 | " cfg.model.optim.lr = learning_rate\n", 603 | " cfg.model.optim.sched.min_lr = min_learning_rate\n", 604 | " cfg.exp_manager.exp_dir = output_dir\n", 605 | "\n", 606 | " # Find checkpoints\n", 607 | " ckpt_path = \"\"\n", 608 | " if load_checkpoints:\n", 609 | " path0 = os.path.join(output_dir, \"TalkNetPitch\")\n", 610 | " if os.path.exists(path0):\n", 611 | " path1 = sorted(os.listdir(path0))\n", 612 | " for i in range(len(path1)):\n", 613 | " path2 = os.path.join(path0, path1[-(1+i)], \"checkpoints\")\n", 614 | " if os.path.exists(path2):\n", 615 | " match = [x for x in os.listdir(path2) if \"last.ckpt\" in x]\n", 616 | " if len(match) > 0:\n", 617 | " ckpt_path = os.path.join(path2, match[0])\n", 618 | " print(\"Resuming training from \" + match[0])\n", 619 | " break\n", 620 | "\n", 621 | " if ckpt_path != \"\":\n", 622 | " trainer = pl.Trainer(**cfg.trainer, resume_from_checkpoint = ckpt_path)\n", 623 | " model = TalkNetPitchModel(cfg=cfg.model, trainer=trainer)\n", 624 | " else:\n", 625 | " warmstart_path = \"/content/talknet_pitch.nemo\"\n", 626 | " trainer = pl.Trainer(**cfg.trainer)\n", 627 | " model = TalkNetPitchModel.restore_from(warmstart_path, override_config_path=cfg)\n", 628 | " model.set_trainer(trainer)\n", 629 | " model.setup_training_data(cfg.model.train_ds)\n", 630 | " model.setup_validation_data(cfg.model.validation_ds)\n", 631 | " model.setup_optimization(cfg.model.optim)\n", 632 | " print(\"Warm-starting from \" + warmstart_path)\n", 633 | " exp_manager(trainer, cfg.get('exp_manager', None))\n", 634 | " trainer.callbacks.extend([pl.callbacks.LearningRateMonitor(), LogEpochTimeCallback()]) # noqa\n", 635 | " trainer.fit(model)\n", 636 | "\n", 637 | "GlobalHydra().clear()\n", 638 | "initialize(config_path=\"conf\")\n", 639 | "cfg = compose(config_name=\"talknet-pitch\")\n", 640 | "train(cfg)" 641 | ], 642 | "execution_count": null, 643 | "outputs": [] 644 | }, 645 | { 646 | "cell_type": "code", 647 | "metadata": { 648 | "id": "N9hh4WPHbCcn", 649 | "cellView": "form" 650 | }, 651 | "source": [ 652 | "#@markdown **Step 9:** Train spectrogram generator. 200+ epochs are recommended.\n", 653 | "\n", 654 | "#@markdown This is the slowest of the three models to train, and the hardest to\n", 655 | "#@markdown get good results from. If your character sounds noisy or robotic,\n", 656 | "#@markdown try improving the dataset, or adjusting the epochs and learning rate.\n", 657 | "\n", 658 | "epochs = 200 #@param {type:\"integer\"}\n", 659 | "\n", 660 | "#@markdown If CUDA runs out of memory, try the following:\n", 661 | "#@markdown * Click on Runtime -> Restart runtime, re-run step 3, and try again.\n", 662 | "#@markdown * If that doesn't help, reduce the batch size (default 32).\n", 663 | "batch_size = 32 #@param {type:\"integer\"}\n", 664 | "\n", 665 | "#@markdown Advanced settings. You can probably leave these at their defaults (1e-3, 3e-6, empty, checked).\n", 666 | "learning_rate = 1e-3 #@param {type:\"number\"}\n", 667 | "min_learning_rate = 3e-6 #@param {type:\"number\"}\n", 668 | "pretrained_path = \"\" #@param {type:\"string\"}\n", 669 | "load_checkpoints = True #@param {type:\"boolean\"}\n", 670 | "\n", 671 | "import os\n", 672 | "from hydra.experimental import compose, initialize\n", 673 | "from hydra.core.global_hydra import GlobalHydra\n", 674 | "from omegaconf import OmegaConf\n", 675 | "import pytorch_lightning as pl\n", 676 | "from nemo.collections.common.callbacks import LogEpochTimeCallback\n", 677 | "from nemo.collections.tts.models import TalkNetSpectModel\n", 678 | "from nemo.core.config import hydra_runner\n", 679 | "from nemo.utils.exp_manager import exp_manager\n", 680 | "\n", 681 | "def train(cfg):\n", 682 | " cfg.sample_rate = 22050\n", 683 | " cfg.train_dataset = \"trainfiles.json\"\n", 684 | " cfg.validation_datasets = \"valfiles.json\"\n", 685 | " cfg.durs_file = os.path.join(output_dir, \"durations.pt\")\n", 686 | " cfg.f0_file = os.path.join(output_dir, \"f0s.pt\")\n", 687 | " cfg.trainer.accelerator = \"dp\"\n", 688 | " cfg.trainer.max_epochs = epochs\n", 689 | " cfg.trainer.check_val_every_n_epoch = 5\n", 690 | " cfg.model.train_ds.dataloader_params.batch_size = batch_size\n", 691 | " cfg.model.validation_ds.dataloader_params.batch_size = batch_size\n", 692 | " cfg.model.optim.lr = learning_rate\n", 693 | " cfg.model.optim.sched.min_lr = min_learning_rate\n", 694 | " cfg.exp_manager.exp_dir = output_dir\n", 695 | "\n", 696 | " # Find checkpoints\n", 697 | " ckpt_path = \"\"\n", 698 | " if load_checkpoints:\n", 699 | " path0 = os.path.join(output_dir, \"TalkNetSpect\")\n", 700 | " if os.path.exists(path0):\n", 701 | " path1 = sorted(os.listdir(path0))\n", 702 | " for i in range(len(path1)):\n", 703 | " path2 = os.path.join(path0, path1[-(1+i)], \"checkpoints\")\n", 704 | " if os.path.exists(path2):\n", 705 | " match = [x for x in os.listdir(path2) if \"last.ckpt\" in x]\n", 706 | " if len(match) > 0:\n", 707 | " ckpt_path = os.path.join(path2, match[0])\n", 708 | " print(\"Resuming training from \" + match[0])\n", 709 | " break\n", 710 | "\n", 711 | " if ckpt_path != \"\":\n", 712 | " trainer = pl.Trainer(**cfg.trainer, resume_from_checkpoint = ckpt_path)\n", 713 | " model = TalkNetSpectModel(cfg=cfg.model, trainer=trainer)\n", 714 | " else:\n", 715 | " if pretrained_path != \"\":\n", 716 | " warmstart_path = pretrained_path\n", 717 | " else:\n", 718 | " warmstart_path = \"/content/talknet_spect.nemo\"\n", 719 | " trainer = pl.Trainer(**cfg.trainer)\n", 720 | " model = TalkNetSpectModel.restore_from(warmstart_path, override_config_path=cfg)\n", 721 | " model.set_trainer(trainer)\n", 722 | " model.setup_training_data(cfg.model.train_ds)\n", 723 | " model.setup_validation_data(cfg.model.validation_ds)\n", 724 | " model.setup_optimization(cfg.model.optim)\n", 725 | " print(\"Warm-starting from \" + warmstart_path)\n", 726 | " exp_manager(trainer, cfg.get('exp_manager', None))\n", 727 | " trainer.callbacks.extend([pl.callbacks.LearningRateMonitor(), LogEpochTimeCallback()]) # noqa\n", 728 | " trainer.fit(model)\n", 729 | "\n", 730 | "GlobalHydra().clear()\n", 731 | "initialize(config_path=\"conf\")\n", 732 | "cfg = compose(config_name=\"talknet-spect\")\n", 733 | "train(cfg)" 734 | ], 735 | "execution_count": null, 736 | "outputs": [] 737 | }, 738 | { 739 | "cell_type": "code", 740 | "metadata": { 741 | "id": "Ajfyfz2p9Ior", 742 | "cellView": "form" 743 | }, 744 | "source": [ 745 | "#@markdown **Step 10:** Generate GTA spectrograms. This will help HiFi-GAN learn what your TalkNet model sounds like.\n", 746 | "\n", 747 | "#@markdown If this step fails, make sure you've finished training the spectrogram generator.\n", 748 | "\n", 749 | "import sys\n", 750 | "import os\n", 751 | "import torch\n", 752 | "import numpy as np\n", 753 | "from tqdm import tqdm\n", 754 | "from nemo.collections.tts.models import TalkNetSpectModel\n", 755 | "import shutil\n", 756 | "\n", 757 | "def fix_paths(inpath):\n", 758 | " output = \"\"\n", 759 | " with open(inpath, \"r\", encoding=\"utf8\") as f:\n", 760 | " for l in f.readlines():\n", 761 | " if l[:5].lower() != \"wavs/\":\n", 762 | " output += \"wavs/\" + l\n", 763 | " else:\n", 764 | " output += l\n", 765 | " with open(inpath, \"w\", encoding=\"utf8\") as w:\n", 766 | " w.write(output)\n", 767 | "\n", 768 | "shutil.copyfile(train_filelist, \"/content/hifi-gan/training.txt\")\n", 769 | "shutil.copyfile(val_filelist, \"/content/hifi-gan/validation.txt\")\n", 770 | "fix_paths(\"/content/hifi-gan/training.txt\")\n", 771 | "fix_paths(\"/content/hifi-gan/validation.txt\")\n", 772 | "fix_paths(\"/content/allfiles.txt\")\n", 773 | "\n", 774 | "os.chdir('/content')\n", 775 | "indir = \"wavs\"\n", 776 | "outdir = \"hifi-gan/wavs\"\n", 777 | "if not os.path.exists(outdir):\n", 778 | " os.mkdir(outdir)\n", 779 | "\n", 780 | "model_path = \"\"\n", 781 | "path0 = os.path.join(output_dir, \"TalkNetSpect\")\n", 782 | "if os.path.exists(path0):\n", 783 | " path1 = sorted(os.listdir(path0))\n", 784 | " for i in range(len(path1)):\n", 785 | " path2 = os.path.join(path0, path1[-(1+i)], \"checkpoints\")\n", 786 | " if os.path.exists(path2):\n", 787 | " match = [x for x in os.listdir(path2) if \"TalkNetSpect.nemo\" in x]\n", 788 | " if len(match) > 0:\n", 789 | " model_path = os.path.join(path2, match[0])\n", 790 | " break\n", 791 | "assert model_path != \"\", \"TalkNetSpect.nemo not found\"\n", 792 | "\n", 793 | "dur_path = os.path.join(output_dir, \"durations.pt\")\n", 794 | "f0_path = os.path.join(output_dir, \"f0s.pt\")\n", 795 | "\n", 796 | "model = TalkNetSpectModel.restore_from(model_path)\n", 797 | "model.eval()\n", 798 | "with open(\"allfiles.txt\", \"r\", encoding=\"utf-8\") as f:\n", 799 | " dataset = f.readlines()\n", 800 | "durs = torch.load(dur_path)\n", 801 | "f0s = torch.load(f0_path)\n", 802 | "\n", 803 | "for x in tqdm(dataset):\n", 804 | " x_name = os.path.splitext(os.path.basename(x.split(\"|\")[0].strip()))[0]\n", 805 | " x_tokens = model.parse(text=x.split(\"|\")[1].strip())\n", 806 | " x_durs = (\n", 807 | " torch.stack(\n", 808 | " (\n", 809 | " durs[x_name][\"blanks\"],\n", 810 | " torch.cat((durs[x_name][\"tokens\"], torch.zeros(1).int())),\n", 811 | " ),\n", 812 | " dim=1,\n", 813 | " )\n", 814 | " .view(-1)[:-1]\n", 815 | " .view(1, -1)\n", 816 | " .to(\"cuda:0\")\n", 817 | " )\n", 818 | " x_f0s = f0s[x_name].view(1, -1).to(\"cuda:0\")\n", 819 | " x_spect = model.force_spectrogram(tokens=x_tokens, durs=x_durs, f0=x_f0s)\n", 820 | " rel_path = os.path.splitext(x.split(\"|\")[0].strip())[0][5:]\n", 821 | " abs_dir = os.path.join(outdir, os.path.dirname(rel_path))\n", 822 | " if abs_dir != \"\" and not os.path.exists(abs_dir):\n", 823 | " os.makedirs(abs_dir, exist_ok=True)\n", 824 | " np.save(os.path.join(outdir, rel_path + \".npy\"), x_spect.detach().cpu().numpy())\n" 825 | ], 826 | "execution_count": null, 827 | "outputs": [] 828 | }, 829 | { 830 | "cell_type": "code", 831 | "metadata": { 832 | "id": "yVBjGhRB9hUJ", 833 | "cellView": "form" 834 | }, 835 | "source": [ 836 | "#@markdown **Step 11:** Train HiFi-GAN. 2,000+ steps are recommended.\n", 837 | "#@markdown Stop this cell to finish training the model.\n", 838 | "\n", 839 | "#@markdown If CUDA runs out of memory, click on Runtime -> Restart runtime, re-run step 3, and try again.\n", 840 | "#@markdown If this step still fails to start, make sure step 10 finished successfully.\n", 841 | "\n", 842 | "#@markdown Note: If the training process starts at step 2500000, delete the HiFiGAN folder and try again.\n", 843 | "\n", 844 | "import gdown\n", 845 | "d = 'https://drive.google.com/uc?id='\n", 846 | "\n", 847 | "os.chdir('/content/hifi-gan')\n", 848 | "assert os.path.exists(\"wavs\"), \"Spectrogram folder not found\"\n", 849 | "\n", 850 | "if not os.path.exists(os.path.join(output_dir, \"HiFiGAN\")):\n", 851 | " os.makedirs(os.path.join(output_dir, \"HiFiGAN\"))\n", 852 | "if not os.path.exists(os.path.join(output_dir, \"HiFiGAN\", \"do_00000000\")):\n", 853 | " print(\"Downloading universal model...\")\n", 854 | " gdown.download(d+\"1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW\", os.path.join(output_dir, \"HiFiGAN\", \"g_00000000\"), quiet=False)\n", 855 | " gdown.download(d+\"1O63eHZR9t1haCdRHQcEgMfMNxiOciSru\", os.path.join(output_dir, \"HiFiGAN\", \"do_00000000\"), quiet=False)\n", 856 | " start_from_universal = \"--warm_start True \"\n", 857 | "else:\n", 858 | " start_from_universal = \"\"\n", 859 | "\n", 860 | "!python train.py --fine_tuning True --config config_v1b.json \\\n", 861 | "{start_from_universal} \\\n", 862 | "--checkpoint_interval 250 --checkpoint_path \"{os.path.join(output_dir, 'HiFiGAN')}\" \\\n", 863 | "--input_training_file \"/content/hifi-gan/training.txt\" \\\n", 864 | "--input_validation_file \"/content/hifi-gan/validation.txt\" \\\n", 865 | "--input_wavs_dir \"..\" --input_mels_dir \"wavs\"\n" 866 | ], 867 | "execution_count": null, 868 | "outputs": [] 869 | }, 870 | { 871 | "cell_type": "code", 872 | "metadata": { 873 | "id": "5OtwfrTT-blU", 874 | "cellView": "form" 875 | }, 876 | "source": [ 877 | "#@markdown **Step 12:** Package the models. They'll be saved to the output directory as [character_name]_TalkNet.zip.\n", 878 | "\n", 879 | "character_name = \"Character\" #@param {type:\"string\"}\n", 880 | "\n", 881 | "#@markdown When done, generate a Drive share link, with permissions set to \"Anyone with the link\".\n", 882 | "#@markdown You can then use it with the [Controllable TalkNet notebook](https://colab.research.google.com/drive/1aj6Jk8cpRw7SsN3JSYCv57CrR6s0gYPB)\n", 883 | "#@markdown by selecting \"Custom model\" as your character.\n", 884 | "\n", 885 | "#@markdown This cell will also move the training checkpoints and logs to the trash.\n", 886 | "#@markdown That should free up roughly 2 GB of space on your Drive (remember to empty your trash).\n", 887 | "#@markdown If you wish to keep them, uncheck this box.\n", 888 | "\n", 889 | "delete_checkpoints = True #@param {type:\"boolean\"}\n", 890 | "\n", 891 | "import shutil\n", 892 | "from zipfile import ZipFile\n", 893 | "\n", 894 | "def find_talknet(model_dir):\n", 895 | " ckpt_path = \"\"\n", 896 | " path0 = os.path.join(output_dir, model_dir)\n", 897 | " if os.path.exists(path0):\n", 898 | " path1 = sorted(os.listdir(path0))\n", 899 | " for i in range(len(path1)):\n", 900 | " path2 = os.path.join(path0, path1[-(1+i)], \"checkpoints\")\n", 901 | " if os.path.exists(path2):\n", 902 | " match = [x for x in os.listdir(path2) if \".nemo\" in x]\n", 903 | " if len(match) > 0:\n", 904 | " ckpt_path = os.path.join(path2, match[0])\n", 905 | " break\n", 906 | " assert ckpt_path != \"\", \"Couldn't find \" + model_dir\n", 907 | " return ckpt_path\n", 908 | "\n", 909 | "durs_path = find_talknet(\"TalkNetDurs\")\n", 910 | "pitch_path = find_talknet(\"TalkNetPitch\")\n", 911 | "spect_path = find_talknet(\"TalkNetSpect\")\n", 912 | "assert os.path.exists(os.path.join(output_dir, \"HiFiGAN\", \"g_00000000\")), \"Couldn't find HiFi-GAN\"\n", 913 | "\n", 914 | "zip = ZipFile(os.path.join(output_dir, character_name + \"_TalkNet.zip\"), 'w')\n", 915 | "zip.write(durs_path, \"TalkNetDurs.nemo\")\n", 916 | "zip.write(pitch_path, \"TalkNetPitch.nemo\")\n", 917 | "zip.write(spect_path, \"TalkNetSpect.nemo\")\n", 918 | "zip.write(os.path.join(output_dir, \"HiFiGAN\", \"g_00000000\"), \"hifiganmodel\")\n", 919 | "zip.write(os.path.join(output_dir, \"HiFiGAN\", \"config.json\"), \"config.json\")\n", 920 | "zip.write(os.path.join(output_dir, \"f0_info.json\"), \"f0_info.json\")\n", 921 | "zip.close()\n", 922 | "print(\"Archived model to \" + os.path.join(output_dir, character_name + \"_TalkNet.zip\"))\n", 923 | "\n", 924 | "if delete_checkpoints:\n", 925 | " shutil.rmtree((os.path.join(output_dir, \"TalkNetDurs\")))\n", 926 | " shutil.rmtree((os.path.join(output_dir, \"TalkNetPitch\")))\n", 927 | " shutil.rmtree((os.path.join(output_dir, \"TalkNetSpect\")))\n", 928 | " shutil.rmtree((os.path.join(output_dir, \"HiFiGAN\")))\n" 929 | ], 930 | "execution_count": null, 931 | "outputs": [] 932 | } 933 | ] 934 | } --------------------------------------------------------------------------------