├── .github └── workflows │ └── code_lint.yml ├── .gitignore ├── ASR_presentation_project_Apavou_Belkada_Tronchon_Zucker.pdf ├── ASR_report_project_Apavou_Belkada_Tronchon_Zucker.pdf ├── Datasets ├── __init__.py └── datamodule.py ├── README.md ├── agents ├── BaseTrainer.py └── __init__.py ├── assets ├── .keep ├── img_readme │ ├── ASR_language family.png │ ├── Network.drawio.png │ ├── PER test vs Training data.png │ ├── PER validation vs Training data.png │ ├── hubert.jpeg │ ├── parrot.png │ ├── wav2vec2.png │ └── wavlm.png └── phoible.csv ├── conda_environment.yml ├── config ├── __init__.py └── hparams.py ├── main.py ├── models ├── BaseModule.py ├── __init__.py └── models.py ├── requirements.txt ├── requirements_cuda11-3.txt ├── test.sh ├── train_notebook.ipynb └── utils ├── __init__.py ├── agent_utils.py ├── callbacks.py ├── constant.py ├── dataset_utils.py ├── logger.py ├── metrics.py └── per.py /.github/workflows/code_lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | # Trigger the workflow on push or pull request, 5 | # but only for the main branch 6 | push: 7 | branches: 8 | - main 9 | pull_request: 10 | branches: 11 | - main 12 | 13 | jobs: 14 | run-linters: 15 | name: Run linters 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - name: Check out Git repository 20 | uses: actions/checkout@v2 21 | 22 | - name: Set up Python 23 | uses: actions/setup-python@v1 24 | with: 25 | python-version: 3.8 26 | 27 | - name: Install Python dependencies 28 | run: pip install black flake8 29 | 30 | - name: Run linters 31 | uses: wearerequired/lint-action@v1 32 | with: 33 | black: true 34 | auto_fix: true 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # vs code 2 | .vscode/* 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | 150 | # PyCharm 151 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 152 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 153 | # and can be added to the global gitignore or merged into this file. For a more nuclear 154 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 155 | #.idea/ 156 | 157 | wandb/* 158 | *.ckpt 159 | assets/* 160 | artifacts/* 161 | *.json -------------------------------------------------------------------------------- /ASR_presentation_project_Apavou_Belkada_Tronchon_Zucker.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/ASR_presentation_project_Apavou_Belkada_Tronchon_Zucker.pdf -------------------------------------------------------------------------------- /ASR_report_project_Apavou_Belkada_Tronchon_Zucker.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/ASR_report_project_Apavou_Belkada_Tronchon_Zucker.pdf -------------------------------------------------------------------------------- /Datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/Datasets/__init__.py -------------------------------------------------------------------------------- /Datasets/datamodule.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | import re 4 | import shutil 5 | 6 | import numpy as np 7 | import utils.agent_utils as ag_u 8 | import wandb 9 | from datasets import Audio, load_dataset 10 | from librosa.effects import trim 11 | from phonemizer.backend import EspeakBackend 12 | from phonemizer.separator import Separator 13 | from pytorch_lightning import LightningDataModule 14 | from torch.utils.data import DataLoader 15 | from utils.constant import CHARS_TO_REMOVE_REGEX 16 | from utils.dataset_utils import coll_fn 17 | from utils.logger import init_logger 18 | 19 | 20 | class BaseDataModule(LightningDataModule): 21 | def __init__(self, dataset_param): 22 | super().__init__() 23 | 24 | self.config = dataset_param 25 | self.logger = init_logger("BaseDataModule", "INFO") 26 | self.logger.info( 27 | f"Loading Dataset : {self.config.dataset_name}, language : {self.config.subset}" 28 | ) 29 | 30 | def prepare_data(self) -> None: 31 | return super().prepare_data() 32 | 33 | def load_data(self, split) -> None: 34 | """ 35 | Function to load dataset 36 | """ 37 | 38 | self.logger.info(f"Loading the dataset in load_data: {split}") 39 | 40 | setattr( 41 | self, 42 | f"{split}_save_data_path", 43 | osp.join( 44 | "assets", 45 | "datasets", 46 | f"{split}_{self.config.dataset_name}-{self.config.subset}", 47 | ), 48 | ) 49 | 50 | save_path = getattr(self, f"{split}_save_data_path") 51 | name_file = getattr(self, f"{split}_save_data_path").split("/")[-1] 52 | name_file_path = osp.join(save_path, name_file) 53 | name_dataset = f"{split}_dataset" 54 | 55 | ag_u.create_directory(save_path) 56 | 57 | if not osp.exists(name_file_path) or self.config.create_dataset: 58 | # if not self.config.create_dataset: 59 | # # try 60 | # path = f"asr-project/{self.config.wandb_project}/{name_file}:latest" 61 | # self.logger.info(f"Try loading {path} in artifacts ...") 62 | 63 | # file = ag_u.get_artifact(path, type="dataset") 64 | 65 | # shutil.copy2(file, save_path) 66 | 67 | # self.logger.info(f"Load {path} in artifacts OK") 68 | 69 | # file = open(name_file_path, "rb") 70 | # setattr(self, name_dataset, pickle.load(file)) 71 | # self.logger.info( 72 | # f"Loaded {split} dataset : {name_file_path}") 73 | # else: 74 | # except 75 | setattr( 76 | self, 77 | name_dataset, 78 | load_dataset( 79 | self.config.dataset_name, 80 | self.config.subset, 81 | split=split if split != "val" else "validation", 82 | use_auth_token=self.config.use_auth_token, 83 | download_mode=self.config.download_mode, 84 | cache_dir=self.config.cache_dir, 85 | ), 86 | ) 87 | 88 | setattr( 89 | self, 90 | name_dataset, 91 | getattr(self, name_dataset).remove_columns( 92 | [ 93 | "accent", 94 | "age", 95 | "client_id", 96 | "down_votes", 97 | "gender", 98 | "locale", 99 | "segment", 100 | "up_votes", 101 | ] 102 | ), 103 | ) 104 | 105 | setattr( 106 | self, 107 | name_dataset, 108 | getattr(self, name_dataset).cast_column( 109 | "audio", Audio(sampling_rate=16000) 110 | ), 111 | ) 112 | 113 | metadata_artifact = { 114 | "dataset_name": self.config.dataset_name, 115 | "subset": self.config.subset, 116 | "split": split, 117 | "sampling_rate": 16000, 118 | } 119 | 120 | self._save_dataset( 121 | split, name_file_path, metadata_artifact, f"{split} dataset" 122 | ) 123 | else: 124 | file = open(name_file_path, "rb") 125 | setattr(self, name_dataset, pickle.load(file)) 126 | 127 | self.sampling_rate = 16000 128 | 129 | self.logger.info(f"Done prepare_data {split}") 130 | 131 | def process_dataset(self, split, processor, batch_size=512): 132 | """ 133 | Function to process data of a dataset (remove indesirable characters and process audio with processor) 134 | """ 135 | 136 | save_path = getattr(self, f"{split}_save_data_path") 137 | name_file = save_path.split("/")[-1] + "_process" 138 | name_file_path = osp.join(save_path, name_file) 139 | name_dataset = f"{split}_dataset" 140 | 141 | if not osp.exists(name_file_path) or self.config.create_dataset: 142 | # if not self.config.create_dataset: 143 | # # try 144 | # path = f"asr-project/{self.config.wandb_project}/{name_file}:latest" 145 | # self.logger.info(f"Try loading {path} in artifacts ...") 146 | 147 | # file = ag_u.get_artifact(path, type="dataset") 148 | 149 | # shutil.copy2(file, save_path) 150 | 151 | # self.logger.info(f"Load {path} in artifacts OK") 152 | 153 | # file = open(name_file_path, "rb") 154 | # setattr(self, name_dataset, pickle.load(file)) 155 | # self.logger.info( 156 | # f"Loaded processed {split} dataset : {name_file_path}") 157 | # else: 158 | # except 159 | self.logger.info(f"Processing {split} dataset ...") 160 | 161 | setattr( 162 | self, 163 | name_dataset, 164 | getattr(self, name_dataset).map( 165 | lambda x: { 166 | "sentence": re.sub( 167 | CHARS_TO_REMOVE_REGEX, "", x["sentence"] 168 | ).lower() 169 | }, 170 | num_proc=self.config.num_proc, 171 | load_from_cache_file=False, 172 | ), 173 | ) 174 | setattr( 175 | self, 176 | name_dataset, 177 | getattr(self, name_dataset).map( 178 | lambda batch: { 179 | "audio": processor( 180 | [ad["array"] for ad in batch["audio"]], sampling_rate=16000 181 | ).input_values 182 | }, 183 | batched=True, 184 | batch_size=batch_size, 185 | num_proc=self.config.num_proc, 186 | load_from_cache_file=False, 187 | ), 188 | ) 189 | 190 | self.logger.info(f"Saving {split} dataset ...") 191 | 192 | metadata_artifact = { 193 | "dataset_name": self.config.dataset_name, 194 | "subset": self.config.subset, 195 | "split": split, 196 | "sampling_rate": self.sampling_rate, 197 | } 198 | 199 | self._save_dataset( 200 | split, name_file_path, metadata_artifact, f"{split} dataset processed" 201 | ) 202 | else: 203 | self.logger.info( 204 | f"{split} dataset already exists no processing necessary ..." 205 | ) 206 | 207 | file = open(name_file_path, "rb") 208 | setattr(self, name_dataset, pickle.load(file)) 209 | self.logger.info(f"Loaded processed {split} dataset : {name_file_path}") 210 | 211 | def filtered_data(self, split, top_db=15) -> None: 212 | """ 213 | Function to filter dataset (remove silence and remove long audio ) 214 | """ 215 | 216 | self.logger.info(f"Filtering {split} dataset ...") 217 | 218 | save_path = getattr(self, f"{split}_save_data_path") 219 | name_file = f"{save_path.split('/')[-1]}_filter_{top_db}_{self.config.max_input_length_in_sec}" 220 | name_file_path = osp.join(save_path, name_file) 221 | name_dataset = f"{split}_dataset" 222 | 223 | if not osp.exists(name_file_path) or self.config.create_dataset: 224 | # if not self.config.create_dataset: 225 | # # try 226 | # path = f"asr-project/{self.config.wandb_project}/{name_file}:latest" 227 | # self.logger.info(f"Try loading {path} in artifacts ...") 228 | 229 | # file = ag_u.get_artifact(path, type="dataset") 230 | 231 | # shutil.copy2(file, getattr(self, f'{split}_save_data_path')) 232 | 233 | # file = open(name_file_path, "rb") 234 | # setattr(self, name_dataset, pickle.load(file)) 235 | # self.logger.info( 236 | # f"Loaded filtered {split} dataset : {name_file_path}") 237 | # else: 238 | # # except 239 | self.logger.info( 240 | f"Length {split} dataset before filter {len(getattr(self, name_dataset))}" 241 | ) 242 | 243 | setattr( 244 | self, 245 | name_dataset, 246 | getattr(self, name_dataset).map( 247 | lambda x: {"audio": trim(np.array(x["audio"]), top_db=top_db)[0]}, 248 | num_proc=self.config.num_proc, 249 | load_from_cache_file=False, 250 | ), 251 | ) 252 | setattr( 253 | self, 254 | name_dataset, 255 | getattr(self, name_dataset).filter( 256 | lambda x: len(x["audio"]) 257 | < self.config.max_input_length_in_sec * self.sampling_rate, 258 | num_proc=self.config.num_proc, 259 | load_from_cache_file=False, 260 | ), 261 | ) 262 | 263 | self.logger.info( 264 | f"Length {split} dataset after filter {len(getattr(self, name_dataset))}" 265 | ) 266 | 267 | metadata_artifact = { 268 | "dataset_name": self.config.dataset_name, 269 | "subset": self.config.subset, 270 | "split": split, 271 | "sampling_rate": self.sampling_rate, 272 | "top_db": top_db, 273 | "max_input_length_in_sec": self.config.max_input_length_in_sec, 274 | } 275 | 276 | self._save_dataset( 277 | split, 278 | name_file_path, 279 | metadata_artifact, 280 | f"{split} dataset processed and filtered", 281 | ) 282 | 283 | else: 284 | file = open(name_file_path, "rb") 285 | setattr(self, name_dataset, pickle.load(file)) 286 | self.logger.info(f"Loaded filtered {split} dataset : {name_file_path}") 287 | 288 | self.logger.info(f"Length {split} dataset : {len(getattr(self, name_dataset))}") 289 | 290 | def create_phonemes(self, split) -> None: 291 | """ 292 | Function to phonemize all sentence of the dataset 293 | """ 294 | 295 | language = ( 296 | self.config.language[:2] if self.config.language[:2] != "zh" else "cmn" 297 | ) 298 | self.logger.info(f"Creating {split} phonemes language {language}...") 299 | backend = EspeakBackend(language) 300 | separator = Separator(phone=" ", word="| ", syllable="") 301 | 302 | name_dataset = f"{split}_dataset" 303 | 304 | setattr( 305 | self, 306 | name_dataset, 307 | getattr(self, name_dataset).add_column( 308 | "phonemes", 309 | backend.phonemize( 310 | getattr(self, name_dataset)["sentence"], 311 | njobs=self.config.num_proc, 312 | separator=separator, 313 | ), 314 | ), 315 | ) 316 | 317 | def _save_dataset( 318 | self, split, name_file_path, metadata_artifact, description_artifact 319 | ): 320 | 321 | file = open(name_file_path, "wb") 322 | pickle.dump(getattr(self, f"{split}_dataset"), file) 323 | 324 | self.logger.info(f"Saved to {name_file_path}") 325 | 326 | self.push_artefact(name_file_path, metadata_artifact, description_artifact) 327 | 328 | def push_artefact(self, path_artifact, metadata, description): 329 | artifact = wandb.Artifact( 330 | name=osp.basename(path_artifact), 331 | type="dataset", 332 | metadata=metadata, 333 | description=description, 334 | ) 335 | artifact.add_file(path_artifact) 336 | wandb.log_artifact(artifact, aliases=["latest"]) 337 | 338 | def setup(self, stage=None): 339 | # Build dataset 340 | if stage in (None, "fit"): 341 | 342 | self.filtered_data("train") 343 | self.filtered_data("val") 344 | 345 | self.create_phonemes("train") 346 | self.create_phonemes("val") 347 | 348 | if stage == "test": 349 | self.create_phonemes("test") 350 | 351 | if stage == "predict": 352 | self.dataset = load_dataset( 353 | self.config.dataset_name, 354 | self.config.subset, 355 | split="other", 356 | use_auth_token=self.config.use_auth_token, 357 | download_mode=self.config.download_mode, 358 | cache_dir=self.config.cache_dir, 359 | ) 360 | 361 | def train_dataloader(self): 362 | train_loader = DataLoader( 363 | self.train_dataset, 364 | shuffle=True, 365 | batch_size=self.config.batch_size, 366 | num_workers=self.config.num_workers, 367 | collate_fn=coll_fn, 368 | ) 369 | return train_loader 370 | 371 | def val_dataloader(self): 372 | val_loader = DataLoader( 373 | self.val_dataset, 374 | shuffle=False, 375 | batch_size=self.config.batch_size, 376 | num_workers=self.config.num_workers, 377 | collate_fn=coll_fn, 378 | ) 379 | return val_loader 380 | 381 | def test_dataloader(self): 382 | val_loader = DataLoader( 383 | self.test_dataset, 384 | shuffle=False, 385 | batch_size=self.config.batch_size, 386 | num_workers=self.config.num_workers, 387 | collate_fn=coll_fn, 388 | ) 389 | return val_loader 390 | 391 | def predict_dataloader(self): 392 | predict_loader = DataLoader( 393 | self.dataset, 394 | batch_size=self.config.batch_size, 395 | num_workers=self.config.num_workers, 396 | shuffle=False, 397 | pin_memory=True, 398 | collate_fn=coll_fn, 399 | ) 400 | return predict_loader 401 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multilingual-PR 2 | 3 | Implementation of the project ```Self-supervised pretraining for phoneme recognition, and generalization on foreign languages``` 4 | 5 | > Authors: [Apavou Clément](https://github.com/clementapa) & [Belkada Younes](https://github.com/younesbelkada) & [Leo Tronchon](https://github.com/leot13) & [Arthur Zucker](https://github.com/ArthurZucker) 6 | 7 | ![Python](https://img.shields.io/badge/Python-green.svg?style=plastic) 8 | ![PyTorch](https://img.shields.io/badge/PyTorch-orange.svg?style=plastic) 9 | ![PyTorch Lightning](https://img.shields.io/badge/PyTorch-Lightning-blueviolet.svg?style=plastic) 10 | 11 |

12 | 13 |

14 | 15 | This repository is powered by HuggingFace :hugs:, Pytorch-Lightning and Weight & Biases. 16 | 17 | ## :bird: Introduction 18 | 19 | The scarcity of annotated data, and the heavy cost of producing them, limits our ability to train deep neural network for audio processing tasks.Therefore, the speech community developed feature learning methods with a minimal need for annotated data, which mostly fall under unsupervised and self-supervised techniques. 20 | 21 | Recently, the rise of self-supervised learning methods for textual modality has outperformed state-of-the-art methods on downstream tasks, by fine-tuning the pretrained models on a relatively small amount of data. These approaches have recently been tested for other modalities such as images and audios. 22 | 23 | Phoneme recognition is an exciting challenge that involves processing a raw audio recording and predict the corresponding sequence of phonemes that are pronounced by the speaker. Throughout this project, we will compare specifically three different self-supervised models, Wav2vec (2019, 2020), HuBERT (2021) and WavLM (2022) pretrained on a corpus of English speech that we will use in various ways to perform phoneme recognition for different languages with a network trained with Connectionist Temporal Classification (CTC) algorithm. Different questions will be addressed: 24 | 25 | + *What is the impact of choosing English as a pretrained language, especially for languages that are very different from English? Which method(s) works best for transferring knowledge from English to other languages?* 26 | + *Which method allows to extract the best features for phoneme recognition?* 27 | + *What is the influence of the abundance of training data on the performance of models?* 28 | 29 | In this project, we address these questions by drawing conclusions from our experiments. 30 | 31 | ## :sparkles: Main features 32 | 33 | + Modularity between SOTA models in self-supervision for speech 34 | + Freedom to select any languages available on CommonVoice hosted at [HuggingFace](https://huggingface.co/datasets/common_voice). 35 | + Nice visualization tool through wandb. 36 | 37 | ## :pencil2: Network Architecture for phoneme recognition 38 | 39 |

40 | 41 |

42 |

43 | Diagram of the models used for the experiments. N=22 and h=1024 for HuBERT Large and WavLM Large, and N=11 and h=768 for Wav2vec2 Base and WavLM Base. Made by us. 44 |

45 | 46 | ## :books: Languages for which phoneme dictionaries are available 47 | Dutch (du), Spanish (es), French (fr), Italian (it), Kyrgyz (ky), Russian (ru), Sweedish 48 | (sv), Turkish (tr), Tatar (tt) and Mandarin (zh). From https://github.com/facebookresearch/CPC_audio. 49 | 50 | ## :star2: Usage 51 | 52 | Please refer to our [example notebook](https://github.com/ASR-project/Multilingual-PR/blob/main/train_notebook.ipynb) if you want to train or test a model. To understand the command line arguments that you can use, run: 53 | ``` 54 | Hparams ['parameters.hparams']: 55 | Hyperparameters of for the run 56 | 57 | --wandb_entity str wandb (default: asr-project) 58 | --debug bool (default: False) 59 | --test bool test code before running, if testing, no checkpoints are written (default: True) 60 | --wandb_project str (default: test-asr) 61 | --root_dir str root_dir (default: /home/arthur/Work/MVA-S2/Speech/Multilingual-PR) 62 | --seed_everything [int] 63 | basic params (default: None) 64 | --gpu int number or gpu (default: 1) 65 | --hparams.max_epochs int 66 | maximum number of epochs (default: 100) 67 | --weights_path str (default: /home/arthur/Work/MVA-S2/Speech/Multilingual-PR/weights) 68 | --tune_lr bool modes (default: False) 69 | --dev_run bool (default: False) 70 | --train bool (default: True) 71 | --best_model str (default: ) 72 | --log_freq_audio int (default: 10) 73 | --log_nb_audio int (default: 2) 74 | --val_check_interval float 75 | trainer params (default: 1.0) 76 | --limit_train_batches float 77 | 1.0 (default: 1.0) 78 | --limit_val_batches float 79 | 1.0 (default: 1.0) 80 | --enable_progress_bar bool 81 | (default: True) 82 | --best_model_run str testing params (default: WavLM_sv) 83 | --early_stopping bool 84 | Early Stopping (default: True) 85 | --early_stopping_params typing.Dict[str, typing.Any] 86 | (default: {'monitor': 'val/per', 'patience': 10, 'mode': 'min', 'verbose': True}) 87 | 88 | DatasetParams ['parameters.data_param']: 89 | Dataset Parameters 90 | ! The batch_size and number of crops should be defined here 91 | 92 | 93 | --dataset_name str Hugging Face datasets parameters (default: common_voice) 94 | --use_auth_token bool 95 | True if use mozilla-foundation datasets (default: False) 96 | --subset str (default: sv-SE) 97 | --download_mode str chosen language (see https://huggingface.co/datasets/common_voice) (default: reuse_dataset_if_exists) 98 | --cache_dir str (default: /home/arthur/Work/MVA-S2/Speech/Multilingual-PR/assets) 99 | --language str to create vocabulary of phonemes (default: sv) 100 | --root_path_annotation str 101 | (default: /home/arthur/Work/MVA-S2/Speech/Multilingual-PR/assets/common_voices_splits) 102 | --phoible_csv_path str 103 | (default: /home/arthur/Work/MVA-S2/Speech/Multilingual-PR/assets) 104 | --num_workers int Dataloader parameters (default: 20) 105 | --batch_size int (default: 2) 106 | --max_input_length_in_sec float 107 | Dataset processing parameters (default: 5) 108 | --num_proc int (default: 4) 109 | --create_dataset bool 110 | (default: False) 111 | 112 | NetworkParams ['parameters.network_param']: 113 | NetworkParams(network_name: str = 'WavLM', pretrained_name: Union[str, NoneType] = '', freeze: bool = True, freeze_transformer: bool = True, eos_token: str = '', bos_token: str = '', unk_token: str = '', pad_token: str = '', word_delimiter_token: str = '|') 114 | 115 | --network_name str Hubert, Wav2Vec2, WavLM (default: WavLM) 116 | --pretrained_name [str] 117 | (default: ) 118 | --freeze bool (default: True) 119 | --freeze_transformer bool 120 | (default: True) 121 | --eos_token str Phoneme Tokenizer (default: ) 122 | --bos_token str (default: ) 123 | --unk_token str (default: ) 124 | --pad_token str (default: ) 125 | --word_delimiter_token str 126 | (default: |) 127 | 128 | OptimizerParams ['parameters.optim_param']: 129 | Optimization parameters 130 | 131 | --optimizer str (default: AdamW) 132 | --lr float (default: 0.02) 133 | --weight_decay float (default: 1e-08) 134 | --accumulate_grad_batches int 135 | 1 for no accumulation (default: 16) 136 | --scheduler [str] Scheduler parameters (default: None) 137 | --optim_param.max_epochs int 138 | Cosine, ReduceLROnPlateau, MultiStepLR, StepLR or None Cosine scheduler (default: 10) 139 | --warmup_epochs int (default: 1) 140 | --warmup_start_lr float 141 | (default: 0.0006) 142 | --eta_min float (default: 5e-06) 143 | --step_size int Step LR scheduler (default: 2) 144 | --gamma float also for multi step lr (default: 0.1) 145 | --milestones str MultiStepLR scheduler (default: [8, 10, 15]) 146 | --min_lr float ReduceLROnPlateau scheduler (default: 5e-09) 147 | --patience int (default: 10) 148 | ``` 149 | 150 | 151 | 152 | ## :sound: Dataset 153 | 154 | The project is based on [Mozilla CommonVoice dataset](https://commonvoice.mozilla.org/fr) available on [HuggingFace](https://huggingface.co/datasets/common_voice). 155 | When the script is launched, the program will automatically download the correct dataset and transform ground truth sentences to phonemes using [phonemizer](https://github.com/bootphon/phonemizer). You are free to chose any dataset available on HuggingFace with phonemes dictionaries previously cited to run your models. For our experiments we use: 156 | ``` 157 | it, nl, tr, ru, sv-SE 158 | ``` 159 | Feel free to try any other languages and submit a Pull Request :electric_plug:. 160 | 161 | ## :paperclip: Pre-trained models 162 | 163 |

164 | 165 | 166 | 167 |

168 |

169 | Schema of Wav2vec2, HuBERT and WavLM. 170 |

171 | 172 | For our experiments, we used models hosted on Hugging Face library, that are pre-trained on 960 hours of **English** audio data from Librispeech dataset on 16kHz sampled speech audio. The following pre-trained models were used: 173 | - Wav2vec2 *Base*: [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) 174 | - WavLM *Base*: [microsoft/wavlm-base](https://huggingface.co/microsoft/wavlm-base) 175 | - WavLM *Large*: [microsoft/wavlm-large](https://huggingface.co/microsoft/wavlm-large) 176 | - HuBERT *Large*: [facebook/hubert-large-ls960-ft](https://huggingface.co/facebook/hubert-large-ls960-ft) 177 | 178 | 179 | 180 | ## :family: Language Family 181 | 182 |
183 | 184 | The language family tree can be found in the following figure. This gives insight on the genetic proximity of each language. 185 | 186 |

187 | 188 |

189 | 190 | | Language | Family | Proximity with English | 191 | |----------|--------|------------------------| 192 | | Italian :it: | *Romance* | 47.8 | 193 | | Russian :ru: | *East Slavic* | 60.3 | 194 | | Dutch 🇳🇱 | *West Germanic* | 27.2 | 195 | | Swedish 🇸🇪 | *North Germanic* | 26.7 | 196 | | Turkish :tr: | *Turkic* | 92.0 | 197 | 198 | 199 |
200 |

201 | Genetic proximity between languages studied and english computed [here](http://www.elinguistics.net/Compare_Languages.aspx). [1, 30]: Highly related languages, [30, 50]: Related languages, [50, 70]: Remotely related languages, [70, 78]: Very remotely related languages, [78, 100]: No recognizable relationship. 202 |

203 | 204 | **English** is a part of the *West Germanic* family.\ 205 | Source: https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md and http://www.elinguistics.net/Compare_Languages.aspx 206 | 207 | ## :chart_with_upwards_trend: Main results 208 | 209 | dataset: Common Voice Corpus 6.1 : https://commonvoice.mozilla.org/fr/datasets 210 | 211 | Pretrained English models to other languages 212 | 213 | ### 🚀 Fine-tuning 214 | 215 | 216 | 217 | 218 | 219 | | Language | Training data (in hours) | Model| PER validation | PER test | Runs| 220 | |-|-|-|-|-|-| 221 | | Italian :it: | 62\.34| Wav2Vec2 *Base* | 19\.05| 17\.95 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 222 | | | | Hubert *Large* | **14\.05** | **12\.67** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 223 | | | | WavLM *Base* | 19\.83 | 25\.60 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 224 | | Russian :ru: | 15\.55 | Wav2Vec2 *Base* | 32\.16 | 31\.66 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 225 | | | | Hubert *Large* | 25\.10 | 24\.09 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 226 | | | | WavLM *Base* | **20\.25** | **18\.88** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 227 | | Dutch 🇳🇱 | 12\.78 | Wav2Vec2 *Base* | 16\.18 | 20\.83 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 228 | | | | Hubert *Large* | **12\.77** | **16\.49** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 229 | | | | WavLM *Base* | 15\.96 | 19\.91 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 230 | | Swedish 🇸🇪 | 3\.22 | Wav2Vec2 *Base* | 26\.50 | 24\.16 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 231 | | | | Hubert *Large* | **21\.77** | **19\.38** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 232 | | | | WavLM *Base* | 26\.86 | 24\.61 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 233 | | Turkish :tr: | 2\.52 | Wav2Vec2 *Base* | 19\.62 | 19\.03 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 234 | | | | Hubert *Large* | **15\.51** | **14\.19** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 235 | | | | WavLM *Base* | 19\.85 | 18\.95 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 236 | | Average | \- | Wav2Vec2 *Base* | 22\.70 | 22\.73 | | 237 | | | | Hubert *Large* | **17\.84** | **17\.36** | | 238 | | | | WavLM *Base* | 20\.55 | 21\.59 | | 239 | 240 |

241 | Table of experiments when models are **fine tuned**. Here, we compare 3 different pretrained models. The models were fine tuned on the phoneme recognition task with different languages and a varying amount of training data. 242 |

243 | 244 | ### 🧊 Frozen Features 245 | 246 | | Language | Training data (in hours) | Model | PER validation | PER test | Runs | 247 | |-|-|-|-|-|-| 248 | | Italian :it: | 62\.34 | Wav2Vec2 *Base* | 38\.94 | 36\.84 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 249 | | | | WavLM *Base* | **27\.29** | **25\.98** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 250 | | | | Hubert *Large* | 23\.85 | 21\.15 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 251 | | | | WavLM *Large* | **21\.02** | **18\.80** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 252 | | Russian :ru: | 15\.55 | Wav2Vec2 *Base* | 50\.11 | 48\.69 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 253 | | | | WavLM *Base* | **40\.66** | **38\.76** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 254 | | | | Hubert *Large* | 38\.36 | 36\.18 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 255 | | | | WavLM *Large* | **34\.48** | **32\.26** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 256 | | Dutch 🇳🇱| 12\.78 | Wav2Vec2 *Base* | 40\.15 | 39\.23 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 257 | | | | WavLM *Base* | **34\.94** | **35\.67** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 258 | | | | Hubert *Large* | **27\.62** | **26\.68** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 259 | | | | WavLM *Large* | 27\.71 | 27\.19 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 260 | | Swedish 🇸🇪 | 3\.22 | Wav2Vec2 *Base* | 50\.30 | 45\.23 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 261 | | | | WavLM *Base* | **43\.65** | **40\.55** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 262 | | | | Hubert *Large* | 37\.34 | **32\.68** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 263 | | | | WavLM *Large* | **37\.25** | 33\.14 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 264 | | Turkish :tr: | 2\.52 | Wav2Vec2 *Base* | 53\.92 | 52\.08 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 265 | | | | WavLM *Base* | **47\.18** | **45\.53** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 266 | | | | Hubert *Large* | 39\.55 | 37\.08 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 267 | | | | WavLM *Large* | **30\.66** | **30\.14** | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 268 | | Average | \- | Wav2Vec2 *Base* | 46\.68 | 44\.41 | | 269 | | | | WavLM *Base* | **38\.74** | **37\.30** | | 270 | | | | Hubert *Large* | 33\.34 | 30\.75 | | 271 | | | | WavLM *Large* | **30\.22** | **28\.31** | | 272 | 273 |

274 | Table of experiments using **frozen features**. Here, we compare 4 different pretrained models. The objective was to train a linear layer, using pretrained models' frozen features, on the phoneme recognition task with different languages and a varying amount of training data. 275 |

276 | 277 | ### ⌚ Training data 278 | 279 | | Training set | Training data | Model | PER validation | PER test | Runs | 280 | |-|-|-|-|-|-| 281 | | 5% | \~ 10 min | Wav2Vec2 *Base* | 55\.35 | 50\.91 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 282 | | || Hubert *Large* | 44\.96 | 39\.38 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 283 | | || WavLM *Base* | 56\.22 | 51\.25 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 284 | | 10% | \~ 20 min | Wav2Vec2 *Base* | 52\.97 | 49\.01 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 285 | | || Hubert *Large* | 42\.61 | 37\.50 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 286 | | || WavLM *Base* | 46\.54 | 43\.64 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 287 | | 50% | \~ 2 h | Wav2Vec2 *Base* | 51\.23 | 46\.24 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 288 | | || Hubert *Large* | 39\.91 | 35\.27 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 289 | | || WavLM *Base* | 44\.57 | 42\.33 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 290 | | 100% | \~ 3 h | Wav2Vec2 *Base* | 50\.30 | 45\.23 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 291 | | || Hubert *Large* | 37\.34 | 32\.68 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 292 | | || WavLM *Base* | 43\.65 | 40\.55 | [![](https://github.com/wandb/assets/blob/main/wandb-github-badge-gradient.svg)]() | 293 | 294 | 295 |

296 | Variation in the amount of training data with frozen features of models pre-trained with the 3 different methods. Language: Swedish 🇸🇪. 297 |

298 | 299 |

300 | 301 | 302 |

303 |

304 | PER on the test and validation sets vs Training data for the Swedish language with frozen features. 305 |

306 | 307 | ## :pushpin: Project structure 308 | 309 | ``` 310 | ├── agents 311 | | ├── BaseTrainer.py 312 | | 313 | ├── assets # database and vocab phonemes are put here 314 | | 315 | ├── config 316 | | ├── hparams.py # configuration file 317 | | 318 | ├── Datasets 319 | | | 320 | | ├── datamodule.py # datamodules PyTorch lightning for CommonVoice dataset 321 | | 322 | ├── models 323 | | ├── BaseModule.py # lightning module 324 | | ├── models.py # Wav2vec2 WavLM and Hubert using Hugging Face library 325 | | 326 | ├── utils # utils functions 327 | | ├── agent_utils.py 328 | | ├── callbacks.py 329 | | ├── dataset_utils.py 330 | | ├── logger.py 331 | | ├── metrics.py 332 | | ├── per.py # torch metrics implementation of the phoneme error rate 333 | | 334 | ├── hparams.py # configuration file 335 | | 336 | ├── main.py # main script to launch for training of inference 337 | | 338 | └── README.md 339 | ``` 340 | 341 | ## ⚡ Powered by 342 | 343 |

344 | 345 | logo hugging face 346 | 347 | logo wandb 348 | 349 | logo pytorch lightning 350 |

351 | -------------------------------------------------------------------------------- /agents/BaseTrainer.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import wandb 4 | from models.BaseModule import BaseModule 5 | from pytorch_lightning.callbacks import ( 6 | LearningRateMonitor, 7 | RichProgressBar, 8 | EarlyStopping, 9 | ) 10 | from utils.agent_utils import get_artifact, get_datamodule 11 | from utils.callbacks import ( 12 | AutoSaveModelCheckpoint, 13 | LogMetricsCallback, 14 | LogAudioPrediction, 15 | ) 16 | from utils.logger import init_logger 17 | 18 | from utils.dataset_utils import create_vocabulary, create_vocabulary2 19 | 20 | 21 | class BaseTrainer: 22 | def __init__(self, config, run=None) -> None: 23 | self.config = config.hparams 24 | self.wb_run = run 25 | self.network_param = config.network_param 26 | 27 | self.logger = init_logger("BaseTrainer", "INFO") 28 | 29 | self.logger.info( 30 | f"Create vocabulary language : {config.data_param.language} ..." 31 | ) 32 | 33 | if config.data_param.subset == "en": 34 | ( 35 | config.network_param.vocab_file, 36 | config.network_param.len_vocab, 37 | ) = create_vocabulary( 38 | config.data_param.language, 39 | config.data_param.phoible_csv_path, 40 | eos_token=config.network_param.eos_token, 41 | bos_token=config.network_param.bos_token, 42 | unk_token=config.network_param.unk_token, 43 | pad_token=config.network_param.pad_token, 44 | word_delimiter_token=config.network_param.word_delimiter_token, 45 | ) 46 | else: 47 | ( 48 | config.network_param.vocab_file, 49 | config.network_param.len_vocab, 50 | ) = create_vocabulary2( 51 | config.data_param.language, 52 | config.data_param.root_path_annotation, 53 | eos_token=config.network_param.eos_token, 54 | bos_token=config.network_param.bos_token, 55 | unk_token=config.network_param.unk_token, 56 | pad_token=config.network_param.pad_token, 57 | word_delimiter_token=config.network_param.word_delimiter_token, 58 | ) 59 | 60 | self.logger.info(f"Vocabulary file : {config.network_param.vocab_file}") 61 | 62 | self.logger.info("Loading Data module...") 63 | self.datamodule = get_datamodule(config.data_param) 64 | 65 | self.logger.info("Loading Model module...") 66 | self.pl_model = BaseModule(config.network_param, config.optim_param) 67 | 68 | self.wb_run.watch(self.pl_model.model) 69 | 70 | def run(self): 71 | if self.config.tune_lr: 72 | tune_lr_trainer = pl.Trainer( 73 | logger=self.wb_run, 74 | gpus=self.config.gpu, 75 | auto_lr_find=True, 76 | accelerator="auto", 77 | default_root_dir=self.wb_run.save_dir, 78 | ) 79 | tune_lr_trainer.logger = self.wb_run 80 | 81 | if not self.config.debug: 82 | torch.autograd.set_detect_anomaly(False) 83 | torch.autograd.profiler.profile(False) 84 | torch.autograd.profiler.emit_nvtx(False) 85 | torch.backends.cudnn.benchmark = True 86 | 87 | trainer = pl.Trainer( 88 | logger=self.wb_run, # W&B integration 89 | callbacks=self.get_callbacks(), 90 | gpus=self.config.gpu, # use all available GPU's 91 | max_epochs=self.config.max_epochs, # number of epochs 92 | log_every_n_steps=1, 93 | fast_dev_run=self.config.dev_run, 94 | amp_backend="apex", 95 | enable_progress_bar=self.config.enable_progress_bar, 96 | val_check_interval=self.config.val_check_interval, 97 | limit_train_batches=self.config.limit_train_batches, 98 | limit_val_batches=self.config.limit_val_batches, 99 | accumulate_grad_batches=self.config.accumulate_grad_batches, 100 | ) 101 | 102 | trainer.logger = self.wb_run 103 | 104 | self.datamodule.load_data("train") 105 | self.datamodule.process_dataset("train", self.pl_model.processor) 106 | 107 | self.datamodule.load_data("val") 108 | self.datamodule.process_dataset("val", self.pl_model.processor) 109 | 110 | if self.config.tune_lr: 111 | tune_lr_trainer.tune(self.pl_model, datamodule=self.datamodule) 112 | 113 | trainer.fit(self.pl_model, datamodule=self.datamodule) 114 | 115 | @torch.no_grad() 116 | def predict(self): 117 | if not self.config.debug: 118 | torch.autograd.set_detect_anomaly(False) 119 | torch.autograd.profiler.profile(False) 120 | torch.autograd.profiler.emit_nvtx(False) 121 | torch.backends.cudnn.benchmark = True 122 | 123 | trainer = pl.Trainer( 124 | logger=self.wb_run, # W&B integration 125 | callbacks=self.get_callbacks(), 126 | gpus=self.config.gpu, # use all available GPU's 127 | log_every_n_steps=1, 128 | fast_dev_run=self.config.dev_run, 129 | amp_backend="apex", 130 | enable_progress_bar=self.config.enable_progress_bar, 131 | ) 132 | 133 | trainer.logger = self.wb_run 134 | 135 | self.datamodule.load_data("test") 136 | self.datamodule.process_dataset("test", self.pl_model.processor) 137 | 138 | path_model = f"{self.config.wandb_entity}/{self.config.wandb_project}/{self.config.best_model_run}:top-1" 139 | best_model_path = get_artifact(path_model, type="model") 140 | 141 | trainer.test(self.pl_model, self.datamodule, ckpt_path=best_model_path) 142 | 143 | return 144 | 145 | def get_callbacks(self): 146 | callbacks = [ 147 | LearningRateMonitor(), 148 | LogMetricsCallback(), 149 | LogAudioPrediction(self.config.log_freq_audio, self.config.log_nb_audio), 150 | ] 151 | 152 | if self.config.enable_progress_bar: 153 | callbacks += [RichProgressBar()] 154 | 155 | if self.config.early_stopping: 156 | callbacks += [EarlyStopping(**self.config.early_stopping_params)] 157 | 158 | monitor = "val/per" 159 | mode = "min" 160 | wandb.define_metric(monitor, summary=mode) 161 | save_top_k = 1 162 | every_n_epochs = 1 163 | callbacks += [ 164 | AutoSaveModelCheckpoint( # ModelCheckpoint 165 | config=(self.network_param).__dict__, 166 | project=self.config.wandb_project, 167 | entity=self.config.wandb_entity, 168 | monitor=monitor, 169 | mode=mode, 170 | filename="epoch-{epoch:02d}-val_per={val/per:.2f}", 171 | verbose=True, 172 | dirpath=self.config.weights_path + f"/{str(wandb.run.name)}", 173 | save_top_k=save_top_k, 174 | every_n_epochs=every_n_epochs, 175 | auto_insert_metric_name=False, 176 | ) 177 | ] # our model checkpoint callback 178 | 179 | return callbacks 180 | -------------------------------------------------------------------------------- /agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/agents/__init__.py -------------------------------------------------------------------------------- /assets/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/.keep -------------------------------------------------------------------------------- /assets/img_readme/ASR_language family.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/ASR_language family.png -------------------------------------------------------------------------------- /assets/img_readme/Network.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/Network.drawio.png -------------------------------------------------------------------------------- /assets/img_readme/PER test vs Training data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/PER test vs Training data.png -------------------------------------------------------------------------------- /assets/img_readme/PER validation vs Training data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/PER validation vs Training data.png -------------------------------------------------------------------------------- /assets/img_readme/hubert.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/hubert.jpeg -------------------------------------------------------------------------------- /assets/img_readme/parrot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/parrot.png -------------------------------------------------------------------------------- /assets/img_readme/wav2vec2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/wav2vec2.png -------------------------------------------------------------------------------- /assets/img_readme/wavlm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/wavlm.png -------------------------------------------------------------------------------- /conda_environment.yml: -------------------------------------------------------------------------------- 1 | name: speech-project 2 | channels: 3 | - pytorch 4 | - huggingface 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_llvm 10 | - abseil-cpp=20210324.2=h2531618_0 11 | - absl-py=0.15.0=pyhd3eb1b0_0 12 | - aiohttp=3.8.1=py38h7f8727e_0 13 | - aiosignal=1.2.0=pyhd3eb1b0_0 14 | - appdirs=1.4.4=pyh9f0ad1d_0 15 | - arrow-cpp=6.0.1=py38h9de81b1_5_cpu 16 | - async-timeout=4.0.1=pyhd3eb1b0_0 17 | - attrs=21.4.0=pyhd3eb1b0_0 18 | - audioread=2.1.9=py38h578d9bd_2 19 | - aws-c-auth=0.6.8=hadad3cd_1 20 | - aws-c-cal=0.5.12=h70efedd_7 21 | - aws-c-common=0.6.17=h7f98852_0 22 | - aws-c-compression=0.2.14=h7c7754b_7 23 | - aws-c-event-stream=0.2.7=hd2be095_32 24 | - aws-c-http=0.6.10=h416565a_3 25 | - aws-c-io=0.10.14=he836878_0 26 | - aws-c-mqtt=0.7.10=h885097b_0 27 | - aws-c-s3=0.1.29=h8d70ed6_0 28 | - aws-c-sdkutils=0.1.1=h7c7754b_4 29 | - aws-checksums=0.1.12=h7c7754b_6 30 | - aws-crt-cpp=0.17.10=h6ab17b9_5 31 | - aws-sdk-cpp=1.9.160=h36ff4c5_0 32 | - blas=1.0=mkl 33 | - blinker=1.4=py38h06a4308_0 34 | - bottleneck=1.3.2=py38heb32a55_1 35 | - brotli=1.0.9=h7f98852_6 36 | - brotli-bin=1.0.9=h7f98852_6 37 | - brotlipy=0.7.0=py38h27cfd23_1003 38 | - bzip2=1.0.8=h7b6447c_0 39 | - c-ares=1.18.1=h7f8727e_0 40 | - ca-certificates=2022.3.18=h06a4308_0 41 | - cachetools=4.2.2=pyhd3eb1b0_0 42 | - certifi=2021.10.8=py38h06a4308_2 43 | - cffi=1.15.0=py38hd667e15_1 44 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 45 | - click=7.1.2=pyhd3eb1b0_0 46 | - colorama=0.4.4=pyhd3eb1b0_0 47 | - commonmark=0.9.1=pyhd3eb1b0_0 48 | - configparser=5.0.2=pyhd3eb1b0_0 49 | - cryptography=36.0.0=py38h9ce1e76_0 50 | - cudatoolkit=10.2.89=hfd86e86_1 51 | - cycler=0.11.0=pyhd8ed1ab_0 52 | - dataclasses=0.8=pyh6d0b6a4_7 53 | - datasets=1.18.3=py_0 54 | - decorator=5.1.1=pyhd8ed1ab_0 55 | - dill=0.3.4=pyhd3eb1b0_0 56 | - docker-pycreds=0.4.0=pyhd3eb1b0_0 57 | - ffmpeg=4.2.2=h20bf706_0 58 | - filelock=3.4.2=pyhd3eb1b0_0 59 | - fonttools=4.29.1=py38h497a2fe_0 60 | - freetype=2.11.0=h70c0345_0 61 | - frozenlist=1.2.0=py38h7f8727e_0 62 | - fsspec=2022.1.0=pyhd3eb1b0_0 63 | - future=0.18.2=py38_1 64 | - gettext=0.21.0=hf68c758_0 65 | - gflags=2.2.2=he6710b0_0 66 | - giflib=5.2.1=h7b6447c_0 67 | - gitdb=4.0.7=pyhd3eb1b0_0 68 | - gitpython=3.1.18=pyhd3eb1b0_1 69 | - glog=0.5.0=h2531618_0 70 | - gmp=6.2.1=h2531618_2 71 | - gnutls=3.6.15=he1e5248_0 72 | - google-auth=1.33.0=pyhd3eb1b0_0 73 | - google-auth-oauthlib=0.4.1=py_2 74 | - grpc-cpp=1.42.0=ha1441d3_1 75 | - grpcio=1.42.0=py38hce63b2e_0 76 | - huggingface_hub=0.2.1=pyhd3eb1b0_0 77 | - icu=58.2=he6710b0_3 78 | - idna=3.3=pyhd3eb1b0_0 79 | - importlib-metadata=4.8.2=py38h06a4308_0 80 | - importlib_metadata=4.8.2=hd3eb1b0_0 81 | - intel-openmp=2021.4.0=h06a4308_3561 82 | - jbig=2.1=hdba287a_0 83 | - joblib=1.1.0=pyhd3eb1b0_0 84 | - jpeg=9d=h7f8727e_0 85 | - kiwisolver=1.3.2=py38h1fd1430_1 86 | - krb5=1.19.2=hac12032_0 87 | - lame=3.100=h7b6447c_0 88 | - lcms2=2.12=h3be6417_0 89 | - ld_impl_linux-64=2.35.1=h7274673_9 90 | - lerc=3.0=h295c915_0 91 | - libbrotlicommon=1.0.9=h7f98852_6 92 | - libbrotlidec=1.0.9=h7f98852_6 93 | - libbrotlienc=1.0.9=h7f98852_6 94 | - libcurl=7.80.0=h0b77cf5_0 95 | - libdeflate=1.8=h7f8727e_5 96 | - libedit=3.1.20210910=h7f8727e_0 97 | - libev=4.33=h7f8727e_1 98 | - libevent=2.1.12=h8f2d780_0 99 | - libffi=3.3=he6710b0_2 100 | - libflac=1.3.4=h27087fc_0 101 | - libgcc-ng=11.2.0=h1d223b6_12 102 | - libgfortran-ng=7.5.0=h14aa051_20 103 | - libgfortran4=7.5.0=h14aa051_20 104 | - libidn2=2.3.2=h7f8727e_0 105 | - libllvm11=11.1.0=hf817b99_3 106 | - libnghttp2=1.46.0=hce63b2e_0 107 | - libogg=1.3.5=h27cfd23_1 108 | - libopus=1.3.1=h7b6447c_0 109 | - libpng=1.6.37=hbc83047_0 110 | - libprotobuf=3.19.1=h4ff587b_0 111 | - librosa=0.9.1=pyhd8ed1ab_0 112 | - libsndfile=1.0.31=h9c3ff4c_1 113 | - libssh2=1.9.0=h1ba5d50_1 114 | - libstdcxx-ng=11.2.0=he4da1e4_12 115 | - libtasn1=4.16.0=h27cfd23_0 116 | - libthrift=0.15.0=hcc01f38_0 117 | - libtiff=4.3.0=h6f004c6_2 118 | - libunistring=0.9.10=h27cfd23_0 119 | - libutf8proc=2.6.1=h27cfd23_0 120 | - libuv=1.40.0=h7b6447c_0 121 | - libvorbis=1.3.7=h7b6447c_0 122 | - libvpx=1.7.0=h439df22_0 123 | - libwebp=1.2.2=h55f646e_0 124 | - libwebp-base=1.2.2=h7f8727e_0 125 | - libxml2=2.9.12=h03d6c58_0 126 | - libzlib=1.2.11=h36c2ea0_1013 127 | - llvm-openmp=13.0.1=hf817b99_0 128 | - llvmlite=0.38.0=py38h4630a5e_0 129 | - lz4-c=1.9.3=h295c915_1 130 | - markdown=3.3.4=py38h06a4308_0 131 | - matplotlib-base=3.5.1=py38hf4fb855_0 132 | - mkl=2021.4.0=h06a4308_640 133 | - mkl-service=2.4.0=py38h7f8727e_0 134 | - mkl_fft=1.3.1=py38hd3c417c_0 135 | - mkl_random=1.2.2=py38h51133e4_0 136 | - multidict=5.2.0=py38h7f8727e_2 137 | - multiprocess=0.70.12.2=py38h7f8727e_0 138 | - munkres=1.1.4=pyh9f0ad1d_0 139 | - mypy_extensions=0.4.3=py38h06a4308_1 140 | - ncurses=6.3=h7f8727e_2 141 | - nettle=3.7.3=hbbd107a_1 142 | - numba=0.55.1=py38h4bf6c61_0 143 | - numexpr=2.8.1=py38h6abb31d_0 144 | - numpy=1.21.2=py38h20f2e39_0 145 | - numpy-base=1.21.2=py38h79a1101_0 146 | - oauthlib=3.2.0=pyhd8ed1ab_0 147 | - olefile=0.46=pyhd3eb1b0_0 148 | - openh264=2.1.1=h4ff587b_0 149 | - openssl=1.1.1n=h7f8727e_0 150 | - orc=1.7.1=h1be678f_1 151 | - packaging=21.3=pyhd3eb1b0_0 152 | - pandas=1.4.1=py38h295c915_0 153 | - parquet-cpp=1.5.1=h34088ae_4 154 | - pathtools=0.1.2=pyhd3eb1b0_1 155 | - pillow=8.4.0=py38h5aabda8_0 156 | - pip=21.2.4=py38h06a4308_0 157 | - pooch=1.6.0=pyhd8ed1ab_0 158 | - promise=2.3=py38h06a4308_0 159 | - protobuf=3.19.1=py38h295c915_0 160 | - psutil=5.8.0=py38h27cfd23_1 161 | - pyarrow=6.0.1=py38he7e5f7d_5_cpu 162 | - pyasn1=0.4.8=pyhd3eb1b0_0 163 | - pyasn1-modules=0.2.8=py_0 164 | - pycparser=2.21=pyhd3eb1b0_0 165 | - pydeprecate=0.3.2=pyhd8ed1ab_0 166 | - pygments=2.11.2=pyhd3eb1b0_0 167 | - pyjwt=2.3.0=pyhd8ed1ab_1 168 | - pyopenssl=22.0.0=pyhd3eb1b0_0 169 | - pyparsing=3.0.4=pyhd3eb1b0_0 170 | - pysocks=1.7.1=py38h06a4308_0 171 | - pysoundfile=0.10.3.post1=pyhd3deb0d_0 172 | - python=3.8.12=h12debd9_0 173 | - python-dateutil=2.8.2=pyhd3eb1b0_0 174 | - python-xxhash=2.0.2=py38h7f8727e_0 175 | - python_abi=3.8=1_cp38 176 | - pytorch=1.10.2=py3.8_cuda10.2_cudnn7.6.5_0 177 | - pytorch-lightning=1.5.10=pyhd8ed1ab_0 178 | - pytorch-mutex=1.0=cuda 179 | - pytz=2021.3=pyhd3eb1b0_0 180 | - pyyaml=6.0=py38h7f8727e_1 181 | - re2=2021.11.01=h9c3ff4c_0 182 | - readline=8.1.2=h7f8727e_1 183 | - regex=2021.11.2=py38h7f8727e_0 184 | - requests=2.27.1=pyhd3eb1b0_0 185 | - requests-oauthlib=1.3.0=py_0 186 | - resampy=0.2.2=py_0 187 | - rich=11.2.0=pyhd8ed1ab_0 188 | - rsa=4.7.2=pyhd3eb1b0_1 189 | - s2n=1.3.0=h9b69904_0 190 | - sacremoses=0.0.43=pyhd3eb1b0_0 191 | - scikit-learn=1.0.2=py38h51133e4_1 192 | - scipy=1.7.3=py38hc147768_0 193 | - sentry-sdk=1.5.6=pyhd8ed1ab_0 194 | - setuptools=58.0.4=py38h06a4308_0 195 | - shortuuid=1.0.8=py38h578d9bd_0 196 | - simple-parsing=0.0.18=pyhd8ed1ab_0 197 | - six=1.16.0=pyhd3eb1b0_1 198 | - smmap=4.0.0=pyhd3eb1b0_0 199 | - snappy=1.1.8=he6710b0_0 200 | - sqlite=3.37.2=hc218d9a_0 201 | - subprocess32=3.5.4=py_1 202 | - tensorboard=2.6.0=py_1 203 | - tensorboard-data-server=0.6.0=py38hca6d32c_0 204 | - tensorboard-plugin-wit=1.6.0=py_0 205 | - termcolor=1.1.0=py38h06a4308_1 206 | - threadpoolctl=3.1.0=pyh8a188c0_0 207 | - tk=8.6.11=h1ccaba5_0 208 | - tokenizers=0.10.3=py38hb317417_1 209 | - torchaudio=0.10.2=py38_cu102 210 | - torchmetrics=0.7.2=pyhd8ed1ab_0 211 | - torchvision=0.11.3=py38_cu102 212 | - tqdm=4.62.3=pyhd3eb1b0_1 213 | - typing-extensions=3.10.0.2=hd3eb1b0_0 214 | - typing_extensions=3.10.0.2=pyh06a4308_0 215 | - typing_inspect=0.7.1=pyhd3eb1b0_0 216 | - unicodedata2=14.0.0=py38h497a2fe_0 217 | - urllib3=1.26.8=pyhd3eb1b0_0 218 | - wandb=0.12.10=pyhd8ed1ab_0 219 | - werkzeug=2.0.3=pyhd3eb1b0_0 220 | - wheel=0.37.1=pyhd3eb1b0_0 221 | - x264=1!157.20191217=h7b6447c_0 222 | - xxhash=0.8.0=h7f8727e_3 223 | - xz=5.2.5=h7b6447c_0 224 | - yaml=0.2.5=h7b6447c_0 225 | - yarl=1.6.3=py38h27cfd23_0 226 | - yaspin=2.1.0=pyhd8ed1ab_0 227 | - zipp=3.7.0=pyhd3eb1b0_0 228 | - zlib=1.2.11=h36c2ea0_1013 229 | - zstd=1.5.0=ha4553b6_1 230 | - pip: 231 | - clldutils==3.10.1 232 | - colorlog==6.6.0 233 | - csvw==2.0.0 234 | - dlinfo==1.2.1 235 | - isodate==0.6.1 236 | - phonemizer==3.0.1 237 | - python-graphviz==0.19.1 238 | - rfc3986==1.5.0 239 | - segments==2.2.0 240 | - tabulate==0.8.9 241 | - torchviz==0.0.2 242 | - transformers==4.16.0 243 | - uritemplate==4.1.1 244 | - wget==3.2 245 | prefix: /home/clement/anaconda3/envs/speech-project 246 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/config/__init__.py -------------------------------------------------------------------------------- /config/hparams.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pickle import FALSE 3 | import random 4 | from dataclasses import dataclass 5 | from os import path as osp 6 | from typing import Any, ClassVar, Dict, List, Optional 7 | from simple_parsing.helpers import Serializable, choice, dict_field, list_field 8 | 9 | import pytorch_lightning as pl 10 | import simple_parsing 11 | import torch 12 | import torch.optim 13 | 14 | ################################## Global parameters ################################## 15 | 16 | 17 | @dataclass 18 | class Hparams: 19 | """Hyperparameters of for the run""" 20 | 21 | # wandb 22 | wandb_entity: str = "asr-project" # name of the project 23 | debug: bool = ( 24 | False # test code before running, if testing, no checkpoints are written 25 | ) 26 | test: bool = True 27 | wandb_project: str = f"{'test-'*test}asr" 28 | root_dir: str = os.getcwd() # root_dir 29 | 30 | # basic params 31 | seed_everything: Optional[int] = None # seed for the whole run 32 | gpu: int = 1 # number or gpu 33 | max_epochs: int = 100 # maximum number of epochs 34 | weights_path: str = osp.join(os.getcwd(), "weights") 35 | 36 | # modes 37 | tune_lr: bool = False # tune the model on first run 38 | dev_run: bool = False 39 | train: bool = True 40 | 41 | best_model: str = "" 42 | 43 | log_freq_audio: int = 10 44 | log_nb_audio: int = 2 45 | 46 | # trainer params 47 | val_check_interval: float = 1.0 # 1.0 (at the end of the epoch) 48 | limit_train_batches: float = 1.0 # 1.0 49 | limit_val_batches: float = 1.0 # 1.0 50 | enable_progress_bar: bool = True 51 | 52 | # testing params 53 | best_model_run: str = "WavLM_sv" 54 | 55 | # Early Stopping 56 | early_stopping: bool = True 57 | early_stopping_params: Dict[str, Any] = dict_field( 58 | dict(monitor="val/per", patience=10, mode="min", verbose=True) 59 | ) 60 | 61 | 62 | @dataclass 63 | class NetworkParams: 64 | network_name: str = "WavLM" # Hubert, Wav2Vec2, WavLM 65 | pretrained_name: Optional[str] = "" 66 | 67 | freeze: bool = True 68 | freeze_transformer: bool = True 69 | 70 | # Phoneme Tokenizer 71 | eos_token: str = "
" 72 | bos_token: str = "" 73 | unk_token: str = "" 74 | pad_token: str = "" 75 | word_delimiter_token: str = "|" 76 | 77 | 78 | @dataclass 79 | class DatasetParams: 80 | """Dataset Parameters 81 | ! The batch_size and number of crops should be defined here 82 | """ 83 | 84 | # Hugging Face datasets parameters 85 | dataset_name: str = "common_voice" # https://huggingface.co/mozilla-foundation or https://huggingface.co/datasets/common_voice # dataset, use Eval for FT 86 | use_auth_token: bool = False # True if use mozilla-foundation datasets 87 | subset: str = ( 88 | "sv-SE" # chosen language (see https://huggingface.co/datasets/common_voice) 89 | ) 90 | download_mode: str = "reuse_dataset_if_exists" 91 | cache_dir: str = osp.join(os.getcwd(), "assets") 92 | 93 | # to create vocabulary of phonemes 94 | language: str = "sv" 95 | root_path_annotation: str = osp.join(os.getcwd(), "assets", "common_voices_splits") 96 | phoible_csv_path: str = osp.join(os.getcwd(), "assets") 97 | 98 | # Dataloader parameters 99 | num_workers: int = 20 # number of workers for dataloaders 100 | batch_size: int = 2 101 | 102 | # Dataset processing parameters 103 | max_input_length_in_sec: float = 5 104 | num_proc: int = 4 105 | 106 | create_dataset: bool = False 107 | 108 | 109 | @dataclass 110 | class OptimizerParams: 111 | """Optimization parameters""" 112 | 113 | optimizer: str = "AdamW" 114 | lr: float = 2e-2 115 | weight_decay: float = 1e-8 116 | 117 | accumulate_grad_batches: int = 16 # 1 for no accumulation 118 | 119 | # Scheduler parameters 120 | scheduler: Optional[ 121 | str 122 | ] = None # Cosine, ReduceLROnPlateau, MultiStepLR, StepLR or None 123 | 124 | # Cosine scheduler 125 | max_epochs: int = 10 126 | warmup_epochs: int = 1 127 | warmup_start_lr: float = 6e-4 128 | eta_min: float = 5e-6 129 | 130 | # Step LR scheduler 131 | step_size: int = 2 132 | gamma: float = 0.1 # also for multi step lr 133 | 134 | # MultiStepLR scheduler 135 | milestones: List[Any] = list_field(8, 10, 15) 136 | 137 | # ReduceLROnPlateau scheduler 138 | min_lr: float = 5e-9 139 | patience: int = 10 140 | 141 | 142 | @dataclass 143 | class Parameters: 144 | """base options.""" 145 | 146 | hparams: Hparams = Hparams() 147 | data_param: DatasetParams = DatasetParams() 148 | network_param: NetworkParams = NetworkParams() 149 | optim_param: OptimizerParams = OptimizerParams() 150 | 151 | def __post_init__(self): 152 | """Post-initialization code""" 153 | if self.hparams.seed_everything is None: 154 | self.hparams.seed_everything = random.randint(1, 10000) 155 | 156 | self.hparams.wandb_project = f"{'test-'*self.hparams.test}asr" 157 | 158 | self.network_param.phonemizer_lang = self.data_param.language 159 | print(f"Phonemizer language : {self.network_param.phonemizer_lang }") 160 | 161 | random.seed(self.hparams.seed_everything) 162 | torch.manual_seed(self.hparams.seed_everything) 163 | pl.seed_everything(self.hparams.seed_everything) 164 | 165 | if self.network_param.pretrained_name == "": 166 | if self.network_param.network_name == "Wav2Vec2": 167 | # self.network_param.pretrained_name = "facebook/wav2vec2-xlsr-53-espeak-cv-ft" 168 | self.network_param.pretrained_name = "facebook/wav2vec2-base-960h" 169 | elif self.network_param.network_name == "WavLM": 170 | # self.network_param.pretrained_name = "microsoft/wavlm-base" 171 | self.network_param.pretrained_name = "microsoft/wavlm-large" 172 | elif self.network_param.network_name == "Hubert": 173 | self.network_param.pretrained_name = "facebook/hubert-large-ls960-ft" 174 | else: 175 | raise NotImplementedError( 176 | "Only Wav2Vec2, WavLM and Hubert are available !" 177 | ) 178 | print(f"Pretrained model: {self.network_param.pretrained_name}") 179 | 180 | self.data_param.wandb_project = self.hparams.wandb_project 181 | self.hparams.accumulate_grad_batches = self.optim_param.accumulate_grad_batches 182 | 183 | @classmethod 184 | def parse(cls): 185 | parser = simple_parsing.ArgumentParser() 186 | parser.add_arguments(cls, dest="parameters") 187 | args = parser.parse_args() 188 | instance: Parameters = args.parameters 189 | return instance 190 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import faulthandler 2 | from pytest import param 3 | 4 | faulthandler.enable() 5 | 6 | from pytorch_lightning.loggers import WandbLogger 7 | 8 | # Standard libraries 9 | import wandb 10 | from agents.BaseTrainer import BaseTrainer 11 | from config.hparams import Parameters 12 | from utils.agent_utils import parse_params 13 | 14 | 15 | def main(): 16 | parameters = Parameters.parse() 17 | 18 | # initialize wandb instance 19 | wdb_config = parse_params(parameters) 20 | 21 | if parameters.hparams.train: 22 | tags = [ 23 | parameters.data_param.dataset_name, 24 | parameters.data_param.subset, 25 | parameters.optim_param.optimizer, 26 | parameters.network_param.network_name, 27 | f"{'not'*(not parameters.network_param.freeze)} freezed", 28 | parameters.network_param.pretrained_name, 29 | ] 30 | 31 | if parameters.hparams.limit_train_batches != 1.0: 32 | tags += [f"{parameters.hparams.limit_train_batches}_train"] 33 | if parameters.network_param.freeze_transformer: 34 | tags += ["transformer_freezed"] 35 | 36 | wandb.init( 37 | name=f"{parameters.network_param.network_name}_{parameters.data_param.language}{'_CNN_not_freezed'*(not parameters.network_param.freeze)}{f'_{parameters.hparams.limit_train_batches}_train'*(parameters.hparams.limit_train_batches!=1.0)}{'_tf_freezed'*(parameters.network_param.freeze_transformer)}", 38 | config=wdb_config, 39 | project=parameters.hparams.wandb_project, 40 | entity=parameters.hparams.wandb_entity, 41 | allow_val_change=True, 42 | job_type="train", 43 | tags=tags, 44 | ) 45 | 46 | wandb_run = WandbLogger( 47 | config=wdb_config, 48 | project=parameters.hparams.wandb_project, 49 | entity=parameters.hparams.wandb_entity, 50 | allow_val_change=True, 51 | ) 52 | 53 | agent = BaseTrainer(parameters, wandb_run) 54 | agent.run() 55 | else: 56 | tags = [ 57 | parameters.data_param.dataset_name, 58 | parameters.data_param.subset, 59 | parameters.data_param.language, 60 | parameters.network_param.network_name, 61 | f"{'not'*(not parameters.network_param.freeze)} freezed", 62 | parameters.network_param.pretrained_name, 63 | "test", 64 | ] 65 | if parameters.hparams.limit_train_batches != 1.0: 66 | tags += [f"{parameters.hparams.limit_train_batches}_train"] 67 | if parameters.network_param.freeze_transformer: 68 | tags += ["transformer_freezed"] 69 | 70 | wandb_run = wandb.init( 71 | name=parameters.hparams.best_model_run + "_test", 72 | config=wdb_config, 73 | project=parameters.hparams.wandb_project, 74 | entity=parameters.hparams.wandb_entity, 75 | allow_val_change=True, 76 | job_type="test", 77 | tags=tags, 78 | ) 79 | 80 | wandb_logger = WandbLogger( 81 | config=wdb_config, 82 | project=parameters.hparams.wandb_project, 83 | entity=parameters.hparams.wandb_entity, 84 | allow_val_change=True, 85 | ) 86 | 87 | agent = BaseTrainer(parameters, wandb_logger) 88 | agent.predict() 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /models/BaseModule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pytorch_lightning import LightningModule 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, MultiStepLR 6 | from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR 7 | from transformers import ( 8 | Wav2Vec2PhonemeCTCTokenizer, 9 | Wav2Vec2Processor, 10 | Wav2Vec2FeatureExtractor, 11 | ) 12 | from utils.agent_utils import get_model 13 | from utils.logger import init_logger 14 | 15 | from itertools import chain 16 | 17 | 18 | from torch.profiler import profile, record_function, ProfilerActivity 19 | 20 | 21 | class BaseModule(LightningModule): 22 | def __init__(self, network_param, optim_param): 23 | """ 24 | method used to define our model parameters 25 | """ 26 | super(BaseModule, self).__init__() 27 | 28 | logger = init_logger("BaseModule", "INFO") 29 | 30 | # Optimizer 31 | self.optim_param = optim_param 32 | self.lr = optim_param.lr 33 | 34 | logger.info(f"Optimizer : {optim_param.optimizer}, lr : {optim_param.lr}") 35 | 36 | # Tokenizer 37 | # https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py 38 | self.phonemes_tokenizer = Wav2Vec2PhonemeCTCTokenizer( 39 | vocab_file=network_param.vocab_file, 40 | eos_token=network_param.eos_token, 41 | bos_token=network_param.bos_token, 42 | unk_token=network_param.unk_token, 43 | pad_token=network_param.pad_token, 44 | word_delimiter_token=network_param.word_delimiter_token, 45 | do_phonemize=False, 46 | return_attention_mask=False, 47 | ) 48 | 49 | network_param.vocab_size = self.phonemes_tokenizer.vocab_size 50 | 51 | # Loss function 52 | self.loss = nn.CTCLoss( 53 | blank=self.phonemes_tokenizer.encoder[network_param.word_delimiter_token] 54 | ) 55 | 56 | # Feature_extractor 57 | feature_extractor = Wav2Vec2FeatureExtractor( 58 | feature_size=1, 59 | sampling_rate=16000, 60 | padding_value=0.0, 61 | do_normalize=True, 62 | return_attention_mask=False, 63 | ) 64 | 65 | logger.info(f"Features extractor : {network_param.network_name}") 66 | self.processor = Wav2Vec2Processor( 67 | feature_extractor=feature_extractor, tokenizer=self.phonemes_tokenizer 68 | ) 69 | 70 | # Model 71 | self.model = get_model(network_param.network_name, network_param) 72 | logger.info(f"Model: {network_param.network_name}") 73 | 74 | if network_param.freeze: 75 | self.model.model.freeze_feature_extractor() 76 | 77 | logger.info(f"Feature extactor:{' not'*(not network_param.freeze)} freezed") 78 | 79 | if network_param.freeze_transformer: 80 | self.model.model.requires_grad_(False) 81 | self.model.model.lm_head.requires_grad_(True) 82 | 83 | def forward(self, x): 84 | output = self.model(x) 85 | return output 86 | 87 | def training_step(self, batch, batch_idx): 88 | """needs to return a loss from a single batch""" 89 | loss, logits, preds, targets = self._get_outputs(batch, batch_idx) 90 | if loss != loss: 91 | print("loss is nan, model collapse, exiting") 92 | exit(1) 93 | # Log loss 94 | self.log("train/loss", loss, batch_size=len(preds)) 95 | 96 | return { 97 | "loss": loss, 98 | "logits": logits.detach(), 99 | "preds": preds, 100 | "targets": targets, 101 | } 102 | 103 | def validation_step(self, batch, batch_idx): 104 | """used for logging metrics""" 105 | loss, logits, preds, targets = self._get_outputs(batch, batch_idx) 106 | 107 | # Log loss 108 | self.log("val/loss", loss) 109 | 110 | return {"loss": loss, "logits": logits, "preds": preds, "targets": targets} 111 | 112 | def test_step(self, batch, batch_idx): 113 | """used for logging metrics""" 114 | loss, logits, preds, targets = self._get_outputs(batch, batch_idx) 115 | 116 | # Log loss 117 | self.log("test/loss", loss) 118 | 119 | return {"loss": loss, "logits": logits, "preds": preds, "targets": targets} 120 | 121 | def configure_optimizers(self): 122 | """defines model optimizer""" 123 | optimizer = getattr(torch.optim, self.optim_param.optimizer) 124 | optimizer = optimizer( 125 | self.parameters(), lr=self.lr, weight_decay=self.optim_param.weight_decay 126 | ) 127 | 128 | if self.optim_param.scheduler != None: 129 | if self.optim_param.scheduler == "Cosine": 130 | scheduler = LinearWarmupCosineAnnealingLR( 131 | optimizer, 132 | warmup_epochs=self.optim_param.warmup_epochs, 133 | max_epochs=self.optim_param.max_epochs, 134 | warmup_start_lr=self.optim_param.warmup_start_lr, 135 | eta_min=self.optim_param.eta_min, 136 | ) 137 | elif self.optim_param.scheduler == "StepLR": 138 | scheduler = StepLR( 139 | optimizer, 140 | step_size=self.optim_param.step_size, 141 | gamma=self.optim_param.gamma, 142 | ) 143 | elif self.optim_param.scheduler == "MultiStepLR": 144 | scheduler = MultiStepLR( 145 | optimizer, 146 | milestones=self.optim_param.milestones, 147 | gamma=self.optim_param.gamma, 148 | ) 149 | else: 150 | scheduler = { 151 | "scheduler": ReduceLROnPlateau( 152 | optimizer, 153 | mode="min", 154 | patience=self.optim_param.patience, 155 | min_lr=self.optim_param.min_lr, 156 | ), 157 | "monitor": "val/loss", 158 | } 159 | 160 | return [[optimizer], [scheduler]] 161 | 162 | return optimizer 163 | 164 | def _get_outputs(self, batch, batch_idx): 165 | """convenience function since train/valid/test steps are similar""" 166 | x = batch 167 | 168 | # x['array'] gives the actual raw audio 169 | output = self(x["array"]).logits 170 | 171 | # process outputs 172 | log_probs = F.log_softmax(output, dim=-1) 173 | input_lengths = torch.LongTensor([len(b) for b in log_probs]) 174 | log_probs = log_probs.permute(1, 0, 2) 175 | 176 | # process targets 177 | # extract the indices from the dictionary 178 | with self.processor.as_target_processor(): 179 | # tokenizattion but no phonemization 180 | x["labels"] = self.processor(x["phonemes"]).input_ids 181 | 182 | target_lengths = torch.LongTensor([len(targ) for targ in x["labels"]]) 183 | targets = torch.Tensor(list(chain.from_iterable(x["labels"]))).int() 184 | 185 | loss = self.loss(log_probs, targets, input_lengths, target_lengths) 186 | 187 | # to compute metric and log samples 188 | phone_preds = self.processor.batch_decode(torch.argmax(output, dim=-1)) 189 | phone_targets = self.processor.batch_decode(x["labels"]) 190 | 191 | return loss, output, phone_preds, phone_targets 192 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/models/__init__.py -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import HubertForCTC, Wav2Vec2ForCTC, WavLMForCTC 3 | 4 | 5 | class BaseModel(nn.Module): 6 | """ 7 | BaseFeaturesExtractor class that will extract features according to the type of model 8 | https://huggingface.co/blog/fine-tune-wav2vec2-english 9 | """ 10 | 11 | def __init__(self, params): 12 | super().__init__() 13 | self.params = params 14 | 15 | def forward(self, x): 16 | outputs = self.model(x) 17 | return outputs 18 | 19 | 20 | class Wav2Vec2(BaseModel): 21 | """ 22 | https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/wav2vec2#transformers.Wav2Vec2ForCTC 23 | """ 24 | 25 | def __init__(self, params): 26 | super().__init__(params) 27 | 28 | self.model = Wav2Vec2ForCTC.from_pretrained(params.pretrained_name) 29 | in_features = self.model.lm_head.in_features 30 | self.model.lm_head = nn.Linear( 31 | in_features=in_features, out_features=self.params.vocab_size 32 | ) 33 | 34 | 35 | class WavLM(BaseModel): 36 | """ 37 | https://huggingface.co/docs/transformers/model_doc/wavlm#transformers.WavLMForCTC 38 | """ 39 | 40 | def __init__(self, params): 41 | super().__init__(params) 42 | self.model = WavLMForCTC.from_pretrained(params.pretrained_name) 43 | in_features = self.model.lm_head.in_features 44 | self.model.lm_head = nn.Linear( 45 | in_features=in_features, out_features=self.params.vocab_size 46 | ) 47 | 48 | 49 | class Hubert(BaseModel): 50 | """ 51 | https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/hubert#transformers.HubertForCTC 52 | """ 53 | 54 | def __init__(self, params): 55 | super().__init__(params) 56 | self.model = HubertForCTC.from_pretrained(params.pretrained_name) 57 | in_features = self.model.lm_head.in_features 58 | self.model.lm_head = nn.Linear( 59 | in_features=in_features, out_features=self.params.vocab_size 60 | ) 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb 2 | pytest 3 | transformers 4 | datasets 5 | soundfile 6 | simple-parsing 7 | torch==1.10.0 8 | pytorch-lightning 9 | torchaudio 10 | phonemizer 11 | rich 12 | librosa 13 | wget 14 | lightning-bolts -------------------------------------------------------------------------------- /requirements_cuda11-3.txt: -------------------------------------------------------------------------------- 1 | datasets==1.18.3 2 | numpy==1.21.2 3 | pandas==1.4.1 4 | phonemizer==3.0.1 5 | pytorch_lightning==1.5.10 6 | rich==11.2.0 7 | simple_parsing==0.0.18 8 | -f https://download.pytorch.org/whl/cu113/torch_stable.html 9 | torch==1.10.2+cu113 10 | tqdm==4.62.3 11 | transformers==4.16.0 12 | wandb 13 | pytest 14 | soundfile 15 | torchaudio 16 | librosa 17 | wget -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | # Testing script to reproduce our test experiments 2 | 3 | python main.py --train False --language nl --subset nl --network_name Hubert --best_model_run Hubert_nl_tf_freezed 4 | python main.py --train False --language nl --subset nl --network_name WavLM --best_model_run WavLM_nl_tf_freezed 5 | python main.py --train False --language nl --subset nl --network_name Wav2Vec2 --best_model_run Wav2Vec2_nl_tf_freezed 6 | 7 | python main.py --train False --language sv --subset sv-SE --network_name Hubert --best_model_run Hubert_sv_tf_freezed 8 | python main.py --train False --language sv --subset sv-SE --network_name WavLM --best_model_run WavLM_sv_tf_freezed 9 | python main.py --train False --language sv --subset sv-SE --network_name Wav2Vec2 --best_model_run Wav2Vec2_sv_tf_freezed 10 | 11 | python main.py --train False --language it --subset it --network_name Hubert --best_model_run Hubert_it_tf_freezed 12 | python main.py --train False --language it --subset it --network_name WavLM --best_model_run WavLM_it_tf_freezed 13 | python main.py --train False --language it --subset it --network_name Wav2Vec2 --best_model_run Wav2Vec2_it_tf_freezed 14 | 15 | python main.py --train False --language ru --subset ru --network_name Hubert --best_model_run Hubert_ru_tf_freezed 16 | python main.py --train False --language ru --subset ru --network_name WavLM --best_model_run WavLM_ru_tf_freezed 17 | python main.py --train False --language ru --subset ru --network_name Wav2Vec2 --best_model_run Wav2Vec2_ru_tf_freezed 18 | 19 | python main.py --train False --language tr --subset tr --network_name Hubert --best_model_run Hubert_tr_tf_freezed 20 | python main.py --train False --language tr --subset tr --network_name WavLM --best_model_run WavLM_tr_tf_freezed 21 | python main.py --train False --language tr --subset tr --network_name Wav2Vec2 --best_model_run Wav2Vec2_tr_tf_freezed 22 | 23 | python main.py --train False --language sv --subset sv-SE --network_name Hubert --best_model_run Hubert_sv_0.05_train_tf_freezed --limit_train_batches 0.05 24 | python main.py --train False --language sv --subset sv-SE --network_name WavLM --best_model_run WavLM_sv_0.05_train_tf_freezed --limit_train_batches 0.05 25 | python main.py --train False --language sv --subset sv-SE --network_name Wav2Vec2 --best_model_run Wav2Vec2_sv_0.05_train_tf_freezed --limit_train_batches 0.05 26 | 27 | python main.py --train False --language sv --subset sv-SE --network_name Hubert --best_model_run Hubert_sv_0.1_train_tf_freezed --limit_train_batches 0.1 28 | python main.py --train False --language sv --subset sv-SE --network_name WavLM --best_model_run WavLM_sv_0.1_train_tf_freezed --limit_train_batches 0.1 29 | python main.py --train False --language sv --subset sv-SE --network_name Wav2Vec2 --best_model_run Wav2Vec2_sv_0.1_train_tf_freezed --limit_train_batches 0.1 30 | 31 | python main.py --train False --language sv --subset sv-SE --network_name Hubert --best_model_run Hubert_sv_0.5_train_tf_freezed --limit_train_batches 0.5 32 | python main.py --train False --language sv --subset sv-SE --network_name WavLM --best_model_run WavLM_sv_0.5_train_tf_freezed --limit_train_batches 0.5 33 | python main.py --train False --language sv --subset sv-SE --network_name Wav2Vec2 --best_model_run Wav2Vec2_sv_0.5_train_tf_freezed --limit_train_batches 0.5 -------------------------------------------------------------------------------- /train_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "id": "kyYQaEdwuQ1_" 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "!git clone https://github.com/ASR-project/Multilingual-PR.git\n", 12 | "%cd Multilingual-PR\n", 13 | "!pip install -q -r requirements.txt\n", 14 | "!pip install pytorch_lightning==1.5.10\n", 15 | "!apt-get install espeak -y" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": { 22 | "id": "6OwxeTlzwXAV" 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "import wandb\n", 27 | "!wandb login" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "id": "qLg-sGwEFSoL" 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "!python main.py --gpu 1 --num_workers 2 --language ru --subset ru --network_name WavLM --train True --num_proc 1 --enable_progress_bar False --lr 2e-2" 39 | ] 40 | } 41 | ], 42 | "metadata": { 43 | "accelerator": "GPU", 44 | "colab": { 45 | "background_execution": "on", 46 | "collapsed_sections": [], 47 | "machine_shape": "hm", 48 | "name": "train_notebook.ipynb", 49 | "provenance": [] 50 | }, 51 | "kernelspec": { 52 | "display_name": "Python 3", 53 | "name": "python3" 54 | }, 55 | "language_info": { 56 | "name": "python" 57 | } 58 | }, 59 | "nbformat": 4, 60 | "nbformat_minor": 0 61 | } 62 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/utils/__init__.py -------------------------------------------------------------------------------- /utils/agent_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import errno 4 | 5 | import wandb 6 | from config.hparams import Parameters 7 | from Datasets.datamodule import BaseDataModule 8 | from rich.progress import ( 9 | BarColumn, 10 | Progress, 11 | SpinnerColumn, 12 | TextColumn, 13 | TimeElapsedColumn, 14 | TimeRemainingColumn, 15 | ) 16 | 17 | 18 | def get_net(network_name, network_param): 19 | """ 20 | Get Network Architecture based on arguments provided 21 | """ 22 | 23 | mod = importlib.import_module(f"models.{network_name}") 24 | net = getattr(mod, network_name) 25 | return net(network_param) 26 | 27 | 28 | def get_artifact(name: str, type: str) -> str: 29 | """Artifact utilities 30 | Extracts the artifact from the name by downloading it locally> 31 | Return : str = path to the artifact 32 | """ 33 | if name != "" and name is not None: 34 | artifact = wandb.run.use_artifact(name, type=type) 35 | artifact_dir = artifact.download() 36 | file_path = os.path.join(artifact_dir, os.listdir(artifact_dir)[0]) 37 | return file_path 38 | else: 39 | return None 40 | 41 | 42 | def get_datamodule(data_param): 43 | """ 44 | Fetch Datamodule Function Pointer 45 | """ 46 | return BaseDataModule(data_param) 47 | 48 | 49 | def get_model(model_name, params): 50 | """ 51 | get features extractors 52 | """ 53 | try: 54 | mod = importlib.import_module(f"models.models") 55 | net = getattr(mod, model_name) 56 | return net(params) 57 | except NotImplementedError: 58 | raise NotImplementedError(f"Not implemented only Wav2vec, WavLM and Hubert") 59 | 60 | 61 | def parse_params(parameters: Parameters) -> dict: 62 | wdb_config = {} 63 | for k, v in vars(parameters).items(): 64 | for key, value in vars(v).items(): 65 | wdb_config[f"{k}-{key}"] = value 66 | return wdb_config 67 | 68 | 69 | def get_progress_bar(): 70 | return Progress( 71 | "[progress.description]{task.description}", 72 | SpinnerColumn(), 73 | BarColumn(), 74 | "[progress.percentage]{task.percentage:>3.0f}%", 75 | TextColumn("[bold blue]{task.fields[info]}", justify="right"), 76 | TimeElapsedColumn(), 77 | TimeRemainingColumn(), 78 | "\n", 79 | ) 80 | 81 | 82 | def create_directory(dir_path): 83 | try: 84 | os.makedirs(dir_path) 85 | except OSError as e: 86 | print(e) 87 | if e.errno != errno.EEXIST: 88 | raise 89 | -------------------------------------------------------------------------------- /utils/callbacks.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Any, Dict, Optional 3 | 4 | import torch 5 | import wandb 6 | import pytorch_lightning as pl 7 | import numpy as np 8 | 9 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback 10 | from pytorch_lightning.utilities import rank_zero_info 11 | from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT 12 | 13 | from utils.metrics import MetricsModule 14 | 15 | 16 | class AutoSaveModelCheckpoint(ModelCheckpoint): 17 | def __init__( 18 | self, 19 | config, 20 | project, 21 | entity, 22 | dirpath: Optional[_PATH] = None, 23 | filename: Optional[str] = None, 24 | monitor: Optional[str] = None, 25 | verbose: bool = False, 26 | save_last: Optional[bool] = None, 27 | save_top_k: int = 1, 28 | save_weights_only: bool = False, 29 | mode: str = "min", 30 | auto_insert_metric_name: bool = True, 31 | every_n_train_steps: Optional[int] = None, 32 | train_time_interval: Optional[timedelta] = None, 33 | every_n_epochs: Optional[int] = None, 34 | save_on_train_epoch_end: Optional[bool] = None, 35 | every_n_val_epochs: Optional[int] = None, 36 | ): 37 | super().__init__( 38 | dirpath, 39 | filename, 40 | monitor, 41 | verbose, 42 | save_last, 43 | save_top_k, 44 | save_weights_only, 45 | mode, 46 | auto_insert_metric_name, 47 | every_n_train_steps, 48 | train_time_interval, 49 | every_n_epochs, 50 | save_on_train_epoch_end, 51 | every_n_val_epochs, 52 | ) 53 | self.config = config 54 | self.project = project 55 | self.entity = entity 56 | 57 | def _update_best_and_save( 58 | self, 59 | current: torch.Tensor, 60 | trainer: "pl.Trainer", 61 | monitor_candidates: Dict[str, _METRIC], 62 | ) -> None: 63 | k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k 64 | 65 | del_filepath = None 66 | if len(self.best_k_models) == k and k > 0: 67 | del_filepath = self.kth_best_model_path 68 | self.best_k_models.pop(del_filepath) 69 | 70 | # do not save nan, replace with +/- inf 71 | if isinstance(current, torch.Tensor) and torch.isnan(current): 72 | current = torch.tensor( 73 | float("inf" if self.mode == "min" else "-inf"), device=current.device 74 | ) 75 | 76 | filepath = self._get_metric_interpolated_filepath_name( 77 | monitor_candidates, trainer, del_filepath 78 | ) 79 | 80 | # save the current score 81 | self.current_score = current 82 | self.best_k_models[filepath] = current 83 | 84 | if len(self.best_k_models) == k: 85 | # monitor dict has reached k elements 86 | _op = max if self.mode == "min" else min 87 | self.kth_best_model_path = _op( 88 | self.best_k_models, key=self.best_k_models.get 89 | ) 90 | self.kth_value = self.best_k_models[self.kth_best_model_path] 91 | 92 | _op = min if self.mode == "min" else max 93 | self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) 94 | self.best_model_score = self.best_k_models[self.best_model_path] 95 | 96 | if self.verbose: 97 | epoch = monitor_candidates.get("epoch") 98 | step = monitor_candidates.get("step") 99 | rank_zero_info( 100 | f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" 101 | f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' 102 | ) 103 | trainer.save_checkpoint(filepath, self.save_weights_only) 104 | 105 | if del_filepath is not None and filepath != del_filepath: 106 | trainer.training_type_plugin.remove_checkpoint(del_filepath) 107 | 108 | reverse = False if self.mode == "min" else True 109 | score = sorted(self.best_k_models.values(), reverse=reverse) 110 | # indices = [(i+1) for i, x in enumerate(score) if x == current] 111 | self.alias = f"latest" # 112 | self.name = f"{wandb.run.name}" 113 | 114 | self.filepath = filepath 115 | 116 | def log_artifact(self): 117 | rank_zero_info(f"Logging artifact") 118 | 119 | api = wandb.Api(overrides={"project": self.project, "entity": self.entity}) 120 | model_artifact = wandb.Artifact( 121 | type="model", name=self.name, metadata=self.config 122 | ) 123 | 124 | model_artifact.add_file(self.filepath) 125 | wandb.log_artifact(model_artifact, aliases=[self.alias]) 126 | model_artifact.wait() 127 | rank_zero_info(f"Done. Saved '{self.name}' weights to wandb") 128 | rank_zero_info(f"Cleaning up artifacts") 129 | artifacts = [] 130 | for art in list(api.artifact_versions("model", self.name)): 131 | try: 132 | per = art.logged_by().summary.get("val/per", 0) 133 | artifacts.append((art, per)) 134 | except: 135 | pass 136 | 137 | artifacts = sorted(artifacts, key=lambda art: art[-1]["min"], reverse=False) 138 | 139 | for i, artifact in enumerate(artifacts): 140 | artifact[0].aliases = [f"top-{i+1}"] 141 | try: 142 | artifact[0].save() 143 | except: 144 | pass 145 | 146 | rank_zero_info(f"Done") 147 | 148 | def del_artifacts(self): 149 | api = wandb.Api(overrides={"project": self.project, "entity": self.entity}) 150 | artifact_type, artifact_name = "model", f"{wandb.run.name}" 151 | try: 152 | for version in api.artifact_versions(artifact_type, artifact_name): 153 | # Clean previous versions with the same alias, to keep only the latest top k. 154 | if ( 155 | len(version.aliases) == 0 156 | ): # this means that it does not have the latest alias 157 | # either this works, or I will have to remove the model with the alias first then log the next 158 | version.delete() 159 | except: 160 | print("error in del artifact to ignore") 161 | return 162 | 163 | def on_exception( 164 | self, 165 | trainer: "pl.Trainer", 166 | pl_module: "pl.LightningModule", 167 | exception: BaseException, 168 | ) -> None: 169 | self.log_artifact() 170 | return super().on_exception(trainer, pl_module, exception) 171 | 172 | def on_train_end( 173 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 174 | ) -> None: 175 | self.log_artifact() 176 | 177 | 178 | class LogMetricsCallback(Callback): 179 | def __init__(self): 180 | super().__init__() 181 | 182 | def on_fit_start( 183 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 184 | ) -> None: 185 | device = pl_module.device 186 | 187 | self.metrics_module_train = MetricsModule("train", device) 188 | 189 | self.metrics_module_validation = MetricsModule("val", device) 190 | 191 | def on_test_start( 192 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" 193 | ) -> None: 194 | device = pl_module.device 195 | 196 | self.metrics_module_test = MetricsModule("test", device) 197 | 198 | def on_train_batch_end( 199 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 200 | ): 201 | """Called when the train batch ends.""" 202 | 203 | self.metrics_module_train.update_metrics(outputs["preds"], outputs["targets"]) 204 | 205 | def on_train_epoch_end(self, trainer, pl_module): 206 | """Called when the train epoch ends.""" 207 | 208 | self.metrics_module_train.log_metrics("train/", pl_module) 209 | 210 | def on_validation_batch_end( 211 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 212 | ): 213 | """Called when the validation batch ends.""" 214 | 215 | self.metrics_module_validation.update_metrics( 216 | outputs["preds"], outputs["targets"] 217 | ) 218 | 219 | def on_validation_epoch_end(self, trainer, pl_module): 220 | """Called when the validation epoch ends.""" 221 | 222 | self.metrics_module_validation.log_metrics("val/", pl_module) 223 | 224 | def on_test_batch_end( 225 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 226 | ): 227 | """Called when the validation batch ends.""" 228 | 229 | self.metrics_module_test.update_metrics(outputs["preds"], outputs["targets"]) 230 | 231 | def on_test_epoch_end(self, trainer, pl_module): 232 | """Called when the validation epoch ends.""" 233 | 234 | self.metrics_module_test.log_metrics("test/", pl_module) 235 | 236 | 237 | class LogAudioPrediction(Callback): 238 | def __init__(self, log_freq_audio, log_nb_audio) -> None: 239 | super().__init__() 240 | self.log_freq_audio = log_freq_audio 241 | self.log_nb_audio = log_nb_audio 242 | 243 | def on_validation_batch_end( 244 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 245 | ): 246 | """Called when the validation batch ends.""" 247 | 248 | if batch_idx == 0 and pl_module.current_epoch % self.log_freq_audio == 0: 249 | self.log_audio( 250 | pl_module, 251 | "val", 252 | batch, 253 | self.log_nb_audio, 254 | outputs, 255 | trainer.datamodule.sampling_rate, 256 | ) 257 | 258 | def on_train_batch_end( 259 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 260 | ): 261 | """Called when the training batch ends.""" 262 | 263 | if batch_idx == 0 and pl_module.current_epoch % self.log_freq_audio == 0: 264 | self.log_audio( 265 | pl_module, 266 | "train", 267 | batch, 268 | self.log_nb_audio, 269 | outputs, 270 | trainer.datamodule.sampling_rate, 271 | ) 272 | 273 | def on_test_batch_end( 274 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx 275 | ): 276 | """Called when the test batch ends.""" 277 | 278 | if batch_idx == 0 and pl_module.current_epoch % self.log_freq_audio == 0: 279 | self.log_audio( 280 | pl_module, 281 | "test", 282 | batch, 283 | self.log_nb_audio, 284 | outputs, 285 | trainer.datamodule.sampling_rate, 286 | ) 287 | 288 | def log_audio(self, pl_module, name, batch, n, outputs, sampling_rate): 289 | x = batch 290 | audios = x["array"][:n].detach().cpu() 291 | 292 | samples = [] 293 | for i in range(len(audios)): 294 | samples.append( 295 | [ 296 | wandb.Audio(audios[i], sample_rate=sampling_rate), 297 | x["sentence"][i], 298 | outputs["targets"][i], 299 | outputs["preds"][i], 300 | ] 301 | ) 302 | 303 | columns = ["Audio sample", "sentence", "target", "prediction"] 304 | table = wandb.Table(data=samples, columns=columns) 305 | 306 | wandb.run.log({f"{name}/predictions": table}) 307 | -------------------------------------------------------------------------------- /utils/constant.py: -------------------------------------------------------------------------------- 1 | CHARS_TO_REMOVE_REGEX = "[\,\?\.\!\-\;\:\"\“\%\‘\”\�'\。]" 2 | -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import tarfile 5 | 6 | import pandas as pd 7 | import torch 8 | import wget 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | from utils.logger import init_logger 12 | 13 | 14 | def coll_fn(batch): 15 | 16 | batch_dict = {} 17 | batch_dict["array"] = pad_sequence( 18 | [torch.Tensor(b["audio"]) for b in batch], padding_value=0, batch_first=True 19 | ) 20 | batch_dict["path"] = [b["path"] for b in batch] 21 | batch_dict["sentence"] = [b["sentence"] for b in batch] 22 | batch_dict["phonemes"] = [b["phonemes"] for b in batch] 23 | 24 | return batch_dict 25 | 26 | 27 | def create_vocabulary( 28 | ISO6393, path_csv, eos_token, bos_token, unk_token, pad_token, word_delimiter_token 29 | ): 30 | 31 | logger = init_logger("create_vocabulary", "INFO") 32 | 33 | df = pd.read_csv(osp.join(path_csv, "phoible.csv")) 34 | 35 | df_phoneme_target_lang = df[df["ISO6393"] == ISO6393]["Phoneme"] 36 | df_phoneme_target_lang.drop_duplicates(keep="first", inplace=True) 37 | df_phoneme_target_lang.reset_index(drop=True, inplace=True) 38 | 39 | phoneme_vocab = dict(df_phoneme_target_lang) 40 | phoneme_vocab = {v: k for k, v in phoneme_vocab.items()} 41 | 42 | phoneme_vocab[eos_token] = len(phoneme_vocab) 43 | phoneme_vocab[bos_token] = len(phoneme_vocab) 44 | phoneme_vocab[unk_token] = len(phoneme_vocab) 45 | phoneme_vocab[pad_token] = len(phoneme_vocab) 46 | phoneme_vocab[word_delimiter_token] = len(phoneme_vocab) 47 | 48 | logger.info(f"Length vocabulary : {len(phoneme_vocab)}") 49 | 50 | vocab_path = osp.join(os.getcwd(), "assets", "vocab_phoneme") 51 | file_dict = os.path.join(vocab_path, f"vocab-phoneme-{ISO6393}.json") 52 | 53 | if not os.path.exists(vocab_path): 54 | os.makedirs(vocab_path) 55 | 56 | with open(file_dict, "w") as vocab_file: 57 | json.dump(phoneme_vocab, vocab_file) 58 | 59 | return file_dict, len(phoneme_vocab) 60 | 61 | 62 | def create_vocabulary2( 63 | language, path, eos_token, bos_token, unk_token, pad_token, word_delimiter_token 64 | ): 65 | 66 | logger = init_logger("create_vocabulary", "INFO") 67 | 68 | if not osp.exists(path): 69 | assets_path = "/".join(path.split("/")[:-1]) 70 | url = "https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz" 71 | wget.download(url, assets_path) 72 | 73 | tar = tarfile.open(osp.join(assets_path, "common_voices_splits.tar.gz"), "r:gz") 74 | tar.extractall(assets_path) 75 | tar.close() 76 | 77 | json_file = osp.join(path, language, "phonesMatches_reduced.json") 78 | 79 | with open(json_file) as file: 80 | phoneme_vocab = json.load(file) 81 | 82 | phoneme_vocab[eos_token] = len(phoneme_vocab) 83 | phoneme_vocab[bos_token] = len(phoneme_vocab) 84 | phoneme_vocab[unk_token] = len(phoneme_vocab) 85 | phoneme_vocab[pad_token] = len(phoneme_vocab) 86 | phoneme_vocab[word_delimiter_token] = len(phoneme_vocab) 87 | 88 | logger.info(f"Length vocabulary : {len(phoneme_vocab)}") 89 | 90 | vocab_path = osp.join(os.getcwd(), "assets", "vocab_phoneme") 91 | file_dict = os.path.join(vocab_path, f"vocab-phoneme-{language}.json") 92 | 93 | if not os.path.exists(vocab_path): 94 | os.makedirs(vocab_path) 95 | 96 | with open(file_dict, "w") as vocab_file: 97 | json.dump(phoneme_vocab, vocab_file) 98 | 99 | return file_dict, len(phoneme_vocab) 100 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class CustomFormatter(logging.Formatter): 5 | """Logging Formatter to add colors and count warning / errors""" 6 | 7 | yellow = "\x1b[93;1m" 8 | red = "\x1b[31;1m" 9 | blue = "\x1b[36;1m" 10 | green = "\x1b[32;1m" 11 | reset = "\x1b[0m" 12 | orange = "\x1b[33;1m" 13 | date = "%(asctime)s [" 14 | level_name = "%(levelname)s" 15 | prefix = "]\t" 16 | other = "%(name)s\t%(message)s" 17 | datefmt = "%H:%M:%S" 18 | 19 | FORMATS = { 20 | logging.DEBUG: date + green + level_name + reset + prefix + other, 21 | logging.INFO: date + blue + level_name + reset + prefix + "\t" + other, 22 | logging.WARNING: date + yellow + level_name + reset + prefix + other, 23 | logging.ERROR: date + red + level_name + reset + prefix + other, 24 | logging.CRITICAL: date + orange + level_name + reset + prefix + other, 25 | } 26 | 27 | def format(self, record): 28 | log_fmt = self.FORMATS.get(record.levelno) 29 | formatter = logging.Formatter(log_fmt, "%H:%M:%S") 30 | return formatter.format(record) 31 | 32 | 33 | def init_logger(name, log_level): 34 | 35 | logger = logging.getLogger(name) 36 | logger.setLevel(getattr(logging, log_level)) 37 | ch = logging.StreamHandler() 38 | ch.setLevel(logging.DEBUG) 39 | ch.setFormatter(CustomFormatter()) 40 | logger.addHandler(ch) 41 | return logger 42 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | from utils.per import PhonemeErrorRate 2 | 3 | 4 | class MetricsModule: 5 | def __init__(self, set_name, device) -> None: 6 | """ 7 | set_name: val/train/test 8 | """ 9 | self.device = device 10 | dict_metrics = {} 11 | dict_metrics["per"] = PhonemeErrorRate(compute_on_step=False).to(device) 12 | 13 | self.dict_metrics = dict_metrics 14 | 15 | def update_metrics(self, x, y): 16 | 17 | for _, m in self.dict_metrics.items(): 18 | # metric on current batch 19 | m(x, y) # update metrics (torchmetrics method) 20 | 21 | def log_metrics(self, name, pl_module): 22 | 23 | for k, m in self.dict_metrics.items(): 24 | 25 | # metric on all batches using custom accumulation 26 | metric = m.compute() 27 | pl_module.log(name + k, metric) 28 | 29 | # Reseting internal state such that metric ready for new data 30 | m.reset() 31 | m.to(self.device) 32 | -------------------------------------------------------------------------------- /utils/per.py: -------------------------------------------------------------------------------- 1 | from torchmetrics import Metric 2 | import torch 3 | from torch import Tensor, tensor 4 | from typing import Any, Dict, List, Optional, Union, Tuple 5 | 6 | 7 | class PhonemeErrorRate(Metric): 8 | """ 9 | https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/text/wer.py#L23-L93 10 | """ 11 | 12 | def __init__(self, compute_on_step=False): 13 | super().__init__(compute_on_step=compute_on_step) 14 | self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum") 15 | self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum") 16 | 17 | def update(self, preds, targets): 18 | """ 19 | preds : list of sentence phoneme 20 | targets : list of sentence phoneme 21 | """ 22 | errors, total = _per_update(preds, targets) 23 | 24 | self.errors += errors 25 | self.total += total 26 | 27 | def compute(self): 28 | return _per_compute(self.errors, self.total) 29 | 30 | 31 | def _per_update( 32 | preds: Union[str, List[str]], 33 | target: Union[str, List[str]], 34 | ) -> Tuple[Tensor, Tensor]: 35 | """Update the wer score with the current set of references and predictions. 36 | Args: 37 | preds: Transcription(s) to score as a string or list of strings 38 | target: Reference(s) for each speech input as a string or list of strings 39 | Returns: 40 | Number of edit operations to get from the reference to the prediction, summed over all samples 41 | Number of words overall references 42 | """ 43 | if isinstance(preds, str): 44 | preds = [preds] 45 | if isinstance(target, str): 46 | target = [target] 47 | errors = tensor(0, dtype=torch.float) 48 | total = tensor(0, dtype=torch.float) 49 | for pred, tgt in zip(preds, target): 50 | pred_tokens = pred.split() 51 | tgt_tokens = tgt.split() 52 | errors += _edit_distance(pred_tokens, tgt_tokens) 53 | total += len(tgt_tokens) 54 | return errors, total 55 | 56 | 57 | def _per_compute(errors: Tensor, total: Tensor) -> Tensor: 58 | """Compute the word error rate. 59 | Args: 60 | errors: Number of edit operations to get from the reference to the prediction, summed over all samples 61 | total: Number of words overall references 62 | Returns: 63 | Word error rate score 64 | """ 65 | return errors / total 66 | 67 | 68 | def _edit_distance(prediction_tokens: List[str], reference_tokens: List[str]) -> int: 69 | """Standard dynamic programming algorithm to compute the edit distance. 70 | Args: 71 | prediction_tokens: A tokenized predicted sentence 72 | reference_tokens: A tokenized reference sentence 73 | Returns: 74 | Edit distance between the predicted sentence and the reference sentence 75 | """ 76 | dp = [[0] * (len(reference_tokens) + 1) for _ in range(len(prediction_tokens) + 1)] 77 | for i in range(len(prediction_tokens) + 1): 78 | dp[i][0] = i 79 | for j in range(len(reference_tokens) + 1): 80 | dp[0][j] = j 81 | for i in range(1, len(prediction_tokens) + 1): 82 | for j in range(1, len(reference_tokens) + 1): 83 | if prediction_tokens[i - 1] == reference_tokens[j - 1]: 84 | dp[i][j] = dp[i - 1][j - 1] 85 | else: 86 | dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1 87 | return dp[-1][-1] 88 | --------------------------------------------------------------------------------