├── Controllable_TalkNet.ipynb
├── README.md
└── TalkNet_Training_A100.ipynb
/Controllable_TalkNet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "include_colab_link": true
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | },
13 | "language_info": {
14 | "name": "python"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "Gss5Ox_RNiba"
33 | },
34 | "source": [
35 | "# Controllable TalkNet\n",
36 | "To run TalkNet, click on Runtime -> Run all. The interface will appear at the bottom of the page when it's ready.\n",
37 | "\n",
38 | "## Instructions\n",
39 | "\n",
40 | "* Once the notebook is running, click on Files (the folder icon on the left edge).\n",
41 | "* Upload audio clips of a singing or speaking voice by dragging and dropping them onto the sidebar.\n",
42 | "* Click on \"Update file list\" in the TalkNet interface. Select an audio file from the dropdown, and type what it says into the Transcript box.\n",
43 | "* Select a character, and press Generate. The first line will take a little longer to generate.\n",
44 | "\n",
45 | "## Tips and tricks\n",
46 | "* If you want to use TalkNet as regular text-to-speech system, without any reference audio, tick the \"Disable reference audio\" checkbox.\n",
47 | "* You can use [ARPABET](http://www.speech.cs.cmu.edu/cgi-bin/cmudict) to override the pronunciation of words, like this: *She made a little bow, then she picked up her {B OW}.*\n",
48 | "* If you're running out of memory generating lines, try to work with shorter clips.\n",
49 | "* The singing models are trained on very little data, and can have a hard time pronouncing certain words. Try experimenting with ARPABET and punctuation.\n",
50 | "* If the voice is off-key, the problem is usually with the extracted pitch. Press \"Debug pitch\" to listen to it. Reference audio with lots of echo/reverb or background noise, or singers with a very high vocal range can cause issues.\n",
51 | "* If the singing voice sounds strained, try enabling \"Change input pitch\" and adjusting it up or down a few semitones. If you're remixing a song, remember to pitch-shift your background track as well.\n",
52 | "\n",
53 | ">Maintained by: `justinjohn0306`\n",
54 | "\n",
55 | ">Special thanks to: `Tapiocapioca#6641` and `effusiveperiscope`"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "metadata": {
61 | "id": "LCqXqFgP2ri0",
62 | "cellView": "form"
63 | },
64 | "source": [
65 | "#@markdown **Step 1:** Check which GPU you've been allocated.\n",
66 | "\n",
67 | "!nvidia-smi -L\n",
68 | "!nvidia-smi"
69 | ],
70 | "execution_count": null,
71 | "outputs": []
72 | },
73 | {
74 | "cell_type": "code",
75 | "metadata": {
76 | "id": "teF-Ut8Z7Gjp",
77 | "cellView": "form"
78 | },
79 | "source": [
80 | "#@markdown **Step 2:** Download dependencies.\n",
81 | "\n",
82 | "#@markdown >### Note: The runtime will crash after this cell completes, which is intended.\n",
83 | "\n",
84 | "import nltk\n",
85 | "nltk.download('averaged_perceptron_tagger_eng')\n",
86 | "\n",
87 | "import os\n",
88 | "\n",
89 | "custom_lists = [\n",
90 | " #\"https://gist.githubusercontent.com/SortAnon/997cda157954a189259c9876fd804e53/raw/example_models.json\",\n",
91 | "]\n",
92 | "\n",
93 | "!apt-get install sox libsndfile1 ffmpeg\n",
94 | "\n",
95 | "# 3.10 pytorch-lightning fix\n",
96 | "!pip install -U torch --index-url https://download.pytorch.org/whl/cu118\n",
97 | "!pip install pytorch-lightning==1.6.5\n",
98 | "!pip install torchmetrics==0.11.4\n",
99 | "# 3.10 fix\n",
100 | "!pip install numpy==1.23.5 scipy==1.10.1 librosa==0.8.1\n",
101 | "!pip install tensorflow dash==1.21.0 dash-bootstrap-components==0.13.0 jupyter-dash==0.4.0 psola wget unidecode pysptk frozendict torchvision torchaudio torchtext torch_stft kaldiio pydub pyannote.audio g2p_en pesq pystoi crepe resampy ffmpeg-python torchcrepe einops taming-transformers-rom1504==0.0.6 tensorflow-hub\n",
102 | "!pip uninstall gdown -y\n",
103 | "!pip install git+https://github.com/wkentaro/gdown.git\n",
104 | "!python -m pip install git+https://github.com/effusiveperiscope/NeMo.git\n",
105 | "if not os.path.exists(\"hifi-gan\"):\n",
106 | " !git clone -q --recursive https://github.com/justinjohn0306/talknet-hifi-gan hifi-gan\n",
107 | "!git clone -q https://github.com/effusiveperiscope/ControllableTalkNet\n",
108 | "os.chdir(\"/content/ControllableTalkNet\")\n",
109 | "!git archive --output=./files.tar --format=tar HEAD\n",
110 | "os.chdir(\"/content\")\n",
111 | "!tar xf ControllableTalkNet/files.tar\n",
112 | "!rm -rf ControllableTalkNet\n",
113 | "\n",
114 | "# 3.10 werkzeug fix\n",
115 | "!python -m pip install werkzeug==2.0.0 flask==2.1.2\n",
116 | "# pytorch cuda fix???\n",
117 | "!pip install -U torch torchtext torchvision torchaudio torch_stft torchcrepe --index-url https://download.pytorch.org/whl/cu121\n",
118 | "!pip install librosa==0.8.1\n",
119 | "\n",
120 | "os.chdir(\"/content/model_lists\")\n",
121 | "for c in custom_lists:\n",
122 | " !wget \"{c}\"\n",
123 | "os.chdir(\"/content\")\n",
124 | "\n",
125 | "\n",
126 | "# restart the runtime\n",
127 | "os.kill(os.getpid(), 9)\n"
128 | ],
129 | "execution_count": null,
130 | "outputs": []
131 | },
132 | {
133 | "cell_type": "code",
134 | "metadata": {
135 | "id": "tOXejargIPTq",
136 | "cellView": "form"
137 | },
138 | "source": [
139 | "# @markdown **Step 3:** **Start GUI**\n",
140 | "using_inline = True\n",
141 | "import pkg_resources\n",
142 | "from pkg_resources import DistributionNotFound, VersionConflict\n",
143 | "\"\"\"dependencies = [\n",
144 | "\"tensorflow==2.4.1\",\n",
145 | "\"dash\",\n",
146 | "\"jupyter-dash\",\n",
147 | "\"psola\",\n",
148 | "\"wget\",\n",
149 | "\"unidecode\",\n",
150 | "\"pysptk\",\n",
151 | "\"frozendict\",\n",
152 | "\"torchvision==0.9.1\",\n",
153 | "\"torchaudio==0.8.1\",\n",
154 | "\"torchtext==0.9.1\",\n",
155 | "\"torch_stft\",\n",
156 | "\"kaldiio\",\n",
157 | "\"pydub\",\n",
158 | "\"pyannote.audio\",\n",
159 | "\"g2p_en\",\n",
160 | "\"pesq\",\n",
161 | "\"pystoi\",\n",
162 | "\"crepe\",\n",
163 | "\"resampy\",\n",
164 | "\"ffmpeg-python\",\n",
165 | "\"numpy\",\n",
166 | "\"scipy\",\n",
167 | "\"nemo_toolkit\",\n",
168 | "\"tqdm\",\n",
169 | "\"gdown\",\n",
170 | "]\n",
171 | "pkg_resources.require(dependencies)\"\"\"\n",
172 | "\n",
173 | "from controllable_talknet import *\n",
174 | "app.run_server(\n",
175 | " mode=\"inline\",\n",
176 | " #dev_tools_ui=True,\n",
177 | " #dev_tools_hot_reload=True,\n",
178 | " threaded=True,\n",
179 | ")"
180 | ],
181 | "execution_count": null,
182 | "outputs": []
183 | },
184 | {
185 | "cell_type": "code",
186 | "metadata": {
187 | "cellView": "form",
188 | "id": "F4WDzfRIgb5Z"
189 | },
190 | "source": [
191 | "# @markdown **Step 3B:** If the above fails with a 403 error, do the following:\n",
192 | "# @markdown * Go to Runtime -> Restart runtime\n",
193 | "# @markdown * Run this cell (click the play button)\n",
194 | "# @markdown * Click on the googleusercontent.com link to use TalkNet in a separate tab\n",
195 | "try:\n",
196 | " using_inline\n",
197 | "except:\n",
198 | " using_inline = False\n",
199 | "if not using_inline:\n",
200 | " from controllable_talknet import *\n",
201 | " from google.colab.output import eval_js\n",
202 | "\n",
203 | " print(eval_js(\"google.colab.kernel.proxyPort(8050)\"))\n",
204 | " app.run_server(\n",
205 | " mode=\"external\",\n",
206 | " debug=False,\n",
207 | " #dev_tools_ui=True,\n",
208 | " #dev_tools_hot_reload=True,\n",
209 | " threaded=True,\n",
210 | " )"
211 | ],
212 | "execution_count": null,
213 | "outputs": []
214 | }
215 | ]
216 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TalkNET-colab
2 |
3 | **TalkNET Model Training Notebook**:
4 |
5 | **TalkNET Synthesis Notebook (when you have a model)**:
6 |
7 |
8 | Taken from [TalkNET](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/tts/models/talknet.py) and made trainable on colab by [SortAnon](https://github.com/SortAnon/ControllableTalkNet) and fixes made by [Justin John](https://github.com/justinjohn0306).
9 |
10 | If you are looking for assistance, feel free to join [discord](https://discord.com/invite/NhJZGtH) and ask in the tech support there.
11 |
--------------------------------------------------------------------------------
/TalkNet_Training_A100.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "provenance": [],
7 | "include_colab_link": true
8 | },
9 | "kernelspec": {
10 | "name": "python3",
11 | "display_name": "Python 3"
12 | },
13 | "language_info": {
14 | "name": "python"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "metadata": {
32 | "id": "Gss5Ox_RNiba"
33 | },
34 | "source": [
35 | "# TalkNet Training\n",
36 | "Last updated: 2023-05-17\n",
37 | "\n",
38 | "To train a 22KHz TalkNet, run the cells below and follow the instructions.\n",
39 | "\n",
40 | "This will take a while, and you might have to do it in multiple Colab sessions. The notebook will automatically resume training any models from the last saved checkpoint. If you're resuming from a new session, always re-run steps 1 through 5 first.\n",
41 | "\n",
42 | "##**IMPORTANT:**\n",
43 | "Your Trash folder on Drive will fill up with old checkpoints\n",
44 | "as you train the various models. Keep an eye on your Drive storage, and empty the trash if it starts to become full.\n",
45 | "\n",
46 | "- Fixes by ``justinjohn0306`` and ``Tapiocapioca#6641``"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "metadata": {
52 | "id": "LCqXqFgP2ri0",
53 | "cellView": "form"
54 | },
55 | "source": [
56 | "#@markdown **Step 1:** Check which GPU you've been allocated.\n",
57 | "\n",
58 | "#@markdown You want a P100, V100, T4, V100 or A100.\n",
59 | "!nvidia-smi -L\n",
60 | "!nvidia-smi"
61 | ],
62 | "execution_count": null,
63 | "outputs": []
64 | },
65 | {
66 | "cell_type": "code",
67 | "metadata": {
68 | "id": "j9TI-Q6m3qlx",
69 | "cellView": "form"
70 | },
71 | "source": [
72 | "#@markdown **Step 2:** Mount Google Drive.\n",
73 | "from google.colab import drive\n",
74 | "drive.mount(\"/content/drive\")"
75 | ],
76 | "execution_count": null,
77 | "outputs": []
78 | },
79 | {
80 | "cell_type": "code",
81 | "metadata": {
82 | "id": "nfSawDUD5tqv",
83 | "cellView": "form"
84 | },
85 | "source": [
86 | "#@markdown **Step 3:** Configure training data paths. Upload the following to your Drive and change the paths below:\n",
87 | "#@markdown * A dataset of .wav files, packaged as a .zip or .tar file\n",
88 | "#@markdown * Training and validation filelists, in LJSpeech format with relative paths (note: ARPABET transcripts are not supported)\n",
89 | "#@markdown * An output path for checkpoints\n",
90 | "\n",
91 | "import os\n",
92 | "\n",
93 | "dataset = \"/content/drive/My Drive/path_to_dataset.zip\" #@param {type:\"string\"}\n",
94 | "train_filelist = \"/content/drive/My Drive/train_filelist.txt\" #@param {type:\"string\"}\n",
95 | "val_filelist = \"/content/drive/My Drive/val_filelist.txt\" #@param {type:\"string\"}\n",
96 | "output_dir = \"/content/drive/My Drive/talknet/name_of_character\" #@param {type:\"string\"}\n",
97 | "assert os.path.exists(dataset), \"Cannot find dataset\"\n",
98 | "assert os.path.exists(train_filelist), \"Cannot find training filelist\"\n",
99 | "assert os.path.exists(val_filelist), \"Cannot find validation filelist\"\n",
100 | "if not os.path.exists(output_dir):\n",
101 | " os.makedirs(output_dir)\n",
102 | "print(\"OK\")\n"
103 | ],
104 | "execution_count": null,
105 | "outputs": []
106 | },
107 | {
108 | "cell_type": "markdown",
109 | "source": [
110 | " **Note: The runtime will crash after step 4, which is intentional.** **Just re-run step 3 and proceed to step 5 and further.**"
111 | ],
112 | "metadata": {
113 | "id": "d02YMYEnqsGJ"
114 | }
115 | },
116 | {
117 | "cell_type": "code",
118 | "metadata": {
119 | "id": "teF-Ut8Z7Gjp",
120 | "cellView": "form"
121 | },
122 | "source": [
123 | "#@markdown **Step 4:** Download NVIDIA NeMo.\n",
124 | "!pip install pip==24.0\n",
125 | "!pip uninstall torch torchvision torchaudio -y\n",
126 | "!pip uninstall gdown -y\n",
127 | "!pip install git+https://github.com/wkentaro/gdown.git\n",
128 | "import os\n",
129 | "import time\n",
130 | "import gdown\n",
131 | "\n",
132 | "os.chdir('/content')\n",
133 | "!apt-get install sox libsndfile1 ffmpeg\n",
134 | "!pip install \"cython<3.0.0\" && pip install --no-build-isolation \"pyyaml<6.0\" && pip install --no-build-isolation \"pysptk==0.2.0\"\n",
135 | "!pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117\n",
136 | "!pip install wget unidecode tensorflow==2.9 tensorboardX pysptk frozendict torch_stft pytorch-lightning==1.3.8 kaldiio pydub pyannote.audio g2p_en pesq pystoi crepe ffmpeg-python\n",
137 | "!python -m pip install git+https://github.com/SortAnon/NeMo.git\n",
138 | "!git clone -q https://github.com/SortAnon/hifi-gan.git\n",
139 | "!pip install --pre torchtext==0.6.0 --no-deps --quiet\n",
140 | "!pip install -q numpy --upgrade torchmetrics==0.6 omegaconf==2.2.3 hmmlearn==0.2.7 crepe==0.0.12 tensorboard==2.9 librosa==0.9.1 protobuf==3.20.0 torch_stft==0.1.4\n",
141 | "\n",
142 | "\n",
143 | "!mkdir -p conf && cd conf \\\n",
144 | "&& wget https://raw.githubusercontent.com/SortAnon/NeMo/main/examples/tts/conf/talknet-durs.yaml \\\n",
145 | "&& wget https://raw.githubusercontent.com/SortAnon/NeMo/main/examples/tts/conf/talknet-pitch.yaml \\\n",
146 | "&& wget https://raw.githubusercontent.com/SortAnon/NeMo/main/examples/tts/conf/talknet-spect.yaml \\\n",
147 | "&& cd ..\n",
148 | "\n",
149 | "# Download pre-trained models\n",
150 | "zip_path = \"tts_en_talknet_1.0.0rc1.zip\"\n",
151 | "for i in range(10):\n",
152 | " if not os.path.exists(zip_path) or os.stat(zip_path).st_size < 100:\n",
153 | " gdown.download(\n",
154 | " \"https://drive.google.com/uc?id=19wSym9mNEnmzLS9XdPlfNAW9_u-mP1hR\",\n",
155 | " zip_path,\n",
156 | " quiet=False,\n",
157 | " )\n",
158 | "!unzip -qo {zip_path}\n",
159 | "\n",
160 | "# restart the runtime\n",
161 | "os.kill(os.getpid(), 9)"
162 | ],
163 | "execution_count": null,
164 | "outputs": []
165 | },
166 | {
167 | "cell_type": "code",
168 | "metadata": {
169 | "id": "bxFr3Fdi_kOC",
170 | "cellView": "form"
171 | },
172 | "source": [
173 | "#@markdown **Step 5:** Dataset processing, part 1.\n",
174 | "\n",
175 | "#@markdown If this step fails, try the following:\n",
176 | "#@markdown * Make sure your filelists are correct. They should have relative\n",
177 | "#@markdown paths that match the contents of the archive.\n",
178 | "\n",
179 | "import os\n",
180 | "import shutil\n",
181 | "import sys\n",
182 | "import json\n",
183 | "import nemo\n",
184 | "import torch\n",
185 | "import torchaudio\n",
186 | "import numpy as np\n",
187 | "from pysptk import sptk\n",
188 | "from pathlib import Path\n",
189 | "from tqdm.notebook import tqdm\n",
190 | "import ffmpeg\n",
191 | "\n",
192 | "def fix_transcripts(inpath):\n",
193 | " found_arpabet = False\n",
194 | " found_grapheme = False\n",
195 | " with open(inpath, \"r\", encoding=\"utf8\") as f:\n",
196 | " lines = f.readlines()\n",
197 | " with open(inpath, \"w\", encoding=\"utf8\") as f:\n",
198 | " for l in lines:\n",
199 | " if l.strip() == \"\":\n",
200 | " continue\n",
201 | " if \"{\" in l:\n",
202 | " if not found_arpabet:\n",
203 | " print(\"Warning: Skipping ARPABET lines (not supported).\")\n",
204 | " found_arpabet = True\n",
205 | " else:\n",
206 | " f.write(l)\n",
207 | " found_grapheme = True\n",
208 | " assert found_grapheme, \"No non-ARPABET lines found in \" + inpath\n",
209 | "\n",
210 | "def generate_json(inpath, outpath):\n",
211 | " output = \"\"\n",
212 | " sample_rate = 22050\n",
213 | " with open(inpath, \"r\", encoding=\"utf8\") as f:\n",
214 | " for l in f.readlines():\n",
215 | " lpath = l.split(\"|\")[0].strip()\n",
216 | " if lpath[:5] != \"wavs/\":\n",
217 | " lpath = \"wavs/\" + lpath\n",
218 | " size = os.stat(\n",
219 | " os.path.join(os.path.dirname(inpath), lpath)\n",
220 | " ).st_size\n",
221 | " x = {\n",
222 | " \"audio_filepath\": lpath,\n",
223 | " \"duration\": size / (sample_rate * 2),\n",
224 | " \"text\": l.split(\"|\")[1].strip(),\n",
225 | " }\n",
226 | " output += json.dumps(x) + \"\\n\"\n",
227 | " with open(outpath, \"w\", encoding=\"utf8\") as w:\n",
228 | " w.write(output)\n",
229 | "\n",
230 | "def convert_to_22k(inpath):\n",
231 | " if inpath.strip()[-4:].lower() != \".wav\":\n",
232 | " print(\"Warning: \" + inpath.strip() + \" is not a .wav file!\")\n",
233 | " return\n",
234 | " ffmpeg.input(inpath).output(\n",
235 | " inpath + \"_22k.wav\",\n",
236 | " ar=\"22050\",\n",
237 | " ac=\"1\",\n",
238 | " acodec=\"pcm_s16le\",\n",
239 | " map_metadata=\"-1\",\n",
240 | " fflags=\"+bitexact\",\n",
241 | " ).overwrite_output().run(quiet=True)\n",
242 | " os.remove(inpath)\n",
243 | " os.rename(inpath + \"_22k.wav\", inpath)\n",
244 | "\n",
245 | "# Extract dataset\n",
246 | "os.chdir('/content')\n",
247 | "if os.path.exists(\"/content/wavs\"):\n",
248 | " shutil.rmtree(\"/content/wavs\")\n",
249 | "os.mkdir(\"wavs\")\n",
250 | "os.chdir(\"wavs\")\n",
251 | "if dataset[-4:] == \".zip\":\n",
252 | " !unzip -q \"{dataset}\"\n",
253 | "elif dataset[-4:] == \".tar\":\n",
254 | " !tar -xf \"{dataset}\"\n",
255 | "else:\n",
256 | " raise Exception(\"Unknown extension for dataset\")\n",
257 | "if os.path.exists(\"/content/wavs/wavs\"):\n",
258 | " shutil.move(\"/content/wavs/wavs\", \"/content/tempwavs\")\n",
259 | " shutil.rmtree(\"/content/wavs\")\n",
260 | " shutil.move(\"/content/tempwavs\", \"/content/wavs\")\n",
261 | "\n",
262 | "# Filelist for preprocessing\n",
263 | "os.chdir('/content')\n",
264 | "shutil.copy(train_filelist, \"trainfiles.txt\")\n",
265 | "shutil.copy(val_filelist, \"valfiles.txt\")\n",
266 | "fix_transcripts(\"trainfiles.txt\")\n",
267 | "fix_transcripts(\"valfiles.txt\")\n",
268 | "seen_files = []\n",
269 | "with open(\"trainfiles.txt\", encoding=\"utf-8\") as f:\n",
270 | " t = f.read().split(\"\\n\")\n",
271 | "with open(\"valfiles.txt\", encoding=\"utf-8\") as f:\n",
272 | " v = f.read().split(\"\\n\")\n",
273 | " all_filelist = t[:] + v[:]\n",
274 | "with open(\"/content/allfiles.txt\", \"w\", encoding=\"utf-8\") as f:\n",
275 | " for x in all_filelist:\n",
276 | " if x.strip() == \"\":\n",
277 | " continue\n",
278 | " if x.split(\"|\")[0] not in seen_files:\n",
279 | " seen_files.append(x.split(\"|\")[0])\n",
280 | " f.write(x.strip() + \"\\n\")\n",
281 | "\n",
282 | "# Ensure audio is 22k\n",
283 | "print(\"Converting audio...\")\n",
284 | "for r, _, f in os.walk(\"/content/wavs\"):\n",
285 | " for name in tqdm(f):\n",
286 | " convert_to_22k(os.path.join(r, name))\n",
287 | "\n",
288 | "# Convert to JSON\n",
289 | "generate_json(\"trainfiles.txt\", \"trainfiles.json\")\n",
290 | "generate_json(\"valfiles.txt\", \"valfiles.json\")\n",
291 | "generate_json(\"allfiles.txt\", \"allfiles.json\")\n",
292 | "\n",
293 | "print(\"OK\")"
294 | ],
295 | "execution_count": null,
296 | "outputs": []
297 | },
298 | {
299 | "cell_type": "code",
300 | "metadata": {
301 | "id": "sos9vsxPkIN7",
302 | "cellView": "form"
303 | },
304 | "source": [
305 | "#@markdown **Step 6:** Dataset processing, part 2. This takes a while, but\n",
306 | "#@markdown you only have to run this once per dataset (results are saved to Drive).\n",
307 | "\n",
308 | "#@markdown If this step fails, try the following:\n",
309 | "#@markdown * Make sure your dataset only contains WAV files.\n",
310 | "\n",
311 | "# Extract phoneme duration\n",
312 | "\n",
313 | "import nltk\n",
314 | "nltk.download('averaged_perceptron_tagger_eng')\n",
315 | "\n",
316 | "import json\n",
317 | "from nemo.collections.asr.models import EncDecCTCModel\n",
318 | "asr_model = EncDecCTCModel.from_pretrained(model_name=\"asr_talknet_aligner\").cpu().eval()\n",
319 | "\n",
320 | "def forward_extractor(tokens, log_probs, blank):\n",
321 | " \"\"\"Computes states f and p.\"\"\"\n",
322 | " n, m = len(tokens), log_probs.shape[0]\n",
323 | " # `f[s, t]` -- max sum of log probs for `s` first codes\n",
324 | " # with `t` first timesteps with ending in `tokens[s]`.\n",
325 | " f = np.empty((n + 1, m + 1), dtype=float)\n",
326 | " f.fill(-(10 ** 9))\n",
327 | " p = np.empty((n + 1, m + 1), dtype=int)\n",
328 | " f[0, 0] = 0.0 # Start\n",
329 | " for s in range(1, n + 1):\n",
330 | " c = tokens[s - 1]\n",
331 | " for t in range((s + 1) // 2, m + 1):\n",
332 | " f[s, t] = log_probs[t - 1, c]\n",
333 | " # Option #1: prev char is equal to current one.\n",
334 | " if s == 1 or c == blank or c == tokens[s - 3]:\n",
335 | " options = f[s : (s - 2 if s > 1 else None) : -1, t - 1]\n",
336 | " else: # Is not equal to current one.\n",
337 | " options = f[s : (s - 3 if s > 2 else None) : -1, t - 1]\n",
338 | " f[s, t] += np.max(options)\n",
339 | " p[s, t] = np.argmax(options)\n",
340 | " return f, p\n",
341 | "\n",
342 | "\n",
343 | "def backward_extractor(f, p):\n",
344 | " \"\"\"Computes durs from f and p.\"\"\"\n",
345 | " n, m = f.shape\n",
346 | " n -= 1\n",
347 | " m -= 1\n",
348 | " durs = np.zeros(n, dtype=int)\n",
349 | " if f[-1, -1] >= f[-2, -1]:\n",
350 | " s, t = n, m\n",
351 | " else:\n",
352 | " s, t = n - 1, m\n",
353 | " while s > 0:\n",
354 | " durs[s - 1] += 1\n",
355 | " s -= p[s, t]\n",
356 | " t -= 1\n",
357 | " assert durs.shape[0] == n\n",
358 | " assert np.sum(durs) == m\n",
359 | " assert np.all(durs[1::2] > 0)\n",
360 | " return durs\n",
361 | "\n",
362 | "def preprocess_tokens(tokens, blank):\n",
363 | " new_tokens = [blank]\n",
364 | " for c in tokens:\n",
365 | " new_tokens.extend([c, blank])\n",
366 | " tokens = new_tokens\n",
367 | " return tokens\n",
368 | "\n",
369 | "data_config = {\n",
370 | " 'manifest_filepath': \"allfiles.json\",\n",
371 | " 'sample_rate': 22050,\n",
372 | " 'labels': asr_model.decoder.vocabulary,\n",
373 | " 'batch_size': 1,\n",
374 | "}\n",
375 | "\n",
376 | "parser = nemo.collections.asr.data.audio_to_text.AudioToCharWithDursF0Dataset.make_vocab(\n",
377 | " notation='phonemes', punct=True, spaces=True, stresses=False, add_blank_at=\"last\"\n",
378 | ")\n",
379 | "\n",
380 | "dataset = nemo.collections.asr.data.audio_to_text._AudioTextDataset(\n",
381 | " manifest_filepath=data_config['manifest_filepath'], sample_rate=data_config['sample_rate'], parser=parser,\n",
382 | ")\n",
383 | "\n",
384 | "dl = torch.utils.data.DataLoader(\n",
385 | " dataset=dataset, batch_size=data_config['batch_size'], collate_fn=dataset.collate_fn, shuffle=False,\n",
386 | ")\n",
387 | "\n",
388 | "blank_id = asr_model.decoder.num_classes_with_blank - 1\n",
389 | "\n",
390 | "if os.path.exists(os.path.join(output_dir, \"durations.pt\")):\n",
391 | " print(\"durations.pt already exists; skipping\")\n",
392 | "else:\n",
393 | " dur_data = {}\n",
394 | " for sample_idx, test_sample in tqdm(enumerate(dl), total=len(dl)):\n",
395 | " log_probs, _, greedy_predictions = asr_model(\n",
396 | " input_signal=test_sample[0], input_signal_length=test_sample[1]\n",
397 | " )\n",
398 | "\n",
399 | " log_probs = log_probs[0].cpu().detach().numpy()\n",
400 | " seq_ids = test_sample[2][0].cpu().detach().numpy()\n",
401 | "\n",
402 | " target_tokens = preprocess_tokens(seq_ids, blank_id)\n",
403 | "\n",
404 | " f, p = forward_extractor(target_tokens, log_probs, blank_id)\n",
405 | " durs = backward_extractor(f, p)\n",
406 | "\n",
407 | " dur_key = Path(dl.dataset.collection[sample_idx].audio_file).stem\n",
408 | " dur_data[dur_key] = {\n",
409 | " 'blanks': torch.tensor(durs[::2], dtype=torch.long).cpu().detach(),\n",
410 | " 'tokens': torch.tensor(durs[1::2], dtype=torch.long).cpu().detach()\n",
411 | " }\n",
412 | "\n",
413 | " del test_sample\n",
414 | "\n",
415 | " torch.save(dur_data, os.path.join(output_dir, \"durations.pt\"))\n",
416 | "\n",
417 | "#Extract F0 (pitch)\n",
418 | "import crepe\n",
419 | "from scipy.io import wavfile\n",
420 | "\n",
421 | "def crepe_f0(audio_file, hop_length=256):\n",
422 | " sr, audio = wavfile.read(audio_file)\n",
423 | " audio_x = np.arange(0, len(audio)) / 22050.0\n",
424 | " time, frequency, confidence, activation = crepe.predict(audio, sr, viterbi=True)\n",
425 | "\n",
426 | " x = np.arange(0, len(audio), hop_length) / 22050.0\n",
427 | " freq_interp = np.interp(x, time, frequency)\n",
428 | " conf_interp = np.interp(x, time, confidence)\n",
429 | " audio_interp = np.interp(x, audio_x, np.absolute(audio)) / 32768.0\n",
430 | " weights = [0.5, 0.25, 0.25]\n",
431 | " audio_smooth = np.convolve(audio_interp, np.array(weights)[::-1], \"same\")\n",
432 | "\n",
433 | " conf_threshold = 0.25\n",
434 | " audio_threshold = 0.0005\n",
435 | " for i in range(len(freq_interp)):\n",
436 | " if conf_interp[i] < conf_threshold:\n",
437 | " freq_interp[i] = 0.0\n",
438 | " if audio_smooth[i] < audio_threshold:\n",
439 | " freq_interp[i] = 0.0\n",
440 | "\n",
441 | " # Hack to make f0 and mel lengths equal\n",
442 | " if len(audio) % hop_length == 0:\n",
443 | " freq_interp = np.pad(freq_interp, pad_width=[0, 1])\n",
444 | " return torch.from_numpy(freq_interp.astype(np.float32))\n",
445 | "\n",
446 | "if os.path.exists(os.path.join(output_dir, \"f0s.pt\")):\n",
447 | " print(\"f0s.pt already exists; skipping\")\n",
448 | "else:\n",
449 | " f0_data = {}\n",
450 | " with open(\"allfiles.json\") as f:\n",
451 | " for i, l in enumerate(f.readlines()):\n",
452 | " print(str(i))\n",
453 | " audio_path = json.loads(l)[\"audio_filepath\"]\n",
454 | " f0_data[Path(audio_path).stem] = crepe_f0(audio_path)\n",
455 | "\n",
456 | " # calculate f0 stats (mean & std) only for train set\n",
457 | " with open(\"trainfiles.json\") as f:\n",
458 | " train_ids = {Path(json.loads(l)[\"audio_filepath\"]).stem for l in f}\n",
459 | " all_f0 = torch.cat([f0[f0 >= 1e-5] for f0_id, f0 in f0_data.items() if f0_id in train_ids])\n",
460 | "\n",
461 | " F0_MEAN, F0_STD = all_f0.mean().item(), all_f0.std().item()\n",
462 | " print(\"F0_MEAN: \" + str(F0_MEAN) + \", F0_STD: \" + str(F0_STD))\n",
463 | " torch.save(f0_data, os.path.join(output_dir, \"f0s.pt\"))\n",
464 | " with open(os.path.join(output_dir, \"f0_info.json\"), \"w\") as f:\n",
465 | " f.write(json.dumps({\"FO_MEAN\": F0_MEAN, \"F0_STD\": F0_STD}))"
466 | ],
467 | "execution_count": null,
468 | "outputs": []
469 | },
470 | {
471 | "cell_type": "code",
472 | "metadata": {
473 | "id": "nM7-bMpKO7U2",
474 | "cellView": "form"
475 | },
476 | "source": [
477 | "#@markdown **Step 7:** Train duration predictor.\n",
478 | "\n",
479 | "#@markdown If CUDA runs out of memory, try the following:\n",
480 | "#@markdown * Click on Runtime -> Restart runtime, re-run step 3, and try again.\n",
481 | "#@markdown * If that doesn't help, reduce the batch size (default 64).\n",
482 | "batch_size = 64 #@param {type:\"integer\"}\n",
483 | "\n",
484 | "epochs = 20\n",
485 | "learning_rate = 1e-3\n",
486 | "min_learning_rate = 3e-6\n",
487 | "load_checkpoints = True\n",
488 | "\n",
489 | "import os\n",
490 | "from hydra.experimental import compose, initialize\n",
491 | "from hydra.core.global_hydra import GlobalHydra\n",
492 | "from omegaconf import OmegaConf\n",
493 | "import pytorch_lightning as pl\n",
494 | "from nemo.collections.common.callbacks import LogEpochTimeCallback\n",
495 | "from nemo.collections.tts.models import TalkNetDursModel\n",
496 | "from nemo.core.config import hydra_runner\n",
497 | "from nemo.utils.exp_manager import exp_manager\n",
498 | "\n",
499 | "def train(cfg):\n",
500 | " cfg.sample_rate = 22050\n",
501 | " cfg.train_dataset = \"trainfiles.json\"\n",
502 | " cfg.validation_datasets = \"valfiles.json\"\n",
503 | " cfg.durs_file = os.path.join(output_dir, \"durations.pt\")\n",
504 | " cfg.f0_file = os.path.join(output_dir, \"f0s.pt\")\n",
505 | " cfg.trainer.accelerator = \"dp\"\n",
506 | " cfg.trainer.max_epochs = epochs\n",
507 | " cfg.trainer.check_val_every_n_epoch = 5\n",
508 | " cfg.model.train_ds.dataloader_params.batch_size = batch_size\n",
509 | " cfg.model.validation_ds.dataloader_params.batch_size = batch_size\n",
510 | " cfg.model.optim.lr = learning_rate\n",
511 | " cfg.model.optim.sched.min_lr = min_learning_rate\n",
512 | " cfg.exp_manager.exp_dir = output_dir\n",
513 | "\n",
514 | " # Find checkpoints\n",
515 | " ckpt_path = \"\"\n",
516 | " if load_checkpoints:\n",
517 | " path0 = os.path.join(output_dir, \"TalkNetDurs\")\n",
518 | " if os.path.exists(path0):\n",
519 | " path1 = sorted(os.listdir(path0))\n",
520 | " for i in range(len(path1)):\n",
521 | " path2 = os.path.join(path0, path1[-(1+i)], \"checkpoints\")\n",
522 | " if os.path.exists(path2):\n",
523 | " match = [x for x in os.listdir(path2) if \"last.ckpt\" in x]\n",
524 | " if len(match) > 0:\n",
525 | " ckpt_path = os.path.join(path2, match[0])\n",
526 | " print(\"Resuming training from \" + match[0])\n",
527 | " break\n",
528 | "\n",
529 | " if ckpt_path != \"\":\n",
530 | " trainer = pl.Trainer(**cfg.trainer, resume_from_checkpoint = ckpt_path)\n",
531 | " model = TalkNetDursModel(cfg=cfg.model, trainer=trainer)\n",
532 | " else:\n",
533 | " warmstart_path = \"/content/talknet_durs.nemo\"\n",
534 | " trainer = pl.Trainer(**cfg.trainer)\n",
535 | " model = TalkNetDursModel.restore_from(warmstart_path, override_config_path=cfg)\n",
536 | " model.set_trainer(trainer)\n",
537 | " model.setup_training_data(cfg.model.train_ds)\n",
538 | " model.setup_validation_data(cfg.model.validation_ds)\n",
539 | " model.setup_optimization(cfg.model.optim)\n",
540 | " print(\"Warm-starting from \" + warmstart_path)\n",
541 | " exp_manager(trainer, cfg.get('exp_manager', None))\n",
542 | " trainer.callbacks.extend([pl.callbacks.LearningRateMonitor(), LogEpochTimeCallback()]) # noqa\n",
543 | " trainer.fit(model)\n",
544 | "\n",
545 | "GlobalHydra().clear()\n",
546 | "initialize(config_path=\"conf\")\n",
547 | "cfg = compose(config_name=\"talknet-durs\")\n",
548 | "train(cfg)\n"
549 | ],
550 | "execution_count": null,
551 | "outputs": []
552 | },
553 | {
554 | "cell_type": "code",
555 | "metadata": {
556 | "id": "JLfm00NuJfon",
557 | "cellView": "form"
558 | },
559 | "source": [
560 | "#@markdown **Step 8:** Train pitch predictor.\n",
561 | "\n",
562 | "#@markdown If CUDA runs out of memory, try the following:\n",
563 | "#@markdown * Click on Runtime -> Restart runtime, re-run step 3, and try again.\n",
564 | "#@markdown * If that doesn't help, reduce the batch size (default 64).\n",
565 | "batch_size = 64 #@param {type:\"integer\"}\n",
566 | "epochs = 50\n",
567 | "\n",
568 | "import json\n",
569 | "\n",
570 | "with open(os.path.join(output_dir, \"f0_info.json\"), \"r\") as f:\n",
571 | " f0_info = json.load(f)\n",
572 | " f0_mean = f0_info[\"FO_MEAN\"]\n",
573 | " f0_std = f0_info[\"F0_STD\"]\n",
574 | "\n",
575 | "learning_rate = 1e-3\n",
576 | "min_learning_rate = 3e-6\n",
577 | "load_checkpoints = True\n",
578 | "\n",
579 | "import os\n",
580 | "from hydra.experimental import compose, initialize\n",
581 | "from hydra.core.global_hydra import GlobalHydra\n",
582 | "from omegaconf import OmegaConf\n",
583 | "import pytorch_lightning as pl\n",
584 | "from nemo.collections.common.callbacks import LogEpochTimeCallback\n",
585 | "from nemo.collections.tts.models import TalkNetPitchModel\n",
586 | "from nemo.core.config import hydra_runner\n",
587 | "from nemo.utils.exp_manager import exp_manager\n",
588 | "\n",
589 | "def train(cfg):\n",
590 | " cfg.sample_rate = 22050\n",
591 | " cfg.train_dataset = \"trainfiles.json\"\n",
592 | " cfg.validation_datasets = \"valfiles.json\"\n",
593 | " cfg.durs_file = os.path.join(output_dir, \"durations.pt\")\n",
594 | " cfg.f0_file = os.path.join(output_dir, \"f0s.pt\")\n",
595 | " cfg.trainer.accelerator = \"dp\"\n",
596 | " cfg.trainer.max_epochs = epochs\n",
597 | " cfg.trainer.check_val_every_n_epoch = 5\n",
598 | " cfg.model.f0_mean=f0_mean\n",
599 | " cfg.model.f0_std=f0_std\n",
600 | " cfg.model.train_ds.dataloader_params.batch_size = batch_size\n",
601 | " cfg.model.validation_ds.dataloader_params.batch_size = batch_size\n",
602 | " cfg.model.optim.lr = learning_rate\n",
603 | " cfg.model.optim.sched.min_lr = min_learning_rate\n",
604 | " cfg.exp_manager.exp_dir = output_dir\n",
605 | "\n",
606 | " # Find checkpoints\n",
607 | " ckpt_path = \"\"\n",
608 | " if load_checkpoints:\n",
609 | " path0 = os.path.join(output_dir, \"TalkNetPitch\")\n",
610 | " if os.path.exists(path0):\n",
611 | " path1 = sorted(os.listdir(path0))\n",
612 | " for i in range(len(path1)):\n",
613 | " path2 = os.path.join(path0, path1[-(1+i)], \"checkpoints\")\n",
614 | " if os.path.exists(path2):\n",
615 | " match = [x for x in os.listdir(path2) if \"last.ckpt\" in x]\n",
616 | " if len(match) > 0:\n",
617 | " ckpt_path = os.path.join(path2, match[0])\n",
618 | " print(\"Resuming training from \" + match[0])\n",
619 | " break\n",
620 | "\n",
621 | " if ckpt_path != \"\":\n",
622 | " trainer = pl.Trainer(**cfg.trainer, resume_from_checkpoint = ckpt_path)\n",
623 | " model = TalkNetPitchModel(cfg=cfg.model, trainer=trainer)\n",
624 | " else:\n",
625 | " warmstart_path = \"/content/talknet_pitch.nemo\"\n",
626 | " trainer = pl.Trainer(**cfg.trainer)\n",
627 | " model = TalkNetPitchModel.restore_from(warmstart_path, override_config_path=cfg)\n",
628 | " model.set_trainer(trainer)\n",
629 | " model.setup_training_data(cfg.model.train_ds)\n",
630 | " model.setup_validation_data(cfg.model.validation_ds)\n",
631 | " model.setup_optimization(cfg.model.optim)\n",
632 | " print(\"Warm-starting from \" + warmstart_path)\n",
633 | " exp_manager(trainer, cfg.get('exp_manager', None))\n",
634 | " trainer.callbacks.extend([pl.callbacks.LearningRateMonitor(), LogEpochTimeCallback()]) # noqa\n",
635 | " trainer.fit(model)\n",
636 | "\n",
637 | "GlobalHydra().clear()\n",
638 | "initialize(config_path=\"conf\")\n",
639 | "cfg = compose(config_name=\"talknet-pitch\")\n",
640 | "train(cfg)"
641 | ],
642 | "execution_count": null,
643 | "outputs": []
644 | },
645 | {
646 | "cell_type": "code",
647 | "metadata": {
648 | "id": "N9hh4WPHbCcn",
649 | "cellView": "form"
650 | },
651 | "source": [
652 | "#@markdown **Step 9:** Train spectrogram generator. 200+ epochs are recommended.\n",
653 | "\n",
654 | "#@markdown This is the slowest of the three models to train, and the hardest to\n",
655 | "#@markdown get good results from. If your character sounds noisy or robotic,\n",
656 | "#@markdown try improving the dataset, or adjusting the epochs and learning rate.\n",
657 | "\n",
658 | "epochs = 200 #@param {type:\"integer\"}\n",
659 | "\n",
660 | "#@markdown If CUDA runs out of memory, try the following:\n",
661 | "#@markdown * Click on Runtime -> Restart runtime, re-run step 3, and try again.\n",
662 | "#@markdown * If that doesn't help, reduce the batch size (default 32).\n",
663 | "batch_size = 32 #@param {type:\"integer\"}\n",
664 | "\n",
665 | "#@markdown Advanced settings. You can probably leave these at their defaults (1e-3, 3e-6, empty, checked).\n",
666 | "learning_rate = 1e-3 #@param {type:\"number\"}\n",
667 | "min_learning_rate = 3e-6 #@param {type:\"number\"}\n",
668 | "pretrained_path = \"\" #@param {type:\"string\"}\n",
669 | "load_checkpoints = True #@param {type:\"boolean\"}\n",
670 | "\n",
671 | "import os\n",
672 | "from hydra.experimental import compose, initialize\n",
673 | "from hydra.core.global_hydra import GlobalHydra\n",
674 | "from omegaconf import OmegaConf\n",
675 | "import pytorch_lightning as pl\n",
676 | "from nemo.collections.common.callbacks import LogEpochTimeCallback\n",
677 | "from nemo.collections.tts.models import TalkNetSpectModel\n",
678 | "from nemo.core.config import hydra_runner\n",
679 | "from nemo.utils.exp_manager import exp_manager\n",
680 | "\n",
681 | "def train(cfg):\n",
682 | " cfg.sample_rate = 22050\n",
683 | " cfg.train_dataset = \"trainfiles.json\"\n",
684 | " cfg.validation_datasets = \"valfiles.json\"\n",
685 | " cfg.durs_file = os.path.join(output_dir, \"durations.pt\")\n",
686 | " cfg.f0_file = os.path.join(output_dir, \"f0s.pt\")\n",
687 | " cfg.trainer.accelerator = \"dp\"\n",
688 | " cfg.trainer.max_epochs = epochs\n",
689 | " cfg.trainer.check_val_every_n_epoch = 5\n",
690 | " cfg.model.train_ds.dataloader_params.batch_size = batch_size\n",
691 | " cfg.model.validation_ds.dataloader_params.batch_size = batch_size\n",
692 | " cfg.model.optim.lr = learning_rate\n",
693 | " cfg.model.optim.sched.min_lr = min_learning_rate\n",
694 | " cfg.exp_manager.exp_dir = output_dir\n",
695 | "\n",
696 | " # Find checkpoints\n",
697 | " ckpt_path = \"\"\n",
698 | " if load_checkpoints:\n",
699 | " path0 = os.path.join(output_dir, \"TalkNetSpect\")\n",
700 | " if os.path.exists(path0):\n",
701 | " path1 = sorted(os.listdir(path0))\n",
702 | " for i in range(len(path1)):\n",
703 | " path2 = os.path.join(path0, path1[-(1+i)], \"checkpoints\")\n",
704 | " if os.path.exists(path2):\n",
705 | " match = [x for x in os.listdir(path2) if \"last.ckpt\" in x]\n",
706 | " if len(match) > 0:\n",
707 | " ckpt_path = os.path.join(path2, match[0])\n",
708 | " print(\"Resuming training from \" + match[0])\n",
709 | " break\n",
710 | "\n",
711 | " if ckpt_path != \"\":\n",
712 | " trainer = pl.Trainer(**cfg.trainer, resume_from_checkpoint = ckpt_path)\n",
713 | " model = TalkNetSpectModel(cfg=cfg.model, trainer=trainer)\n",
714 | " else:\n",
715 | " if pretrained_path != \"\":\n",
716 | " warmstart_path = pretrained_path\n",
717 | " else:\n",
718 | " warmstart_path = \"/content/talknet_spect.nemo\"\n",
719 | " trainer = pl.Trainer(**cfg.trainer)\n",
720 | " model = TalkNetSpectModel.restore_from(warmstart_path, override_config_path=cfg)\n",
721 | " model.set_trainer(trainer)\n",
722 | " model.setup_training_data(cfg.model.train_ds)\n",
723 | " model.setup_validation_data(cfg.model.validation_ds)\n",
724 | " model.setup_optimization(cfg.model.optim)\n",
725 | " print(\"Warm-starting from \" + warmstart_path)\n",
726 | " exp_manager(trainer, cfg.get('exp_manager', None))\n",
727 | " trainer.callbacks.extend([pl.callbacks.LearningRateMonitor(), LogEpochTimeCallback()]) # noqa\n",
728 | " trainer.fit(model)\n",
729 | "\n",
730 | "GlobalHydra().clear()\n",
731 | "initialize(config_path=\"conf\")\n",
732 | "cfg = compose(config_name=\"talknet-spect\")\n",
733 | "train(cfg)"
734 | ],
735 | "execution_count": null,
736 | "outputs": []
737 | },
738 | {
739 | "cell_type": "code",
740 | "metadata": {
741 | "id": "Ajfyfz2p9Ior",
742 | "cellView": "form"
743 | },
744 | "source": [
745 | "#@markdown **Step 10:** Generate GTA spectrograms. This will help HiFi-GAN learn what your TalkNet model sounds like.\n",
746 | "\n",
747 | "#@markdown If this step fails, make sure you've finished training the spectrogram generator.\n",
748 | "\n",
749 | "import sys\n",
750 | "import os\n",
751 | "import torch\n",
752 | "import numpy as np\n",
753 | "from tqdm import tqdm\n",
754 | "from nemo.collections.tts.models import TalkNetSpectModel\n",
755 | "import shutil\n",
756 | "\n",
757 | "def fix_paths(inpath):\n",
758 | " output = \"\"\n",
759 | " with open(inpath, \"r\", encoding=\"utf8\") as f:\n",
760 | " for l in f.readlines():\n",
761 | " if l[:5].lower() != \"wavs/\":\n",
762 | " output += \"wavs/\" + l\n",
763 | " else:\n",
764 | " output += l\n",
765 | " with open(inpath, \"w\", encoding=\"utf8\") as w:\n",
766 | " w.write(output)\n",
767 | "\n",
768 | "shutil.copyfile(train_filelist, \"/content/hifi-gan/training.txt\")\n",
769 | "shutil.copyfile(val_filelist, \"/content/hifi-gan/validation.txt\")\n",
770 | "fix_paths(\"/content/hifi-gan/training.txt\")\n",
771 | "fix_paths(\"/content/hifi-gan/validation.txt\")\n",
772 | "fix_paths(\"/content/allfiles.txt\")\n",
773 | "\n",
774 | "os.chdir('/content')\n",
775 | "indir = \"wavs\"\n",
776 | "outdir = \"hifi-gan/wavs\"\n",
777 | "if not os.path.exists(outdir):\n",
778 | " os.mkdir(outdir)\n",
779 | "\n",
780 | "model_path = \"\"\n",
781 | "path0 = os.path.join(output_dir, \"TalkNetSpect\")\n",
782 | "if os.path.exists(path0):\n",
783 | " path1 = sorted(os.listdir(path0))\n",
784 | " for i in range(len(path1)):\n",
785 | " path2 = os.path.join(path0, path1[-(1+i)], \"checkpoints\")\n",
786 | " if os.path.exists(path2):\n",
787 | " match = [x for x in os.listdir(path2) if \"TalkNetSpect.nemo\" in x]\n",
788 | " if len(match) > 0:\n",
789 | " model_path = os.path.join(path2, match[0])\n",
790 | " break\n",
791 | "assert model_path != \"\", \"TalkNetSpect.nemo not found\"\n",
792 | "\n",
793 | "dur_path = os.path.join(output_dir, \"durations.pt\")\n",
794 | "f0_path = os.path.join(output_dir, \"f0s.pt\")\n",
795 | "\n",
796 | "model = TalkNetSpectModel.restore_from(model_path)\n",
797 | "model.eval()\n",
798 | "with open(\"allfiles.txt\", \"r\", encoding=\"utf-8\") as f:\n",
799 | " dataset = f.readlines()\n",
800 | "durs = torch.load(dur_path)\n",
801 | "f0s = torch.load(f0_path)\n",
802 | "\n",
803 | "for x in tqdm(dataset):\n",
804 | " x_name = os.path.splitext(os.path.basename(x.split(\"|\")[0].strip()))[0]\n",
805 | " x_tokens = model.parse(text=x.split(\"|\")[1].strip())\n",
806 | " x_durs = (\n",
807 | " torch.stack(\n",
808 | " (\n",
809 | " durs[x_name][\"blanks\"],\n",
810 | " torch.cat((durs[x_name][\"tokens\"], torch.zeros(1).int())),\n",
811 | " ),\n",
812 | " dim=1,\n",
813 | " )\n",
814 | " .view(-1)[:-1]\n",
815 | " .view(1, -1)\n",
816 | " .to(\"cuda:0\")\n",
817 | " )\n",
818 | " x_f0s = f0s[x_name].view(1, -1).to(\"cuda:0\")\n",
819 | " x_spect = model.force_spectrogram(tokens=x_tokens, durs=x_durs, f0=x_f0s)\n",
820 | " rel_path = os.path.splitext(x.split(\"|\")[0].strip())[0][5:]\n",
821 | " abs_dir = os.path.join(outdir, os.path.dirname(rel_path))\n",
822 | " if abs_dir != \"\" and not os.path.exists(abs_dir):\n",
823 | " os.makedirs(abs_dir, exist_ok=True)\n",
824 | " np.save(os.path.join(outdir, rel_path + \".npy\"), x_spect.detach().cpu().numpy())\n"
825 | ],
826 | "execution_count": null,
827 | "outputs": []
828 | },
829 | {
830 | "cell_type": "code",
831 | "metadata": {
832 | "id": "yVBjGhRB9hUJ",
833 | "cellView": "form"
834 | },
835 | "source": [
836 | "#@markdown **Step 11:** Train HiFi-GAN. 2,000+ steps are recommended.\n",
837 | "#@markdown Stop this cell to finish training the model.\n",
838 | "\n",
839 | "#@markdown If CUDA runs out of memory, click on Runtime -> Restart runtime, re-run step 3, and try again.\n",
840 | "#@markdown If this step still fails to start, make sure step 10 finished successfully.\n",
841 | "\n",
842 | "#@markdown Note: If the training process starts at step 2500000, delete the HiFiGAN folder and try again.\n",
843 | "\n",
844 | "import gdown\n",
845 | "d = 'https://drive.google.com/uc?id='\n",
846 | "\n",
847 | "os.chdir('/content/hifi-gan')\n",
848 | "assert os.path.exists(\"wavs\"), \"Spectrogram folder not found\"\n",
849 | "\n",
850 | "if not os.path.exists(os.path.join(output_dir, \"HiFiGAN\")):\n",
851 | " os.makedirs(os.path.join(output_dir, \"HiFiGAN\"))\n",
852 | "if not os.path.exists(os.path.join(output_dir, \"HiFiGAN\", \"do_00000000\")):\n",
853 | " print(\"Downloading universal model...\")\n",
854 | " gdown.download(d+\"1qpgI41wNXFcH-iKq1Y42JlBC9j0je8PW\", os.path.join(output_dir, \"HiFiGAN\", \"g_00000000\"), quiet=False)\n",
855 | " gdown.download(d+\"1O63eHZR9t1haCdRHQcEgMfMNxiOciSru\", os.path.join(output_dir, \"HiFiGAN\", \"do_00000000\"), quiet=False)\n",
856 | " start_from_universal = \"--warm_start True \"\n",
857 | "else:\n",
858 | " start_from_universal = \"\"\n",
859 | "\n",
860 | "!python train.py --fine_tuning True --config config_v1b.json \\\n",
861 | "{start_from_universal} \\\n",
862 | "--checkpoint_interval 250 --checkpoint_path \"{os.path.join(output_dir, 'HiFiGAN')}\" \\\n",
863 | "--input_training_file \"/content/hifi-gan/training.txt\" \\\n",
864 | "--input_validation_file \"/content/hifi-gan/validation.txt\" \\\n",
865 | "--input_wavs_dir \"..\" --input_mels_dir \"wavs\"\n"
866 | ],
867 | "execution_count": null,
868 | "outputs": []
869 | },
870 | {
871 | "cell_type": "code",
872 | "metadata": {
873 | "id": "5OtwfrTT-blU",
874 | "cellView": "form"
875 | },
876 | "source": [
877 | "#@markdown **Step 12:** Package the models. They'll be saved to the output directory as [character_name]_TalkNet.zip.\n",
878 | "\n",
879 | "character_name = \"Character\" #@param {type:\"string\"}\n",
880 | "\n",
881 | "#@markdown When done, generate a Drive share link, with permissions set to \"Anyone with the link\".\n",
882 | "#@markdown You can then use it with the [Controllable TalkNet notebook](https://colab.research.google.com/drive/1aj6Jk8cpRw7SsN3JSYCv57CrR6s0gYPB)\n",
883 | "#@markdown by selecting \"Custom model\" as your character.\n",
884 | "\n",
885 | "#@markdown This cell will also move the training checkpoints and logs to the trash.\n",
886 | "#@markdown That should free up roughly 2 GB of space on your Drive (remember to empty your trash).\n",
887 | "#@markdown If you wish to keep them, uncheck this box.\n",
888 | "\n",
889 | "delete_checkpoints = True #@param {type:\"boolean\"}\n",
890 | "\n",
891 | "import shutil\n",
892 | "from zipfile import ZipFile\n",
893 | "\n",
894 | "def find_talknet(model_dir):\n",
895 | " ckpt_path = \"\"\n",
896 | " path0 = os.path.join(output_dir, model_dir)\n",
897 | " if os.path.exists(path0):\n",
898 | " path1 = sorted(os.listdir(path0))\n",
899 | " for i in range(len(path1)):\n",
900 | " path2 = os.path.join(path0, path1[-(1+i)], \"checkpoints\")\n",
901 | " if os.path.exists(path2):\n",
902 | " match = [x for x in os.listdir(path2) if \".nemo\" in x]\n",
903 | " if len(match) > 0:\n",
904 | " ckpt_path = os.path.join(path2, match[0])\n",
905 | " break\n",
906 | " assert ckpt_path != \"\", \"Couldn't find \" + model_dir\n",
907 | " return ckpt_path\n",
908 | "\n",
909 | "durs_path = find_talknet(\"TalkNetDurs\")\n",
910 | "pitch_path = find_talknet(\"TalkNetPitch\")\n",
911 | "spect_path = find_talknet(\"TalkNetSpect\")\n",
912 | "assert os.path.exists(os.path.join(output_dir, \"HiFiGAN\", \"g_00000000\")), \"Couldn't find HiFi-GAN\"\n",
913 | "\n",
914 | "zip = ZipFile(os.path.join(output_dir, character_name + \"_TalkNet.zip\"), 'w')\n",
915 | "zip.write(durs_path, \"TalkNetDurs.nemo\")\n",
916 | "zip.write(pitch_path, \"TalkNetPitch.nemo\")\n",
917 | "zip.write(spect_path, \"TalkNetSpect.nemo\")\n",
918 | "zip.write(os.path.join(output_dir, \"HiFiGAN\", \"g_00000000\"), \"hifiganmodel\")\n",
919 | "zip.write(os.path.join(output_dir, \"HiFiGAN\", \"config.json\"), \"config.json\")\n",
920 | "zip.write(os.path.join(output_dir, \"f0_info.json\"), \"f0_info.json\")\n",
921 | "zip.close()\n",
922 | "print(\"Archived model to \" + os.path.join(output_dir, character_name + \"_TalkNet.zip\"))\n",
923 | "\n",
924 | "if delete_checkpoints:\n",
925 | " shutil.rmtree((os.path.join(output_dir, \"TalkNetDurs\")))\n",
926 | " shutil.rmtree((os.path.join(output_dir, \"TalkNetPitch\")))\n",
927 | " shutil.rmtree((os.path.join(output_dir, \"TalkNetSpect\")))\n",
928 | " shutil.rmtree((os.path.join(output_dir, \"HiFiGAN\")))\n"
929 | ],
930 | "execution_count": null,
931 | "outputs": []
932 | }
933 | ]
934 | }
--------------------------------------------------------------------------------