├── content-encoder.png ├── hubert ├── __init__.py ├── utils.py ├── dataset.py └── model.py ├── LICENSE ├── encode.py ├── cluster.py ├── .gitignore ├── hubconf.py ├── Untitled.ipynb ├── README.md └── train.py /content-encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bshall/hubert/HEAD/content-encoder.png -------------------------------------------------------------------------------- /hubert/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ( 2 | Hubert, 3 | HubertDiscrete, 4 | HubertSoft, 5 | ) 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Benjamin van Niekerk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /hubert/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Metric: 5 | def __init__(self): 6 | self.steps = 0 7 | self.value = 0 8 | 9 | def update(self, value): 10 | self.steps += 1 11 | self.value += (value - self.value) / self.steps 12 | return self.value 13 | 14 | def reset(self): 15 | self.steps = 0 16 | self.value = 0 17 | 18 | 19 | def save_checkpoint( 20 | checkpoint_dir, 21 | hubert, 22 | optimizer, 23 | scaler, 24 | step, 25 | loss, 26 | best, 27 | logger, 28 | ): 29 | state = { 30 | "hubert": hubert.state_dict(), 31 | "optimizer": optimizer.state_dict(), 32 | "scaler": scaler.state_dict(), 33 | "step": step, 34 | "loss": loss, 35 | } 36 | checkpoint_dir.mkdir(exist_ok=True, parents=True) 37 | checkpoint_path = checkpoint_dir / f"model-{step}.pt" 38 | torch.save(state, checkpoint_path) 39 | if best: 40 | best_path = checkpoint_dir / "model-best.pt" 41 | torch.save(state, best_path) 42 | logger.info(f"Saved checkpoint: {checkpoint_path.stem}") 43 | 44 | 45 | def load_checkpoint( 46 | load_path, 47 | hubert, 48 | optimizer, 49 | scaler, 50 | rank, 51 | logger, 52 | ): 53 | logger.info(f"Loading checkpoint from {load_path}") 54 | checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"}) 55 | hubert.load_state_dict(checkpoint["hubert"]) 56 | if "scaler" in checkpoint: 57 | scaler.load_state_dict(checkpoint["scaler"]) 58 | if "optimizer" in checkpoint: 59 | optimizer.load_state_dict(checkpoint["optimizer"]) 60 | step, loss = checkpoint.get("step", 0), checkpoint.get("loss", float("inf")) 61 | return step, loss 62 | -------------------------------------------------------------------------------- /encode.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import numpy as np 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torchaudio 9 | from torchaudio.functional import resample 10 | 11 | 12 | def encode_dataset(args): 13 | print(f"Loading hubert checkpoint") 14 | hubert = torch.hub.load( 15 | "bshall/hubert:main", 16 | f"hubert_{args.model}", 17 | trust_repo=True, 18 | ).cuda() 19 | 20 | print(f"Encoding dataset at {args.in_dir}") 21 | for in_path in tqdm(list(args.in_dir.rglob(f"*{args.extension}"))): 22 | wav, sr = torchaudio.load(in_path) 23 | wav = resample(wav, sr, 16000) 24 | wav = wav.unsqueeze(0).cuda() 25 | 26 | with torch.inference_mode(): 27 | units = hubert.units(wav) 28 | 29 | out_path = args.out_dir / in_path.relative_to(args.in_dir) 30 | out_path.parent.mkdir(parents=True, exist_ok=True) 31 | np.save(out_path.with_suffix(".npy"), units.squeeze().cpu().numpy()) 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser(description="Encode an audio dataset.") 36 | parser.add_argument( 37 | "model", 38 | help="available models (HuBERT-Soft or HuBERT-Discrete)", 39 | choices=["soft", "discrete"], 40 | ) 41 | parser.add_argument( 42 | "in_dir", 43 | metavar="in-dir", 44 | help="path to the dataset directory.", 45 | type=Path, 46 | ) 47 | parser.add_argument( 48 | "out_dir", 49 | metavar="out-dir", 50 | help="path to the output directory.", 51 | type=Path, 52 | ) 53 | parser.add_argument( 54 | "--extension", 55 | help="extension of the audio files (defaults to .flac).", 56 | default=".flac", 57 | type=str, 58 | ) 59 | args = parser.parse_args() 60 | encode_dataset(args) 61 | -------------------------------------------------------------------------------- /cluster.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import logging 3 | import argparse 4 | 5 | import torch 6 | import numpy as np 7 | from sklearn.cluster import KMeans 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def cluster(args): 14 | with open(args.subset) as file: 15 | subset = [line.strip() for line in file] 16 | 17 | logger.info(f"Loading features from {args.in_dir}") 18 | features = [] 19 | for path in subset: 20 | in_path = args.in_dir / path 21 | features.append(np.load(in_path.with_suffix(".npy"))) 22 | features = np.concatenate(features, axis=0) 23 | 24 | logger.info(f"Clustering features of shape: {features.shape}") 25 | kmeans = KMeans(n_clusters=args.n_clusters).fit(features) 26 | 27 | checkpoint_path = args.checkpoint_dir / f"kmeans_{args.n_clusters}.pt" 28 | checkpoint_path.parent.mkdir(exist_ok=True, parents=True) 29 | torch.save( 30 | checkpoint_path, 31 | { 32 | "n_features_in_": kmeans.n_features_in_, 33 | "_n_threads": kmeans._n_threads, 34 | "cluster_centers_": kmeans.cluster_centers_, 35 | }, 36 | ) 37 | 38 | 39 | if __name__ == "__main__": 40 | parser = argparse.ArgumentParser(description="Cluster speech features features.") 41 | parser.add_argument( 42 | "in_dir", 43 | metavar="in-dir", 44 | help="path to the encoded dataset", 45 | type=Path, 46 | ) 47 | parser.add_argument( 48 | "subset", 49 | matavar="subset", 50 | help="path to the .txt file containing the list of files to cluster", 51 | type=Path, 52 | ) 53 | parser.add_argument( 54 | "checkpoint_dir", 55 | metavar="checkpoint-dir", 56 | help="path to the checkpoint directory", 57 | type=Path, 58 | ) 59 | parser.add_argument( 60 | "--n-clusters", 61 | help="number of clusters", 62 | type=int, 63 | default=100, 64 | ) 65 | args = parser.parse_args() 66 | cluster(args) 67 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # VSCode project settings 114 | .vscode 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /hubert/dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | import numpy as np 4 | import json 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.utils.data import Dataset 9 | import torchaudio 10 | 11 | 12 | class AcousticUnitsDataset(Dataset): 13 | def __init__( 14 | self, 15 | root: Path, 16 | sample_rate: int = 16000, 17 | label_rate: int = 50, 18 | min_samples: int = 32000, 19 | max_samples: int = 250000, 20 | train: bool = True, 21 | ): 22 | self.wavs_dir = root / "wavs" 23 | self.units_dir = root / "discrete" 24 | 25 | with open(root / "lengths.json") as file: 26 | self.lenghts = json.load(file) 27 | 28 | pattern = "train-*/**/*.flac" if train else "dev-*/**/*.flac" 29 | metadata = ( 30 | (path, path.relative_to(self.wavs_dir).with_suffix("").as_posix()) 31 | for path in self.wavs_dir.rglob(pattern) 32 | ) 33 | metadata = ((path, key) for path, key in metadata if key in self.lenghts) 34 | self.metadata = [ 35 | path for path, key in metadata if self.lenghts[key] > min_samples 36 | ] 37 | 38 | self.sample_rate = sample_rate 39 | self.label_rate = label_rate 40 | self.min_samples = min_samples 41 | self.max_samples = max_samples 42 | self.train = train 43 | 44 | def __len__(self): 45 | return len(self.metadata) 46 | 47 | def __getitem__(self, index): 48 | wav_path = self.metadata[index] 49 | units_path = self.units_dir / wav_path.relative_to(self.wavs_dir) 50 | 51 | wav, _ = torchaudio.load(wav_path) 52 | wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) 53 | codes = np.load(units_path.with_suffix(".npy")) 54 | 55 | return wav, torch.from_numpy(codes).long() 56 | 57 | def collate(self, batch): 58 | wavs, codes = zip(*batch) 59 | wavs, codes = list(wavs), list(codes) 60 | 61 | wav_lengths = [wav.size(-1) for wav in wavs] 62 | code_lengths = [code.size(-1) for code in codes] 63 | 64 | wav_frames = min(self.max_samples, *wav_lengths) 65 | 66 | collated_wavs, wav_offsets = [], [] 67 | for wav in wavs: 68 | wav_diff = wav.size(-1) - wav_frames 69 | wav_offset = random.randint(0, wav_diff) 70 | wav = wav[:, wav_offset : wav_offset + wav_frames] 71 | 72 | collated_wavs.append(wav) 73 | wav_offsets.append(wav_offset) 74 | 75 | rate = self.label_rate / self.sample_rate 76 | code_offsets = [round(wav_offset * rate) for wav_offset in wav_offsets] 77 | code_frames = round(wav_frames * rate) 78 | remaining_code_frames = [ 79 | length - offset for length, offset in zip(code_lengths, code_offsets) 80 | ] 81 | code_frames = min(code_frames, *remaining_code_frames) 82 | 83 | collated_codes = [] 84 | for code, code_offset in zip(codes, code_offsets): 85 | code = code[code_offset : code_offset + code_frames] 86 | collated_codes.append(code) 87 | 88 | wavs = torch.stack(collated_wavs, dim=0) 89 | codes = torch.stack(collated_codes, dim=0) 90 | 91 | return wavs, codes 92 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ["torch", "torchaudio", "sklearn"] 2 | 3 | URLS = { 4 | "hubert-discrete": "https://github.com/bshall/hubert/releases/download/v0.2/hubert-discrete-96b248c5.pt", 5 | "hubert-soft": "https://github.com/bshall/hubert/releases/download/v0.2/hubert-soft-35d9f29f.pt", 6 | "kmeans100": "https://github.com/bshall/hubert/releases/download/v0.2/kmeans100-50f36a95.pt", 7 | } 8 | 9 | import torch 10 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 11 | 12 | from sklearn.cluster import KMeans 13 | 14 | from hubert import Hubert, HubertDiscrete, HubertSoft 15 | 16 | 17 | def hubert() -> Hubert: 18 | r"""Randomly initialized HuBERT from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`""" 19 | hubert = Hubert(504) 20 | hubert.eval() 21 | return hubert 22 | 23 | 24 | def hubert_discrete( 25 | pretrained: bool = True, 26 | progress: bool = True, 27 | ) -> HubertDiscrete: 28 | r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. 29 | Args: 30 | pretrained (bool): load pretrained weights into the model 31 | progress (bool): show progress bar when downloading model 32 | """ 33 | kmeans = kmeans100(pretrained=pretrained, progress=progress) 34 | hubert = HubertDiscrete(kmeans) 35 | if pretrained: 36 | checkpoint = torch.hub.load_state_dict_from_url( 37 | URLS["hubert-discrete"], progress=progress 38 | ) 39 | consume_prefix_in_state_dict_if_present(checkpoint["hubert"], "module.") 40 | hubert.load_state_dict(checkpoint["hubert"]) 41 | hubert.eval() 42 | return hubert 43 | 44 | 45 | def hubert_soft( 46 | pretrained: bool = True, 47 | progress: bool = True, 48 | ) -> HubertSoft: 49 | r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. 50 | Args: 51 | pretrained (bool): load pretrained weights into the model. 52 | progress (bool): show progress bar when downloading model. 53 | """ 54 | hubert = HubertSoft() 55 | if pretrained: 56 | checkpoint = torch.hub.load_state_dict_from_url( 57 | URLS["hubert-soft"], 58 | progress=progress, 59 | ) 60 | consume_prefix_in_state_dict_if_present(checkpoint["hubert"], "module.") 61 | hubert.load_state_dict(checkpoint["hubert"]) 62 | hubert.eval() 63 | return hubert 64 | 65 | 66 | def _kmeans( 67 | num_clusters: int, pretrained: bool = True, progress: bool = True 68 | ) -> KMeans: 69 | kmeans = KMeans(num_clusters) 70 | if pretrained: 71 | checkpoint = torch.hub.load_state_dict_from_url( 72 | URLS[f"kmeans{num_clusters}"], progress=progress 73 | ) 74 | kmeans.__dict__["n_features_in_"] = checkpoint["n_features_in_"] 75 | kmeans.__dict__["_n_threads"] = checkpoint["_n_threads"] 76 | kmeans.__dict__["cluster_centers_"] = checkpoint["cluster_centers_"].numpy() 77 | return kmeans 78 | 79 | 80 | def kmeans100(pretrained: bool = True, progress: bool = True) -> KMeans: 81 | r""" 82 | k-means checkpoint for HuBERT-Discrete with 100 clusters. 83 | Args: 84 | pretrained (bool): load pretrained weights into the model 85 | progress (bool): show progress bar when downloading model 86 | """ 87 | return _kmeans(100, pretrained, progress) 88 | -------------------------------------------------------------------------------- /Untitled.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "1350f00d-5ed6-4218-8434-cfd715b4881d", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "from hubconf import hubert" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "id": "d54c757a-f990-4a0a-80e7-adb54f40bdfe", 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "text/plain": [ 22 | "Hubert(\n", 23 | " (feature_extractor): FeatureExtractor(\n", 24 | " (conv0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)\n", 25 | " (norm0): GroupNorm(512, 512, eps=1e-05, affine=True)\n", 26 | " (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)\n", 27 | " (conv2): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)\n", 28 | " (conv3): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)\n", 29 | " (conv4): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)\n", 30 | " (conv5): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)\n", 31 | " (conv6): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)\n", 32 | " )\n", 33 | " (feature_projection): FeatureProjection(\n", 34 | " (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", 35 | " (projection): Linear(in_features=512, out_features=768, bias=True)\n", 36 | " (dropout): Dropout(p=0.1, inplace=False)\n", 37 | " )\n", 38 | " (positional_embedding): PositionalConvEmbedding(\n", 39 | " (conv): ParametrizedConv1d(\n", 40 | " 768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16\n", 41 | " (parametrizations): ModuleDict(\n", 42 | " (weight): ParametrizationList(\n", 43 | " (0): _WeightNorm()\n", 44 | " )\n", 45 | " )\n", 46 | " )\n", 47 | " )\n", 48 | " (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 49 | " (dropout): Dropout(p=0.1, inplace=False)\n", 50 | " (encoder): TransformerEncoder(\n", 51 | " (layers): ModuleList(\n", 52 | " (0-11): 12 x TransformerEncoderLayer(\n", 53 | " (self_attn): MultiheadAttention(\n", 54 | " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n", 55 | " )\n", 56 | " (linear1): Linear(in_features=768, out_features=3072, bias=True)\n", 57 | " (dropout): Dropout(p=0.1, inplace=False)\n", 58 | " (linear2): Linear(in_features=3072, out_features=768, bias=True)\n", 59 | " (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 60 | " (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", 61 | " (dropout1): Dropout(p=0.1, inplace=False)\n", 62 | " (dropout2): Dropout(p=0.1, inplace=False)\n", 63 | " )\n", 64 | " )\n", 65 | " )\n", 66 | " (proj): Linear(in_features=768, out_features=256, bias=True)\n", 67 | " (label_embedding): Embedding(504, 256)\n", 68 | ")" 69 | ] 70 | }, 71 | "execution_count": 2, 72 | "metadata": {}, 73 | "output_type": "execute_result" 74 | } 75 | ], 76 | "source": [ 77 | "hubert()" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "ae32f397-9476-4a46-b3eb-7898ff73c799", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [] 87 | } 88 | ], 89 | "metadata": { 90 | "kernelspec": { 91 | "display_name": "Python 3 (ipykernel)", 92 | "language": "python", 93 | "name": "python3" 94 | }, 95 | "language_info": { 96 | "codemirror_mode": { 97 | "name": "ipython", 98 | "version": 3 99 | }, 100 | "file_extension": ".py", 101 | "mimetype": "text/x-python", 102 | "name": "python", 103 | "nbconvert_exporter": "python", 104 | "pygments_lexer": "ipython3", 105 | "version": "3.12.2" 106 | } 107 | }, 108 | "nbformat": 4, 109 | "nbformat_minor": 5 110 | } 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HuBERT 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2111.02392) 4 | [![demo](https://img.shields.io/static/v1?message=Audio%20Samples&logo=Github&labelColor=grey&color=blue&logoColor=white&label=%20&style=flat)](https://bshall.github.io/soft-vc/) 5 | [![colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bshall/soft-vc/blob/main/soft-vc-demo.ipynb) 6 | 7 | Training and inference scripts for the HuBERT content encoders in [A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion](https://ieeexplore.ieee.org/abstract/document/9746484). 8 | For more details see [soft-vc](https://github.com/bshall/soft-vc). Audio samples can be found [here](https://bshall.github.io/soft-vc/). Colab demo can be found [here](https://colab.research.google.com/github/bshall/soft-vc/blob/main/soft-vc-demo.ipynb). 9 | 10 |
11 | Soft-VC 13 |
14 |
15 | 16 | Fig 1: Architecture of the voice conversion system. a) The discrete content encoder clusters audio features to produce a sequence of discrete speech units. b) The soft content encoder is trained to predict the discrete units. The acoustic model transforms the discrete/soft speech units into a target spectrogram. The vocoder converts the spectrogram into an audio waveform. 17 | 18 |
19 | 20 | ## Example Usage 21 | 22 | ### Programmatic Usage 23 | 24 | ```python 25 | import torch, torchaudio 26 | 27 | # Load checkpoint (either hubert_soft or hubert_discrete) 28 | hubert = torch.hub.load("bshall/hubert:main", "hubert_soft", trust_repo=True).cuda() 29 | 30 | # Load audio 31 | wav, sr = torchaudio.load("path/to/wav") 32 | assert sr == 16000 33 | wav = wav.unsqueeze(0).cuda() 34 | 35 | # Extract speech units 36 | units = hubert.units(x) 37 | ``` 38 | 39 | ### Script-Based Usage 40 | 41 | ``` 42 | usage: encode.py [-h] [--extension EXTENSION] {soft,discrete} in-dir out-dir 43 | 44 | Encode an audio dataset. 45 | 46 | positional arguments: 47 | {soft,discrete} available models (HuBERT-Soft or HuBERT-Discrete) 48 | in-dir path to the dataset directory. 49 | out-dir path to the output directory. 50 | 51 | optional arguments: 52 | -h, --help show this help message and exit 53 | --extension EXTENSION 54 | extension of the audio files (defaults to .flac). 55 | ``` 56 | 57 | ## Training 58 | 59 | ### Step 1: Dataset Preparation 60 | 61 | Download and extract the [LibriSpeech](https://www.openslr.org/12) corpus. The training script expects the following tree structure for the dataset directory: 62 | 63 | ``` 64 | │ lengths.json 65 | │ 66 | └───wavs 67 | ├───dev-* 68 | │ ├───84 69 | │ ├───... 70 | │ └───8842 71 | └───train-* 72 | ├───19 73 | ├───... 74 | └───8975 75 | ``` 76 | 77 | The `train-*` and `dev-*` directories should contain the training and validation splits respectively. Note that there can be multiple `train` and `dev` folders e.g., `train-clean-100`, `train-other-500`, etc. Finally, the `lengths.json` file should contain key-value pairs with the file path and number of samples: 78 | 79 | ```json 80 | { 81 | "dev-clean/1272/128104/1272-128104-0000": 93680, 82 | "dev-clean/1272/128104/1272-128104-0001": 77040, 83 | } 84 | ``` 85 | 86 | ### Step 2: Extract Discrete Speech Units 87 | 88 | Encode LibriSpeech using the HuBERT-Discrete model and `encode.py` script: 89 | 90 | ``` 91 | usage: encode.py [-h] [--extension EXTENSION] {soft,discrete} in-dir out-dir 92 | 93 | Encode an audio dataset. 94 | 95 | positional arguments: 96 | {soft,discrete} available models (HuBERT-Soft or HuBERT-Discrete) 97 | in-dir path to the dataset directory. 98 | out-dir path to the output directory. 99 | 100 | optional arguments: 101 | -h, --help show this help message and exit 102 | --extension EXTENSION 103 | extension of the audio files (defaults to .flac). 104 | ``` 105 | 106 | for example: 107 | 108 | ``` 109 | python encode.py discrete path/to/LibriSpeech/wavs path/to/LibriSpeech/discrete 110 | ``` 111 | 112 | At this point the directory tree should look like: 113 | 114 | ``` 115 | │ lengths.json 116 | │ 117 | ├───discrete 118 | │ ├───... 119 | └───wavs 120 | ├───... 121 | ``` 122 | 123 | ### Step 3: Train the HuBERT-Soft Content Encoder 124 | 125 | ``` 126 | usage: train.py [-h] [--resume RESUME] [--warmstart] [--mask] [--alpha ALPHA] dataset-dir checkpoint-dir 127 | 128 | Train HuBERT soft content encoder. 129 | 130 | positional arguments: 131 | dataset-dir path to the data directory. 132 | checkpoint-dir path to the checkpoint directory. 133 | 134 | optional arguments: 135 | -h, --help show this help message and exit 136 | --resume RESUME path to the checkpoint to resume from. 137 | --warmstart whether to initialize from the fairseq HuBERT checkpoint. 138 | --mask whether to use input masking. 139 | --alpha ALPHA weight for the masked loss. 140 | ``` 141 | 142 | ## Links 143 | 144 | - [Soft-VC repo](https://github.com/bshall/soft-vc) 145 | - [Soft-VC paper](https://ieeexplore.ieee.org/abstract/document/9746484) 146 | - [Official HuBERT repo](https://github.com/pytorch/fairseq) 147 | - [HuBERT paper](https://arxiv.org/abs/2106.07447) 148 | 149 | ## Citation 150 | 151 | If you found this work helpful please consider citing our paper: 152 | 153 | ``` 154 | @inproceedings{ 155 | soft-vc-2022, 156 | author={van Niekerk, Benjamin and Carbonneau, Marc-André and Zaïdi, Julian and Baas, Matthew and Seuté, Hugo and Kamper, Herman}, 157 | booktitle={ICASSP}, 158 | title={A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion}, 159 | year={2022} 160 | } 161 | ``` 162 | -------------------------------------------------------------------------------- /hubert/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, Tuple 3 | import random 4 | 5 | from sklearn.cluster import KMeans 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class Hubert(nn.Module): 13 | def __init__(self, num_label_embeddings: int = 100, mask: bool = True): 14 | super().__init__() 15 | self._mask = mask 16 | self.feature_extractor = FeatureExtractor() 17 | self.feature_projection = FeatureProjection() 18 | self.positional_embedding = PositionalConvEmbedding() 19 | self.norm = nn.LayerNorm(768) 20 | self.dropout = nn.Dropout(0.1) 21 | self.encoder = TransformerEncoder( 22 | nn.TransformerEncoderLayer( 23 | 768, 12, 3072, activation="gelu", batch_first=True 24 | ), 25 | 12, 26 | ) 27 | self.proj = nn.Linear(768, 256) 28 | 29 | self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) 30 | self.label_embedding = nn.Embedding(num_label_embeddings, 256) 31 | 32 | def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 33 | mask = None 34 | if self.training and self._mask: 35 | mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) 36 | x[mask] = self.masked_spec_embed.to(x.dtype) 37 | return x, mask 38 | 39 | def encode( 40 | self, x: torch.Tensor, layer: Optional[int] = None 41 | ) -> Tuple[torch.Tensor, torch.Tensor]: 42 | x = self.feature_extractor(x) 43 | x = self.feature_projection(x.transpose(1, 2)) 44 | x, mask = self.mask(x) 45 | x = x + self.positional_embedding(x) 46 | x = self.dropout(self.norm(x)) 47 | x = self.encoder(x, output_layer=layer) 48 | return x, mask 49 | 50 | def logits(self, x: torch.Tensor) -> torch.Tensor: 51 | logits = torch.cosine_similarity( 52 | x.unsqueeze(2), 53 | self.label_embedding.weight.unsqueeze(0).unsqueeze(0), 54 | dim=-1, 55 | ) 56 | return logits / 0.1 57 | 58 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 59 | x, mask = self.encode(x) 60 | x = self.proj(x) 61 | logits = self.logits(x) 62 | return logits, mask 63 | 64 | 65 | class HubertSoft(Hubert): 66 | """HuBERT-Soft content encoder from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.""" 67 | 68 | def __init__(self): 69 | super().__init__() 70 | 71 | @torch.inference_mode() 72 | def units(self, wav: torch.Tensor) -> torch.Tensor: 73 | """Extract soft speech units. 74 | 75 | Args: 76 | wav (Tensor): an audio waveform of shape (1, 1, T), where T is the number of samples. 77 | 78 | Returns: 79 | Tensor: soft speech units of shape (1, N, D), where N is the number of frames and D is the unit dimensions. 80 | """ 81 | wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) 82 | x, _ = self.encode(wav) 83 | return self.proj(x) 84 | 85 | 86 | class HubertDiscrete(Hubert): 87 | """HuBERT-Discrete content encoder from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.""" 88 | 89 | def __init__(self, kmeans: KMeans): 90 | super().__init__(504) 91 | self.kmeans = kmeans 92 | 93 | @torch.inference_mode() 94 | def units(self, wav: torch.Tensor) -> torch.LongTensor: 95 | """Extract discrete speech units. 96 | 97 | Args: 98 | wav (Tensor): an audio waveform of shape (1, 1, T), where T is the number of samples. 99 | 100 | Returns: 101 | LongTensor: soft speech units of shape (N,), where N is the number of frames. 102 | """ 103 | wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) 104 | x, _ = self.encode(wav, layer=7) 105 | x = self.kmeans.predict(x.squeeze().cpu().numpy()) 106 | return torch.tensor(x, dtype=torch.long, device=wav.device) 107 | 108 | 109 | class FeatureExtractor(nn.Module): 110 | def __init__(self): 111 | super().__init__() 112 | self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) 113 | self.norm0 = nn.GroupNorm(512, 512) 114 | self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) 115 | self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) 116 | self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) 117 | self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) 118 | self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) 119 | self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) 120 | 121 | def forward(self, x: torch.Tensor) -> torch.Tensor: 122 | x = F.gelu(self.norm0(self.conv0(x))) 123 | x = F.gelu(self.conv1(x)) 124 | x = F.gelu(self.conv2(x)) 125 | x = F.gelu(self.conv3(x)) 126 | x = F.gelu(self.conv4(x)) 127 | x = F.gelu(self.conv5(x)) 128 | x = F.gelu(self.conv6(x)) 129 | return x 130 | 131 | 132 | class FeatureProjection(nn.Module): 133 | def __init__(self): 134 | super().__init__() 135 | self.norm = nn.LayerNorm(512) 136 | self.projection = nn.Linear(512, 768) 137 | self.dropout = nn.Dropout(0.1) 138 | 139 | def forward(self, x: torch.Tensor) -> torch.Tensor: 140 | x = self.norm(x) 141 | x = self.projection(x) 142 | x = self.dropout(x) 143 | return x 144 | 145 | 146 | class PositionalConvEmbedding(nn.Module): 147 | def __init__(self): 148 | super().__init__() 149 | self.conv = nn.Conv1d( 150 | 768, 151 | 768, 152 | kernel_size=128, 153 | padding=128 // 2, 154 | groups=16, 155 | ) 156 | self.conv = nn.utils.parametrizations.weight_norm( 157 | self.conv, name="weight", dim=2 158 | ) 159 | 160 | def forward(self, x: torch.Tensor) -> torch.Tensor: 161 | x = self.conv(x.transpose(1, 2)) 162 | x = F.gelu(x[:, :, :-1]) 163 | return x.transpose(1, 2) 164 | 165 | 166 | class TransformerEncoder(nn.Module): 167 | def __init__( 168 | self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int 169 | ) -> None: 170 | super(TransformerEncoder, self).__init__() 171 | self.layers = nn.ModuleList( 172 | [copy.deepcopy(encoder_layer) for _ in range(num_layers)] 173 | ) 174 | self.num_layers = num_layers 175 | 176 | def forward( 177 | self, 178 | src: torch.Tensor, 179 | mask: torch.Tensor = None, 180 | src_key_padding_mask: torch.Tensor = None, 181 | output_layer: Optional[int] = None, 182 | ) -> torch.Tensor: 183 | output = src 184 | for layer in self.layers[:output_layer]: 185 | output = layer( 186 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask 187 | ) 188 | return output 189 | 190 | 191 | def _compute_mask( 192 | shape: Tuple[int, int], 193 | mask_prob: float, 194 | mask_length: int, 195 | device: torch.device, 196 | min_masks: int = 0, 197 | ) -> torch.Tensor: 198 | batch_size, sequence_length = shape 199 | 200 | if mask_length < 1: 201 | raise ValueError("`mask_length` has to be bigger than 0.") 202 | 203 | if mask_length > sequence_length: 204 | raise ValueError( 205 | f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" 206 | ) 207 | 208 | # compute number of masked spans in batch 209 | num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) 210 | num_masked_spans = max(num_masked_spans, min_masks) 211 | 212 | # make sure num masked indices <= sequence_length 213 | if num_masked_spans * mask_length > sequence_length: 214 | num_masked_spans = sequence_length // mask_length 215 | 216 | # SpecAugment mask to fill 217 | mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) 218 | 219 | # uniform distribution to sample from, make sure that offset samples are < sequence_length 220 | uniform_dist = torch.ones( 221 | (batch_size, sequence_length - (mask_length - 1)), device=device 222 | ) 223 | 224 | # get random indices to mask 225 | mask_indices = torch.multinomial(uniform_dist, num_masked_spans) 226 | 227 | # expand masked indices to masked spans 228 | mask_indices = ( 229 | mask_indices.unsqueeze(dim=-1) 230 | .expand((batch_size, num_masked_spans, mask_length)) 231 | .reshape(batch_size, num_masked_spans * mask_length) 232 | ) 233 | offsets = ( 234 | torch.arange(mask_length, device=device)[None, None, :] 235 | .expand((batch_size, num_masked_spans, mask_length)) 236 | .reshape(batch_size, num_masked_spans * mask_length) 237 | ) 238 | mask_idxs = mask_indices + offsets 239 | 240 | # scatter indices to mask 241 | mask = mask.scatter(1, mask_idxs, True) 242 | 243 | return mask 244 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.cuda.amp as amp 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | import torch.distributed as dist 13 | from torch.utils.data.distributed import DistributedSampler 14 | import torch.multiprocessing as mp 15 | from torch.nn.parallel import DistributedDataParallel as DDP 16 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 17 | 18 | from hubert.model import Hubert, URLS 19 | from hubert.dataset import AcousticUnitsDataset 20 | from hubert.utils import Metric, save_checkpoint, load_checkpoint 21 | 22 | logging.basicConfig(level=logging.INFO) 23 | logger = logging.getLogger(__name__) 24 | 25 | ######################################################################################## 26 | # Define hyperparameters for training: 27 | ######################################################################################## 28 | 29 | BATCH_SIZE = 32 30 | LEARNING_RATE = 2e-5 31 | BETAS = (0.9, 0.98) 32 | EPS = 1e-06 33 | WEIGHT_DECAY = 1e-2 34 | MAX_NORM = 10 35 | STEPS = 25000 36 | LOG_INTERVAL = 5 37 | VALIDATION_INTERVAL = 1000 38 | CHECKPOINT_INTERVAL = 5000 39 | BACKEND = "nccl" 40 | INIT_METHOD = "tcp://localhost:54321" 41 | 42 | 43 | def train(rank, world_size, args): 44 | dist.init_process_group( 45 | BACKEND, 46 | rank=rank, 47 | world_size=world_size, 48 | init_method=INIT_METHOD, 49 | ) 50 | 51 | #################################################################################### 52 | # Setup logging utilities: 53 | #################################################################################### 54 | 55 | log_dir = args.checkpoint_dir / "logs" 56 | log_dir.mkdir(exist_ok=True, parents=True) 57 | 58 | if rank == 0: 59 | logger.setLevel(logging.INFO) 60 | handler = logging.FileHandler(log_dir / f"{args.checkpoint_dir.stem}.log") 61 | handler.setLevel(logging.INFO) 62 | formatter = logging.Formatter( 63 | "%(asctime)s [%(levelname)s] %(message)s", datefmt="%m/%d/%Y %I:%M:%S" 64 | ) 65 | handler.setFormatter(formatter) 66 | logger.addHandler(handler) 67 | else: 68 | logger.setLevel(logging.ERROR) 69 | 70 | writer = SummaryWriter(log_dir) if rank == 0 else None 71 | 72 | #################################################################################### 73 | # Initialize models 74 | #################################################################################### 75 | 76 | hubert = Hubert(mask=args.mask).to(rank) 77 | 78 | if args.warmstart: 79 | checkpoint = torch.hub.load_state_dict_from_url( 80 | URLS["hubert-discrete"], map_location={"cuda:0": f"cuda:{rank}"} 81 | ) 82 | consume_prefix_in_state_dict_if_present(checkpoint["hubert"], "module.") 83 | 84 | # don't use warmstart weights for label embeddings and proj layer 85 | del checkpoint["hubert"]["label_embedding.weight"] 86 | del checkpoint["hubert"]["proj.weight"] 87 | del checkpoint["hubert"]["proj.bias"] 88 | 89 | hubert.load_state_dict(checkpoint["hubert"], strict=False) 90 | 91 | hubert = DDP(hubert, device_ids=[rank]) 92 | 93 | #################################################################################### 94 | # Initialze optimizer and grad scaler 95 | #################################################################################### 96 | 97 | optimizer = optim.AdamW( 98 | hubert.parameters(), 99 | lr=LEARNING_RATE, 100 | betas=BETAS, 101 | eps=EPS, 102 | weight_decay=WEIGHT_DECAY, 103 | ) 104 | scaler = amp.GradScaler() 105 | 106 | #################################################################################### 107 | # Initialize datasets and dataloaders 108 | #################################################################################### 109 | 110 | train_dataset = AcousticUnitsDataset( 111 | root=args.dataset_dir, 112 | train=True, 113 | ) 114 | train_sampler = DistributedSampler(train_dataset, drop_last=True) 115 | train_loader = DataLoader( 116 | train_dataset, 117 | collate_fn=train_dataset.collate, 118 | batch_size=BATCH_SIZE, 119 | sampler=train_sampler, 120 | num_workers=8, 121 | pin_memory=True, 122 | shuffle=False, 123 | drop_last=True, 124 | ) 125 | 126 | validation_dataset = AcousticUnitsDataset( 127 | root=args.dataset_dir, 128 | train=False, 129 | ) 130 | validation_loader = DataLoader( 131 | validation_dataset, 132 | batch_size=1, 133 | shuffle=False, 134 | num_workers=8, 135 | pin_memory=True, 136 | ) 137 | 138 | #################################################################################### 139 | # Load checkpoint if args.resume is set 140 | #################################################################################### 141 | 142 | if args.resume is not None: 143 | global_step, best_loss = load_checkpoint( 144 | load_path=args.resume, 145 | hubert=hubert, 146 | optimizer=optimizer, 147 | scaler=scaler, 148 | rank=rank, 149 | logger=logger, 150 | ) 151 | else: 152 | global_step, best_loss = 0, float("inf") 153 | 154 | # =================================================================================# 155 | # Start training loop 156 | # =================================================================================# 157 | 158 | n_epochs = STEPS // len(train_loader) + 1 159 | start_epoch = global_step // len(train_loader) + 1 160 | 161 | logger.info("**" * 40) 162 | logger.info(f"PyTorch version: {torch.__version__}") 163 | logger.info(f"CUDA version: {torch.version.cuda}") 164 | logger.info(f"CUDNN version: {torch.backends.cudnn.version()}") 165 | logger.info(f"CUDNN enabled: {torch.backends.cudnn.enabled}") 166 | logger.info(f"CUDNN deterministic: {torch.backends.cudnn.deterministic}") 167 | logger.info(f"CUDNN benchmark: {torch.backends.cudnn.benchmark}") 168 | logger.info(f"# of GPUS: {torch.cuda.device_count()}") 169 | logger.info(f"batch size: {BATCH_SIZE}") 170 | logger.info(f"iterations per epoch: {len(train_loader)}") 171 | logger.info(f"# of epochs: {n_epochs}") 172 | logger.info(f"started at epoch: {start_epoch}") 173 | logger.info("**" * 40 + "\n") 174 | 175 | if args.mask: 176 | average_masked_loss = Metric() 177 | average_unmasked_loss = Metric() 178 | average_masked_accuracy = Metric() 179 | average_unmasked_accuracy = Metric() 180 | 181 | epoch_masked_loss = Metric() 182 | epoch_unmasked_loss = Metric() 183 | epoch_masked_accuracy = Metric() 184 | epoch_unmasked_accuracy = Metric() 185 | else: 186 | average_loss = Metric() 187 | average_accuracy = Metric() 188 | 189 | epoch_loss = Metric() 190 | epoch_accuracy = Metric() 191 | 192 | validation_loss = Metric() 193 | validation_accuracy = Metric() 194 | 195 | for epoch in range(start_epoch, n_epochs + 1): 196 | train_sampler.set_epoch(epoch) 197 | 198 | hubert.train() 199 | if args.mask: 200 | epoch_masked_loss.reset() 201 | epoch_unmasked_loss.reset() 202 | epoch_masked_accuracy.reset() 203 | epoch_unmasked_accuracy.reset() 204 | else: 205 | epoch_loss.reset() 206 | epoch_accuracy.reset() 207 | 208 | for wavs, codes in train_loader: 209 | global_step += 1 210 | wavs, codes = wavs.to(rank), codes.to(rank) 211 | 212 | ############################################################################ 213 | # Compute training loss 214 | ############################################################################ 215 | 216 | optimizer.zero_grad() 217 | 218 | with amp.autocast(): 219 | logits, mask = hubert(wavs) 220 | length = min( 221 | mask.size(-1) if args.mask else float("inf"), codes.size(-1) 222 | ) 223 | logits = logits[:, :length, :] 224 | codes = codes[:, :length] 225 | if args.mask: 226 | mask = mask[:, :length] 227 | 228 | if args.mask: 229 | masked_loss = F.cross_entropy(logits[mask], codes[mask]) 230 | unmasked_loss = F.cross_entropy(logits[~mask], codes[~mask]) 231 | loss = args.alpha * masked_loss + (1 - args.alpha) * unmasked_loss 232 | else: 233 | loss = F.cross_entropy(logits.transpose(1, 2), codes) 234 | 235 | scaler.scale(loss).backward() 236 | scaler.unscale_(optimizer) 237 | 238 | nn.utils.clip_grad_norm_(hubert.parameters(), MAX_NORM) 239 | 240 | scaler.step(optimizer) 241 | scaler.update() 242 | 243 | if args.mask: 244 | masked_accuracy = logits[mask].argmax(dim=-1) == codes[mask] 245 | masked_accuracy = torch.mean(masked_accuracy.float()) 246 | 247 | unmasked_accuracy = logits[~mask].argmax(dim=-1) == codes[~mask] 248 | unmasked_accuracy = torch.mean(unmasked_accuracy.float()) 249 | else: 250 | accuracy = logits.argmax(dim=-1) == codes 251 | accuracy = torch.mean(accuracy.float()) 252 | 253 | ############################################################################ 254 | # Update and log training metrics 255 | ############################################################################ 256 | 257 | if args.mask: 258 | average_masked_loss.update(masked_loss.item()) 259 | average_unmasked_loss.update(unmasked_loss.item()) 260 | average_masked_accuracy.update(masked_accuracy.item()) 261 | average_unmasked_accuracy.update(unmasked_accuracy.item()) 262 | 263 | epoch_masked_loss.update(masked_loss.item()) 264 | epoch_unmasked_loss.update(unmasked_loss.item()) 265 | epoch_masked_accuracy.update(masked_accuracy.item()) 266 | epoch_unmasked_accuracy.update(unmasked_accuracy.item()) 267 | else: 268 | average_loss.update(loss.item()) 269 | average_accuracy.update(accuracy.item()) 270 | 271 | epoch_loss.update(loss.item()) 272 | epoch_accuracy.update(accuracy.item()) 273 | 274 | if rank == 0 and global_step % LOG_INTERVAL == 0: 275 | if args.mask: 276 | writer.add_scalar( 277 | "train/masked_loss", 278 | average_masked_loss.value, 279 | global_step, 280 | ) 281 | writer.add_scalar( 282 | "train/unmasked_loss", 283 | average_unmasked_loss.value, 284 | global_step, 285 | ) 286 | writer.add_scalar( 287 | "train/masked_accuracy", 288 | average_masked_accuracy.value * 100, 289 | global_step, 290 | ) 291 | writer.add_scalar( 292 | "train/unmasked_accuracy", 293 | average_unmasked_accuracy.value * 100, 294 | global_step, 295 | ) 296 | average_masked_loss.reset() 297 | average_unmasked_loss.reset() 298 | average_masked_accuracy.reset() 299 | average_unmasked_accuracy.reset() 300 | else: 301 | writer.add_scalar( 302 | "train/loss", 303 | average_loss.value, 304 | global_step, 305 | ) 306 | writer.add_scalar( 307 | "train/accuracy", 308 | average_accuracy.value, 309 | global_step, 310 | ) 311 | average_loss.reset() 312 | average_accuracy.reset() 313 | 314 | # --------------------------------------------------------------------------# 315 | # Start validation loop 316 | # --------------------------------------------------------------------------# 317 | 318 | if global_step % VALIDATION_INTERVAL == 0: 319 | hubert.eval() 320 | validation_loss.reset() 321 | validation_accuracy.reset() 322 | for wavs, codes in validation_loader: 323 | wavs, codes = wavs.to(rank), codes.to(rank) 324 | 325 | with torch.no_grad(): 326 | logits, _ = hubert(wavs) 327 | logits = logits.transpose(1, 2) 328 | 329 | loss = F.cross_entropy(logits, codes) 330 | 331 | accuracy = logits.argmax(dim=1) == codes 332 | accuracy = torch.mean(accuracy.float()) 333 | 334 | #################################################################### 335 | # Update validation metrics 336 | #################################################################### 337 | 338 | validation_loss.update(loss.item()) 339 | validation_accuracy.update(accuracy.item()) 340 | 341 | hubert.train() 342 | 343 | ############################################################################ 344 | # Log validation metrics 345 | ############################################################################ 346 | 347 | if rank == 0: 348 | writer.add_scalar( 349 | "validation/unit_loss", 350 | validation_loss.value, 351 | global_step, 352 | ) 353 | writer.add_scalar( 354 | "validation/unit_accuracy", 355 | validation_accuracy.value * 100, 356 | global_step, 357 | ) 358 | logger.info( 359 | f"valid -- epoch: {epoch}, loss: {validation_loss.value:.4f}, accuracy: {validation_accuracy.value * 100:.2f}" 360 | ) 361 | 362 | ############################################################################ 363 | # Save model checkpoint 364 | ############################################################################ 365 | 366 | new_best = best_loss > validation_loss.value 367 | if new_best or global_step % CHECKPOINT_INTERVAL == 0: 368 | if new_best: 369 | logger.info("-------- new best model found!") 370 | best_loss = validation_loss.value 371 | 372 | if rank == 0: 373 | save_checkpoint( 374 | checkpoint_dir=args.checkpoint_dir, 375 | hubert=hubert, 376 | optimizer=optimizer, 377 | scaler=scaler, 378 | step=global_step, 379 | loss=validation_loss.value, 380 | best=new_best, 381 | logger=logger, 382 | ) 383 | 384 | # -----------------------------------------------------------------------------# 385 | # End validation loop 386 | # -----------------------------------------------------------------------------# 387 | 388 | #################################################################################### 389 | # Log training metrics 390 | #################################################################################### 391 | 392 | logger.info( 393 | f""" 394 | train -- epoch: {epoch}, masked loss: {epoch_masked_loss.value:.4f}, unmasked loss: {epoch_unmasked_loss.value:.4f}, 395 | masked accuracy: {epoch_masked_accuracy.value * 100:.2f}, umasked accuracy: {epoch_unmasked_accuracy.value * 100:.2f} 396 | """ 397 | ) 398 | 399 | # ==================================================================================# 400 | # End training loop 401 | # ==================================================================================# 402 | 403 | dist.destroy_process_group() 404 | 405 | 406 | def train_hubert(args): 407 | world_size = torch.cuda.device_count() 408 | mp.spawn( 409 | train, 410 | args=(world_size, args), 411 | nprocs=world_size, 412 | join=True, 413 | ) 414 | 415 | 416 | if __name__ == "__main__": 417 | parser = argparse.ArgumentParser(description="Train HuBERT soft content encoder.") 418 | parser.add_argument( 419 | "dataset_dir", 420 | metavar="dataset-dir", 421 | help="path to the data directory.", 422 | type=Path, 423 | ) 424 | parser.add_argument( 425 | "checkpoint_dir", 426 | metavar="checkpoint-dir", 427 | help="path to the checkpoint directory.", 428 | type=Path, 429 | ) 430 | parser.add_argument( 431 | "--resume", 432 | help="path to the checkpoint to resume from.", 433 | type=Path, 434 | ) 435 | parser.add_argument( 436 | "--warmstart", 437 | help="whether to initialize from the fairseq HuBERT checkpoint.", 438 | action="store_true", 439 | ) 440 | parser.add_argument( 441 | "--mask", 442 | help="whether to use input masking.", 443 | action="store_true", 444 | ) 445 | parser.add_argument( 446 | "--alpha", 447 | help="weight for the masked loss.", 448 | default=1, 449 | type=float, 450 | ) 451 | args = parser.parse_args() 452 | 453 | world_size = torch.cuda.device_count() 454 | mp.spawn( 455 | train, 456 | args=(world_size, args), 457 | nprocs=world_size, 458 | join=True, 459 | ) 460 | --------------------------------------------------------------------------------