├── .github
└── workflows
│ └── code_lint.yml
├── .gitignore
├── ASR_presentation_project_Apavou_Belkada_Tronchon_Zucker.pdf
├── ASR_report_project_Apavou_Belkada_Tronchon_Zucker.pdf
├── Datasets
├── __init__.py
└── datamodule.py
├── README.md
├── agents
├── BaseTrainer.py
└── __init__.py
├── assets
├── .keep
├── img_readme
│ ├── ASR_language family.png
│ ├── Network.drawio.png
│ ├── PER test vs Training data.png
│ ├── PER validation vs Training data.png
│ ├── hubert.jpeg
│ ├── parrot.png
│ ├── wav2vec2.png
│ └── wavlm.png
└── phoible.csv
├── conda_environment.yml
├── config
├── __init__.py
└── hparams.py
├── main.py
├── models
├── BaseModule.py
├── __init__.py
└── models.py
├── requirements.txt
├── requirements_cuda11-3.txt
├── test.sh
├── train_notebook.ipynb
└── utils
├── __init__.py
├── agent_utils.py
├── callbacks.py
├── constant.py
├── dataset_utils.py
├── logger.py
├── metrics.py
└── per.py
/.github/workflows/code_lint.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on:
4 | # Trigger the workflow on push or pull request,
5 | # but only for the main branch
6 | push:
7 | branches:
8 | - main
9 | pull_request:
10 | branches:
11 | - main
12 |
13 | jobs:
14 | run-linters:
15 | name: Run linters
16 | runs-on: ubuntu-latest
17 |
18 | steps:
19 | - name: Check out Git repository
20 | uses: actions/checkout@v2
21 |
22 | - name: Set up Python
23 | uses: actions/setup-python@v1
24 | with:
25 | python-version: 3.8
26 |
27 | - name: Install Python dependencies
28 | run: pip install black flake8
29 |
30 | - name: Run linters
31 | uses: wearerequired/lint-action@v1
32 | with:
33 | black: true
34 | auto_fix: true
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # vs code
2 | .vscode/*
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108 | __pypackages__/
109 |
110 | # Celery stuff
111 | celerybeat-schedule
112 | celerybeat.pid
113 |
114 | # SageMath parsed files
115 | *.sage.py
116 |
117 | # Environments
118 | .env
119 | .venv
120 | env/
121 | venv/
122 | ENV/
123 | env.bak/
124 | venv.bak/
125 |
126 | # Spyder project settings
127 | .spyderproject
128 | .spyproject
129 |
130 | # Rope project settings
131 | .ropeproject
132 |
133 | # mkdocs documentation
134 | /site
135 |
136 | # mypy
137 | .mypy_cache/
138 | .dmypy.json
139 | dmypy.json
140 |
141 | # Pyre type checker
142 | .pyre/
143 |
144 | # pytype static type analyzer
145 | .pytype/
146 |
147 | # Cython debug symbols
148 | cython_debug/
149 |
150 | # PyCharm
151 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
152 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
153 | # and can be added to the global gitignore or merged into this file. For a more nuclear
154 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
155 | #.idea/
156 |
157 | wandb/*
158 | *.ckpt
159 | assets/*
160 | artifacts/*
161 | *.json
--------------------------------------------------------------------------------
/ASR_presentation_project_Apavou_Belkada_Tronchon_Zucker.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/ASR_presentation_project_Apavou_Belkada_Tronchon_Zucker.pdf
--------------------------------------------------------------------------------
/ASR_report_project_Apavou_Belkada_Tronchon_Zucker.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/ASR_report_project_Apavou_Belkada_Tronchon_Zucker.pdf
--------------------------------------------------------------------------------
/Datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/Datasets/__init__.py
--------------------------------------------------------------------------------
/Datasets/datamodule.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import pickle
3 | import re
4 | import shutil
5 |
6 | import numpy as np
7 | import utils.agent_utils as ag_u
8 | import wandb
9 | from datasets import Audio, load_dataset
10 | from librosa.effects import trim
11 | from phonemizer.backend import EspeakBackend
12 | from phonemizer.separator import Separator
13 | from pytorch_lightning import LightningDataModule
14 | from torch.utils.data import DataLoader
15 | from utils.constant import CHARS_TO_REMOVE_REGEX
16 | from utils.dataset_utils import coll_fn
17 | from utils.logger import init_logger
18 |
19 |
20 | class BaseDataModule(LightningDataModule):
21 | def __init__(self, dataset_param):
22 | super().__init__()
23 |
24 | self.config = dataset_param
25 | self.logger = init_logger("BaseDataModule", "INFO")
26 | self.logger.info(
27 | f"Loading Dataset : {self.config.dataset_name}, language : {self.config.subset}"
28 | )
29 |
30 | def prepare_data(self) -> None:
31 | return super().prepare_data()
32 |
33 | def load_data(self, split) -> None:
34 | """
35 | Function to load dataset
36 | """
37 |
38 | self.logger.info(f"Loading the dataset in load_data: {split}")
39 |
40 | setattr(
41 | self,
42 | f"{split}_save_data_path",
43 | osp.join(
44 | "assets",
45 | "datasets",
46 | f"{split}_{self.config.dataset_name}-{self.config.subset}",
47 | ),
48 | )
49 |
50 | save_path = getattr(self, f"{split}_save_data_path")
51 | name_file = getattr(self, f"{split}_save_data_path").split("/")[-1]
52 | name_file_path = osp.join(save_path, name_file)
53 | name_dataset = f"{split}_dataset"
54 |
55 | ag_u.create_directory(save_path)
56 |
57 | if not osp.exists(name_file_path) or self.config.create_dataset:
58 | # if not self.config.create_dataset:
59 | # # try
60 | # path = f"asr-project/{self.config.wandb_project}/{name_file}:latest"
61 | # self.logger.info(f"Try loading {path} in artifacts ...")
62 |
63 | # file = ag_u.get_artifact(path, type="dataset")
64 |
65 | # shutil.copy2(file, save_path)
66 |
67 | # self.logger.info(f"Load {path} in artifacts OK")
68 |
69 | # file = open(name_file_path, "rb")
70 | # setattr(self, name_dataset, pickle.load(file))
71 | # self.logger.info(
72 | # f"Loaded {split} dataset : {name_file_path}")
73 | # else:
74 | # except
75 | setattr(
76 | self,
77 | name_dataset,
78 | load_dataset(
79 | self.config.dataset_name,
80 | self.config.subset,
81 | split=split if split != "val" else "validation",
82 | use_auth_token=self.config.use_auth_token,
83 | download_mode=self.config.download_mode,
84 | cache_dir=self.config.cache_dir,
85 | ),
86 | )
87 |
88 | setattr(
89 | self,
90 | name_dataset,
91 | getattr(self, name_dataset).remove_columns(
92 | [
93 | "accent",
94 | "age",
95 | "client_id",
96 | "down_votes",
97 | "gender",
98 | "locale",
99 | "segment",
100 | "up_votes",
101 | ]
102 | ),
103 | )
104 |
105 | setattr(
106 | self,
107 | name_dataset,
108 | getattr(self, name_dataset).cast_column(
109 | "audio", Audio(sampling_rate=16000)
110 | ),
111 | )
112 |
113 | metadata_artifact = {
114 | "dataset_name": self.config.dataset_name,
115 | "subset": self.config.subset,
116 | "split": split,
117 | "sampling_rate": 16000,
118 | }
119 |
120 | self._save_dataset(
121 | split, name_file_path, metadata_artifact, f"{split} dataset"
122 | )
123 | else:
124 | file = open(name_file_path, "rb")
125 | setattr(self, name_dataset, pickle.load(file))
126 |
127 | self.sampling_rate = 16000
128 |
129 | self.logger.info(f"Done prepare_data {split}")
130 |
131 | def process_dataset(self, split, processor, batch_size=512):
132 | """
133 | Function to process data of a dataset (remove indesirable characters and process audio with processor)
134 | """
135 |
136 | save_path = getattr(self, f"{split}_save_data_path")
137 | name_file = save_path.split("/")[-1] + "_process"
138 | name_file_path = osp.join(save_path, name_file)
139 | name_dataset = f"{split}_dataset"
140 |
141 | if not osp.exists(name_file_path) or self.config.create_dataset:
142 | # if not self.config.create_dataset:
143 | # # try
144 | # path = f"asr-project/{self.config.wandb_project}/{name_file}:latest"
145 | # self.logger.info(f"Try loading {path} in artifacts ...")
146 |
147 | # file = ag_u.get_artifact(path, type="dataset")
148 |
149 | # shutil.copy2(file, save_path)
150 |
151 | # self.logger.info(f"Load {path} in artifacts OK")
152 |
153 | # file = open(name_file_path, "rb")
154 | # setattr(self, name_dataset, pickle.load(file))
155 | # self.logger.info(
156 | # f"Loaded processed {split} dataset : {name_file_path}")
157 | # else:
158 | # except
159 | self.logger.info(f"Processing {split} dataset ...")
160 |
161 | setattr(
162 | self,
163 | name_dataset,
164 | getattr(self, name_dataset).map(
165 | lambda x: {
166 | "sentence": re.sub(
167 | CHARS_TO_REMOVE_REGEX, "", x["sentence"]
168 | ).lower()
169 | },
170 | num_proc=self.config.num_proc,
171 | load_from_cache_file=False,
172 | ),
173 | )
174 | setattr(
175 | self,
176 | name_dataset,
177 | getattr(self, name_dataset).map(
178 | lambda batch: {
179 | "audio": processor(
180 | [ad["array"] for ad in batch["audio"]], sampling_rate=16000
181 | ).input_values
182 | },
183 | batched=True,
184 | batch_size=batch_size,
185 | num_proc=self.config.num_proc,
186 | load_from_cache_file=False,
187 | ),
188 | )
189 |
190 | self.logger.info(f"Saving {split} dataset ...")
191 |
192 | metadata_artifact = {
193 | "dataset_name": self.config.dataset_name,
194 | "subset": self.config.subset,
195 | "split": split,
196 | "sampling_rate": self.sampling_rate,
197 | }
198 |
199 | self._save_dataset(
200 | split, name_file_path, metadata_artifact, f"{split} dataset processed"
201 | )
202 | else:
203 | self.logger.info(
204 | f"{split} dataset already exists no processing necessary ..."
205 | )
206 |
207 | file = open(name_file_path, "rb")
208 | setattr(self, name_dataset, pickle.load(file))
209 | self.logger.info(f"Loaded processed {split} dataset : {name_file_path}")
210 |
211 | def filtered_data(self, split, top_db=15) -> None:
212 | """
213 | Function to filter dataset (remove silence and remove long audio )
214 | """
215 |
216 | self.logger.info(f"Filtering {split} dataset ...")
217 |
218 | save_path = getattr(self, f"{split}_save_data_path")
219 | name_file = f"{save_path.split('/')[-1]}_filter_{top_db}_{self.config.max_input_length_in_sec}"
220 | name_file_path = osp.join(save_path, name_file)
221 | name_dataset = f"{split}_dataset"
222 |
223 | if not osp.exists(name_file_path) or self.config.create_dataset:
224 | # if not self.config.create_dataset:
225 | # # try
226 | # path = f"asr-project/{self.config.wandb_project}/{name_file}:latest"
227 | # self.logger.info(f"Try loading {path} in artifacts ...")
228 |
229 | # file = ag_u.get_artifact(path, type="dataset")
230 |
231 | # shutil.copy2(file, getattr(self, f'{split}_save_data_path'))
232 |
233 | # file = open(name_file_path, "rb")
234 | # setattr(self, name_dataset, pickle.load(file))
235 | # self.logger.info(
236 | # f"Loaded filtered {split} dataset : {name_file_path}")
237 | # else:
238 | # # except
239 | self.logger.info(
240 | f"Length {split} dataset before filter {len(getattr(self, name_dataset))}"
241 | )
242 |
243 | setattr(
244 | self,
245 | name_dataset,
246 | getattr(self, name_dataset).map(
247 | lambda x: {"audio": trim(np.array(x["audio"]), top_db=top_db)[0]},
248 | num_proc=self.config.num_proc,
249 | load_from_cache_file=False,
250 | ),
251 | )
252 | setattr(
253 | self,
254 | name_dataset,
255 | getattr(self, name_dataset).filter(
256 | lambda x: len(x["audio"])
257 | < self.config.max_input_length_in_sec * self.sampling_rate,
258 | num_proc=self.config.num_proc,
259 | load_from_cache_file=False,
260 | ),
261 | )
262 |
263 | self.logger.info(
264 | f"Length {split} dataset after filter {len(getattr(self, name_dataset))}"
265 | )
266 |
267 | metadata_artifact = {
268 | "dataset_name": self.config.dataset_name,
269 | "subset": self.config.subset,
270 | "split": split,
271 | "sampling_rate": self.sampling_rate,
272 | "top_db": top_db,
273 | "max_input_length_in_sec": self.config.max_input_length_in_sec,
274 | }
275 |
276 | self._save_dataset(
277 | split,
278 | name_file_path,
279 | metadata_artifact,
280 | f"{split} dataset processed and filtered",
281 | )
282 |
283 | else:
284 | file = open(name_file_path, "rb")
285 | setattr(self, name_dataset, pickle.load(file))
286 | self.logger.info(f"Loaded filtered {split} dataset : {name_file_path}")
287 |
288 | self.logger.info(f"Length {split} dataset : {len(getattr(self, name_dataset))}")
289 |
290 | def create_phonemes(self, split) -> None:
291 | """
292 | Function to phonemize all sentence of the dataset
293 | """
294 |
295 | language = (
296 | self.config.language[:2] if self.config.language[:2] != "zh" else "cmn"
297 | )
298 | self.logger.info(f"Creating {split} phonemes language {language}...")
299 | backend = EspeakBackend(language)
300 | separator = Separator(phone=" ", word="| ", syllable="")
301 |
302 | name_dataset = f"{split}_dataset"
303 |
304 | setattr(
305 | self,
306 | name_dataset,
307 | getattr(self, name_dataset).add_column(
308 | "phonemes",
309 | backend.phonemize(
310 | getattr(self, name_dataset)["sentence"],
311 | njobs=self.config.num_proc,
312 | separator=separator,
313 | ),
314 | ),
315 | )
316 |
317 | def _save_dataset(
318 | self, split, name_file_path, metadata_artifact, description_artifact
319 | ):
320 |
321 | file = open(name_file_path, "wb")
322 | pickle.dump(getattr(self, f"{split}_dataset"), file)
323 |
324 | self.logger.info(f"Saved to {name_file_path}")
325 |
326 | self.push_artefact(name_file_path, metadata_artifact, description_artifact)
327 |
328 | def push_artefact(self, path_artifact, metadata, description):
329 | artifact = wandb.Artifact(
330 | name=osp.basename(path_artifact),
331 | type="dataset",
332 | metadata=metadata,
333 | description=description,
334 | )
335 | artifact.add_file(path_artifact)
336 | wandb.log_artifact(artifact, aliases=["latest"])
337 |
338 | def setup(self, stage=None):
339 | # Build dataset
340 | if stage in (None, "fit"):
341 |
342 | self.filtered_data("train")
343 | self.filtered_data("val")
344 |
345 | self.create_phonemes("train")
346 | self.create_phonemes("val")
347 |
348 | if stage == "test":
349 | self.create_phonemes("test")
350 |
351 | if stage == "predict":
352 | self.dataset = load_dataset(
353 | self.config.dataset_name,
354 | self.config.subset,
355 | split="other",
356 | use_auth_token=self.config.use_auth_token,
357 | download_mode=self.config.download_mode,
358 | cache_dir=self.config.cache_dir,
359 | )
360 |
361 | def train_dataloader(self):
362 | train_loader = DataLoader(
363 | self.train_dataset,
364 | shuffle=True,
365 | batch_size=self.config.batch_size,
366 | num_workers=self.config.num_workers,
367 | collate_fn=coll_fn,
368 | )
369 | return train_loader
370 |
371 | def val_dataloader(self):
372 | val_loader = DataLoader(
373 | self.val_dataset,
374 | shuffle=False,
375 | batch_size=self.config.batch_size,
376 | num_workers=self.config.num_workers,
377 | collate_fn=coll_fn,
378 | )
379 | return val_loader
380 |
381 | def test_dataloader(self):
382 | val_loader = DataLoader(
383 | self.test_dataset,
384 | shuffle=False,
385 | batch_size=self.config.batch_size,
386 | num_workers=self.config.num_workers,
387 | collate_fn=coll_fn,
388 | )
389 | return val_loader
390 |
391 | def predict_dataloader(self):
392 | predict_loader = DataLoader(
393 | self.dataset,
394 | batch_size=self.config.batch_size,
395 | num_workers=self.config.num_workers,
396 | shuffle=False,
397 | pin_memory=True,
398 | collate_fn=coll_fn,
399 | )
400 | return predict_loader
401 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multilingual-PR
2 |
3 | Implementation of the project ```Self-supervised pretraining for phoneme recognition, and generalization on foreign languages```
4 |
5 | > Authors: [Apavou Clément](https://github.com/clementapa) & [Belkada Younes](https://github.com/younesbelkada) & [Leo Tronchon](https://github.com/leot13) & [Arthur Zucker](https://github.com/ArthurZucker)
6 |
7 | 
8 | 
9 | 
10 |
11 |
12 |
13 |
14 |
15 | This repository is powered by HuggingFace :hugs:, Pytorch-Lightning and Weight & Biases.
16 |
17 | ## :bird: Introduction
18 |
19 | The scarcity of annotated data, and the heavy cost of producing them, limits our ability to train deep neural network for audio processing tasks.Therefore, the speech community developed feature learning methods with a minimal need for annotated data, which mostly fall under unsupervised and self-supervised techniques.
20 |
21 | Recently, the rise of self-supervised learning methods for textual modality has outperformed state-of-the-art methods on downstream tasks, by fine-tuning the pretrained models on a relatively small amount of data. These approaches have recently been tested for other modalities such as images and audios.
22 |
23 | Phoneme recognition is an exciting challenge that involves processing a raw audio recording and predict the corresponding sequence of phonemes that are pronounced by the speaker. Throughout this project, we will compare specifically three different self-supervised models, Wav2vec (2019, 2020), HuBERT (2021) and WavLM (2022) pretrained on a corpus of English speech that we will use in various ways to perform phoneme recognition for different languages with a network trained with Connectionist Temporal Classification (CTC) algorithm. Different questions will be addressed:
24 |
25 | + *What is the impact of choosing English as a pretrained language, especially for languages that are very different from English? Which method(s) works best for transferring knowledge from English to other languages?*
26 | + *Which method allows to extract the best features for phoneme recognition?*
27 | + *What is the influence of the abundance of training data on the performance of models?*
28 |
29 | In this project, we address these questions by drawing conclusions from our experiments.
30 |
31 | ## :sparkles: Main features
32 |
33 | + Modularity between SOTA models in self-supervision for speech
34 | + Freedom to select any languages available on CommonVoice hosted at [HuggingFace](https://huggingface.co/datasets/common_voice).
35 | + Nice visualization tool through wandb.
36 |
37 | ## :pencil2: Network Architecture for phoneme recognition
38 |
39 |
40 |
41 |
42 |
43 | Diagram of the models used for the experiments. N=22 and h=1024 for HuBERT Large and WavLM Large, and N=11 and h=768 for Wav2vec2 Base and WavLM Base. Made by us.
44 |
45 |
46 | ## :books: Languages for which phoneme dictionaries are available
47 | Dutch (du), Spanish (es), French (fr), Italian (it), Kyrgyz (ky), Russian (ru), Sweedish
48 | (sv), Turkish (tr), Tatar (tt) and Mandarin (zh). From https://github.com/facebookresearch/CPC_audio.
49 |
50 | ## :star2: Usage
51 |
52 | Please refer to our [example notebook](https://github.com/ASR-project/Multilingual-PR/blob/main/train_notebook.ipynb) if you want to train or test a model. To understand the command line arguments that you can use, run:
53 | ```
54 | Hparams ['parameters.hparams']:
55 | Hyperparameters of for the run
56 |
57 | --wandb_entity str wandb (default: asr-project)
58 | --debug bool (default: False)
59 | --test bool test code before running, if testing, no checkpoints are written (default: True)
60 | --wandb_project str (default: test-asr)
61 | --root_dir str root_dir (default: /home/arthur/Work/MVA-S2/Speech/Multilingual-PR)
62 | --seed_everything [int]
63 | basic params (default: None)
64 | --gpu int number or gpu (default: 1)
65 | --hparams.max_epochs int
66 | maximum number of epochs (default: 100)
67 | --weights_path str (default: /home/arthur/Work/MVA-S2/Speech/Multilingual-PR/weights)
68 | --tune_lr bool modes (default: False)
69 | --dev_run bool (default: False)
70 | --train bool (default: True)
71 | --best_model str (default: )
72 | --log_freq_audio int (default: 10)
73 | --log_nb_audio int (default: 2)
74 | --val_check_interval float
75 | trainer params (default: 1.0)
76 | --limit_train_batches float
77 | 1.0 (default: 1.0)
78 | --limit_val_batches float
79 | 1.0 (default: 1.0)
80 | --enable_progress_bar bool
81 | (default: True)
82 | --best_model_run str testing params (default: WavLM_sv)
83 | --early_stopping bool
84 | Early Stopping (default: True)
85 | --early_stopping_params typing.Dict[str, typing.Any]
86 | (default: {'monitor': 'val/per', 'patience': 10, 'mode': 'min', 'verbose': True})
87 |
88 | DatasetParams ['parameters.data_param']:
89 | Dataset Parameters
90 | ! The batch_size and number of crops should be defined here
91 |
92 |
93 | --dataset_name str Hugging Face datasets parameters (default: common_voice)
94 | --use_auth_token bool
95 | True if use mozilla-foundation datasets (default: False)
96 | --subset str (default: sv-SE)
97 | --download_mode str chosen language (see https://huggingface.co/datasets/common_voice) (default: reuse_dataset_if_exists)
98 | --cache_dir str (default: /home/arthur/Work/MVA-S2/Speech/Multilingual-PR/assets)
99 | --language str to create vocabulary of phonemes (default: sv)
100 | --root_path_annotation str
101 | (default: /home/arthur/Work/MVA-S2/Speech/Multilingual-PR/assets/common_voices_splits)
102 | --phoible_csv_path str
103 | (default: /home/arthur/Work/MVA-S2/Speech/Multilingual-PR/assets)
104 | --num_workers int Dataloader parameters (default: 20)
105 | --batch_size int (default: 2)
106 | --max_input_length_in_sec float
107 | Dataset processing parameters (default: 5)
108 | --num_proc int (default: 4)
109 | --create_dataset bool
110 | (default: False)
111 |
112 | NetworkParams ['parameters.network_param']:
113 | NetworkParams(network_name: str = 'WavLM', pretrained_name: Union[str, NoneType] = '', freeze: bool = True, freeze_transformer: bool = True, eos_token: str = '', bos_token: str = '', unk_token: str = '', pad_token: str = '', word_delimiter_token: str = '|')
114 |
115 | --network_name str Hubert, Wav2Vec2, WavLM (default: WavLM)
116 | --pretrained_name [str]
117 | (default: )
118 | --freeze bool (default: True)
119 | --freeze_transformer bool
120 | (default: True)
121 | --eos_token str Phoneme Tokenizer (default: )
122 | --bos_token str (default: )
123 | --unk_token str (default: )
124 | --pad_token str (default: )
125 | --word_delimiter_token str
126 | (default: |)
127 |
128 | OptimizerParams ['parameters.optim_param']:
129 | Optimization parameters
130 |
131 | --optimizer str (default: AdamW)
132 | --lr float (default: 0.02)
133 | --weight_decay float (default: 1e-08)
134 | --accumulate_grad_batches int
135 | 1 for no accumulation (default: 16)
136 | --scheduler [str] Scheduler parameters (default: None)
137 | --optim_param.max_epochs int
138 | Cosine, ReduceLROnPlateau, MultiStepLR, StepLR or None Cosine scheduler (default: 10)
139 | --warmup_epochs int (default: 1)
140 | --warmup_start_lr float
141 | (default: 0.0006)
142 | --eta_min float (default: 5e-06)
143 | --step_size int Step LR scheduler (default: 2)
144 | --gamma float also for multi step lr (default: 0.1)
145 | --milestones str MultiStepLR scheduler (default: [8, 10, 15])
146 | --min_lr float ReduceLROnPlateau scheduler (default: 5e-09)
147 | --patience int (default: 10)
148 | ```
149 |
150 |
151 |
152 | ## :sound: Dataset
153 |
154 | The project is based on [Mozilla CommonVoice dataset](https://commonvoice.mozilla.org/fr) available on [HuggingFace](https://huggingface.co/datasets/common_voice).
155 | When the script is launched, the program will automatically download the correct dataset and transform ground truth sentences to phonemes using [phonemizer](https://github.com/bootphon/phonemizer). You are free to chose any dataset available on HuggingFace with phonemes dictionaries previously cited to run your models. For our experiments we use:
156 | ```
157 | it, nl, tr, ru, sv-SE
158 | ```
159 | Feel free to try any other languages and submit a Pull Request :electric_plug:.
160 |
161 | ## :paperclip: Pre-trained models
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 | Schema of Wav2vec2, HuBERT and WavLM.
170 |
171 |
172 | For our experiments, we used models hosted on Hugging Face library, that are pre-trained on 960 hours of **English** audio data from Librispeech dataset on 16kHz sampled speech audio. The following pre-trained models were used:
173 | - Wav2vec2 *Base*: [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h)
174 | - WavLM *Base*: [microsoft/wavlm-base](https://huggingface.co/microsoft/wavlm-base)
175 | - WavLM *Large*: [microsoft/wavlm-large](https://huggingface.co/microsoft/wavlm-large)
176 | - HuBERT *Large*: [facebook/hubert-large-ls960-ft](https://huggingface.co/facebook/hubert-large-ls960-ft)
177 |
178 |
179 |
180 | ## :family: Language Family
181 |
182 |
183 |
184 | The language family tree can be found in the following figure. This gives insight on the genetic proximity of each language.
185 |
186 |
187 |
188 |
189 |
190 | | Language | Family | Proximity with English |
191 | |----------|--------|------------------------|
192 | | Italian :it: | *Romance* | 47.8 |
193 | | Russian :ru: | *East Slavic* | 60.3 |
194 | | Dutch 🇳🇱 | *West Germanic* | 27.2 |
195 | | Swedish 🇸🇪 | *North Germanic* | 26.7 |
196 | | Turkish :tr: | *Turkic* | 92.0 |
197 |
198 |
199 |
200 |
201 | Genetic proximity between languages studied and english computed [here](http://www.elinguistics.net/Compare_Languages.aspx). [1, 30]: Highly related languages, [30, 50]: Related languages, [50, 70]: Remotely related languages, [70, 78]: Very remotely related languages, [78, 100]: No recognizable relationship.
202 |
203 |
204 | **English** is a part of the *West Germanic* family.\
205 | Source: https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md and http://www.elinguistics.net/Compare_Languages.aspx
206 |
207 | ## :chart_with_upwards_trend: Main results
208 |
209 | dataset: Common Voice Corpus 6.1 : https://commonvoice.mozilla.org/fr/datasets
210 |
211 | Pretrained English models to other languages
212 |
213 | ### 🚀 Fine-tuning
214 |
215 |
216 |
217 |
218 |
219 | | Language | Training data (in hours) | Model| PER validation | PER test | Runs|
220 | |-|-|-|-|-|-|
221 | | Italian :it: | 62\.34| Wav2Vec2 *Base* | 19\.05| 17\.95 | []() |
222 | | | | Hubert *Large* | **14\.05** | **12\.67** | []() |
223 | | | | WavLM *Base* | 19\.83 | 25\.60 | []() |
224 | | Russian :ru: | 15\.55 | Wav2Vec2 *Base* | 32\.16 | 31\.66 | []() |
225 | | | | Hubert *Large* | 25\.10 | 24\.09 | []() |
226 | | | | WavLM *Base* | **20\.25** | **18\.88** | []() |
227 | | Dutch 🇳🇱 | 12\.78 | Wav2Vec2 *Base* | 16\.18 | 20\.83 | []() |
228 | | | | Hubert *Large* | **12\.77** | **16\.49** | []() |
229 | | | | WavLM *Base* | 15\.96 | 19\.91 | []() |
230 | | Swedish 🇸🇪 | 3\.22 | Wav2Vec2 *Base* | 26\.50 | 24\.16 | []() |
231 | | | | Hubert *Large* | **21\.77** | **19\.38** | []() |
232 | | | | WavLM *Base* | 26\.86 | 24\.61 | []() |
233 | | Turkish :tr: | 2\.52 | Wav2Vec2 *Base* | 19\.62 | 19\.03 | []() |
234 | | | | Hubert *Large* | **15\.51** | **14\.19** | []() |
235 | | | | WavLM *Base* | 19\.85 | 18\.95 | []() |
236 | | Average | \- | Wav2Vec2 *Base* | 22\.70 | 22\.73 | |
237 | | | | Hubert *Large* | **17\.84** | **17\.36** | |
238 | | | | WavLM *Base* | 20\.55 | 21\.59 | |
239 |
240 |
241 | Table of experiments when models are **fine tuned**. Here, we compare 3 different pretrained models. The models were fine tuned on the phoneme recognition task with different languages and a varying amount of training data.
242 |
243 |
244 | ### 🧊 Frozen Features
245 |
246 | | Language | Training data (in hours) | Model | PER validation | PER test | Runs |
247 | |-|-|-|-|-|-|
248 | | Italian :it: | 62\.34 | Wav2Vec2 *Base* | 38\.94 | 36\.84 | []() |
249 | | | | WavLM *Base* | **27\.29** | **25\.98** | []() |
250 | | | | Hubert *Large* | 23\.85 | 21\.15 | []() |
251 | | | | WavLM *Large* | **21\.02** | **18\.80** | []() |
252 | | Russian :ru: | 15\.55 | Wav2Vec2 *Base* | 50\.11 | 48\.69 | []() |
253 | | | | WavLM *Base* | **40\.66** | **38\.76** | []() |
254 | | | | Hubert *Large* | 38\.36 | 36\.18 | []() |
255 | | | | WavLM *Large* | **34\.48** | **32\.26** | []() |
256 | | Dutch 🇳🇱| 12\.78 | Wav2Vec2 *Base* | 40\.15 | 39\.23 | []() |
257 | | | | WavLM *Base* | **34\.94** | **35\.67** | []() |
258 | | | | Hubert *Large* | **27\.62** | **26\.68** | []() |
259 | | | | WavLM *Large* | 27\.71 | 27\.19 | []() |
260 | | Swedish 🇸🇪 | 3\.22 | Wav2Vec2 *Base* | 50\.30 | 45\.23 | []() |
261 | | | | WavLM *Base* | **43\.65** | **40\.55** | []() |
262 | | | | Hubert *Large* | 37\.34 | **32\.68** | []() |
263 | | | | WavLM *Large* | **37\.25** | 33\.14 | []() |
264 | | Turkish :tr: | 2\.52 | Wav2Vec2 *Base* | 53\.92 | 52\.08 | []() |
265 | | | | WavLM *Base* | **47\.18** | **45\.53** | []() |
266 | | | | Hubert *Large* | 39\.55 | 37\.08 | []() |
267 | | | | WavLM *Large* | **30\.66** | **30\.14** | []() |
268 | | Average | \- | Wav2Vec2 *Base* | 46\.68 | 44\.41 | |
269 | | | | WavLM *Base* | **38\.74** | **37\.30** | |
270 | | | | Hubert *Large* | 33\.34 | 30\.75 | |
271 | | | | WavLM *Large* | **30\.22** | **28\.31** | |
272 |
273 |
274 | Table of experiments using **frozen features**. Here, we compare 4 different pretrained models. The objective was to train a linear layer, using pretrained models' frozen features, on the phoneme recognition task with different languages and a varying amount of training data.
275 |
276 |
277 | ### ⌚ Training data
278 |
279 | | Training set | Training data | Model | PER validation | PER test | Runs |
280 | |-|-|-|-|-|-|
281 | | 5% | \~ 10 min | Wav2Vec2 *Base* | 55\.35 | 50\.91 | []() |
282 | | || Hubert *Large* | 44\.96 | 39\.38 | []() |
283 | | || WavLM *Base* | 56\.22 | 51\.25 | []() |
284 | | 10% | \~ 20 min | Wav2Vec2 *Base* | 52\.97 | 49\.01 | []() |
285 | | || Hubert *Large* | 42\.61 | 37\.50 | []() |
286 | | || WavLM *Base* | 46\.54 | 43\.64 | []() |
287 | | 50% | \~ 2 h | Wav2Vec2 *Base* | 51\.23 | 46\.24 | []() |
288 | | || Hubert *Large* | 39\.91 | 35\.27 | []() |
289 | | || WavLM *Base* | 44\.57 | 42\.33 | []() |
290 | | 100% | \~ 3 h | Wav2Vec2 *Base* | 50\.30 | 45\.23 | []() |
291 | | || Hubert *Large* | 37\.34 | 32\.68 | []() |
292 | | || WavLM *Base* | 43\.65 | 40\.55 | []() |
293 |
294 |
295 |
296 | Variation in the amount of training data with frozen features of models pre-trained with the 3 different methods. Language: Swedish 🇸🇪.
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 | PER on the test and validation sets vs Training data for the Swedish language with frozen features.
305 |
306 |
307 | ## :pushpin: Project structure
308 |
309 | ```
310 | ├── agents
311 | | ├── BaseTrainer.py
312 | |
313 | ├── assets # database and vocab phonemes are put here
314 | |
315 | ├── config
316 | | ├── hparams.py # configuration file
317 | |
318 | ├── Datasets
319 | | |
320 | | ├── datamodule.py # datamodules PyTorch lightning for CommonVoice dataset
321 | |
322 | ├── models
323 | | ├── BaseModule.py # lightning module
324 | | ├── models.py # Wav2vec2 WavLM and Hubert using Hugging Face library
325 | |
326 | ├── utils # utils functions
327 | | ├── agent_utils.py
328 | | ├── callbacks.py
329 | | ├── dataset_utils.py
330 | | ├── logger.py
331 | | ├── metrics.py
332 | | ├── per.py # torch metrics implementation of the phoneme error rate
333 | |
334 | ├── hparams.py # configuration file
335 | |
336 | ├── main.py # main script to launch for training of inference
337 | |
338 | └── README.md
339 | ```
340 |
341 | ## ⚡ Powered by
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
--------------------------------------------------------------------------------
/agents/BaseTrainer.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import torch
3 | import wandb
4 | from models.BaseModule import BaseModule
5 | from pytorch_lightning.callbacks import (
6 | LearningRateMonitor,
7 | RichProgressBar,
8 | EarlyStopping,
9 | )
10 | from utils.agent_utils import get_artifact, get_datamodule
11 | from utils.callbacks import (
12 | AutoSaveModelCheckpoint,
13 | LogMetricsCallback,
14 | LogAudioPrediction,
15 | )
16 | from utils.logger import init_logger
17 |
18 | from utils.dataset_utils import create_vocabulary, create_vocabulary2
19 |
20 |
21 | class BaseTrainer:
22 | def __init__(self, config, run=None) -> None:
23 | self.config = config.hparams
24 | self.wb_run = run
25 | self.network_param = config.network_param
26 |
27 | self.logger = init_logger("BaseTrainer", "INFO")
28 |
29 | self.logger.info(
30 | f"Create vocabulary language : {config.data_param.language} ..."
31 | )
32 |
33 | if config.data_param.subset == "en":
34 | (
35 | config.network_param.vocab_file,
36 | config.network_param.len_vocab,
37 | ) = create_vocabulary(
38 | config.data_param.language,
39 | config.data_param.phoible_csv_path,
40 | eos_token=config.network_param.eos_token,
41 | bos_token=config.network_param.bos_token,
42 | unk_token=config.network_param.unk_token,
43 | pad_token=config.network_param.pad_token,
44 | word_delimiter_token=config.network_param.word_delimiter_token,
45 | )
46 | else:
47 | (
48 | config.network_param.vocab_file,
49 | config.network_param.len_vocab,
50 | ) = create_vocabulary2(
51 | config.data_param.language,
52 | config.data_param.root_path_annotation,
53 | eos_token=config.network_param.eos_token,
54 | bos_token=config.network_param.bos_token,
55 | unk_token=config.network_param.unk_token,
56 | pad_token=config.network_param.pad_token,
57 | word_delimiter_token=config.network_param.word_delimiter_token,
58 | )
59 |
60 | self.logger.info(f"Vocabulary file : {config.network_param.vocab_file}")
61 |
62 | self.logger.info("Loading Data module...")
63 | self.datamodule = get_datamodule(config.data_param)
64 |
65 | self.logger.info("Loading Model module...")
66 | self.pl_model = BaseModule(config.network_param, config.optim_param)
67 |
68 | self.wb_run.watch(self.pl_model.model)
69 |
70 | def run(self):
71 | if self.config.tune_lr:
72 | tune_lr_trainer = pl.Trainer(
73 | logger=self.wb_run,
74 | gpus=self.config.gpu,
75 | auto_lr_find=True,
76 | accelerator="auto",
77 | default_root_dir=self.wb_run.save_dir,
78 | )
79 | tune_lr_trainer.logger = self.wb_run
80 |
81 | if not self.config.debug:
82 | torch.autograd.set_detect_anomaly(False)
83 | torch.autograd.profiler.profile(False)
84 | torch.autograd.profiler.emit_nvtx(False)
85 | torch.backends.cudnn.benchmark = True
86 |
87 | trainer = pl.Trainer(
88 | logger=self.wb_run, # W&B integration
89 | callbacks=self.get_callbacks(),
90 | gpus=self.config.gpu, # use all available GPU's
91 | max_epochs=self.config.max_epochs, # number of epochs
92 | log_every_n_steps=1,
93 | fast_dev_run=self.config.dev_run,
94 | amp_backend="apex",
95 | enable_progress_bar=self.config.enable_progress_bar,
96 | val_check_interval=self.config.val_check_interval,
97 | limit_train_batches=self.config.limit_train_batches,
98 | limit_val_batches=self.config.limit_val_batches,
99 | accumulate_grad_batches=self.config.accumulate_grad_batches,
100 | )
101 |
102 | trainer.logger = self.wb_run
103 |
104 | self.datamodule.load_data("train")
105 | self.datamodule.process_dataset("train", self.pl_model.processor)
106 |
107 | self.datamodule.load_data("val")
108 | self.datamodule.process_dataset("val", self.pl_model.processor)
109 |
110 | if self.config.tune_lr:
111 | tune_lr_trainer.tune(self.pl_model, datamodule=self.datamodule)
112 |
113 | trainer.fit(self.pl_model, datamodule=self.datamodule)
114 |
115 | @torch.no_grad()
116 | def predict(self):
117 | if not self.config.debug:
118 | torch.autograd.set_detect_anomaly(False)
119 | torch.autograd.profiler.profile(False)
120 | torch.autograd.profiler.emit_nvtx(False)
121 | torch.backends.cudnn.benchmark = True
122 |
123 | trainer = pl.Trainer(
124 | logger=self.wb_run, # W&B integration
125 | callbacks=self.get_callbacks(),
126 | gpus=self.config.gpu, # use all available GPU's
127 | log_every_n_steps=1,
128 | fast_dev_run=self.config.dev_run,
129 | amp_backend="apex",
130 | enable_progress_bar=self.config.enable_progress_bar,
131 | )
132 |
133 | trainer.logger = self.wb_run
134 |
135 | self.datamodule.load_data("test")
136 | self.datamodule.process_dataset("test", self.pl_model.processor)
137 |
138 | path_model = f"{self.config.wandb_entity}/{self.config.wandb_project}/{self.config.best_model_run}:top-1"
139 | best_model_path = get_artifact(path_model, type="model")
140 |
141 | trainer.test(self.pl_model, self.datamodule, ckpt_path=best_model_path)
142 |
143 | return
144 |
145 | def get_callbacks(self):
146 | callbacks = [
147 | LearningRateMonitor(),
148 | LogMetricsCallback(),
149 | LogAudioPrediction(self.config.log_freq_audio, self.config.log_nb_audio),
150 | ]
151 |
152 | if self.config.enable_progress_bar:
153 | callbacks += [RichProgressBar()]
154 |
155 | if self.config.early_stopping:
156 | callbacks += [EarlyStopping(**self.config.early_stopping_params)]
157 |
158 | monitor = "val/per"
159 | mode = "min"
160 | wandb.define_metric(monitor, summary=mode)
161 | save_top_k = 1
162 | every_n_epochs = 1
163 | callbacks += [
164 | AutoSaveModelCheckpoint( # ModelCheckpoint
165 | config=(self.network_param).__dict__,
166 | project=self.config.wandb_project,
167 | entity=self.config.wandb_entity,
168 | monitor=monitor,
169 | mode=mode,
170 | filename="epoch-{epoch:02d}-val_per={val/per:.2f}",
171 | verbose=True,
172 | dirpath=self.config.weights_path + f"/{str(wandb.run.name)}",
173 | save_top_k=save_top_k,
174 | every_n_epochs=every_n_epochs,
175 | auto_insert_metric_name=False,
176 | )
177 | ] # our model checkpoint callback
178 |
179 | return callbacks
180 |
--------------------------------------------------------------------------------
/agents/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/agents/__init__.py
--------------------------------------------------------------------------------
/assets/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/.keep
--------------------------------------------------------------------------------
/assets/img_readme/ASR_language family.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/ASR_language family.png
--------------------------------------------------------------------------------
/assets/img_readme/Network.drawio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/Network.drawio.png
--------------------------------------------------------------------------------
/assets/img_readme/PER test vs Training data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/PER test vs Training data.png
--------------------------------------------------------------------------------
/assets/img_readme/PER validation vs Training data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/PER validation vs Training data.png
--------------------------------------------------------------------------------
/assets/img_readme/hubert.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/hubert.jpeg
--------------------------------------------------------------------------------
/assets/img_readme/parrot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/parrot.png
--------------------------------------------------------------------------------
/assets/img_readme/wav2vec2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/wav2vec2.png
--------------------------------------------------------------------------------
/assets/img_readme/wavlm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/assets/img_readme/wavlm.png
--------------------------------------------------------------------------------
/conda_environment.yml:
--------------------------------------------------------------------------------
1 | name: speech-project
2 | channels:
3 | - pytorch
4 | - huggingface
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=conda_forge
9 | - _openmp_mutex=4.5=1_llvm
10 | - abseil-cpp=20210324.2=h2531618_0
11 | - absl-py=0.15.0=pyhd3eb1b0_0
12 | - aiohttp=3.8.1=py38h7f8727e_0
13 | - aiosignal=1.2.0=pyhd3eb1b0_0
14 | - appdirs=1.4.4=pyh9f0ad1d_0
15 | - arrow-cpp=6.0.1=py38h9de81b1_5_cpu
16 | - async-timeout=4.0.1=pyhd3eb1b0_0
17 | - attrs=21.4.0=pyhd3eb1b0_0
18 | - audioread=2.1.9=py38h578d9bd_2
19 | - aws-c-auth=0.6.8=hadad3cd_1
20 | - aws-c-cal=0.5.12=h70efedd_7
21 | - aws-c-common=0.6.17=h7f98852_0
22 | - aws-c-compression=0.2.14=h7c7754b_7
23 | - aws-c-event-stream=0.2.7=hd2be095_32
24 | - aws-c-http=0.6.10=h416565a_3
25 | - aws-c-io=0.10.14=he836878_0
26 | - aws-c-mqtt=0.7.10=h885097b_0
27 | - aws-c-s3=0.1.29=h8d70ed6_0
28 | - aws-c-sdkutils=0.1.1=h7c7754b_4
29 | - aws-checksums=0.1.12=h7c7754b_6
30 | - aws-crt-cpp=0.17.10=h6ab17b9_5
31 | - aws-sdk-cpp=1.9.160=h36ff4c5_0
32 | - blas=1.0=mkl
33 | - blinker=1.4=py38h06a4308_0
34 | - bottleneck=1.3.2=py38heb32a55_1
35 | - brotli=1.0.9=h7f98852_6
36 | - brotli-bin=1.0.9=h7f98852_6
37 | - brotlipy=0.7.0=py38h27cfd23_1003
38 | - bzip2=1.0.8=h7b6447c_0
39 | - c-ares=1.18.1=h7f8727e_0
40 | - ca-certificates=2022.3.18=h06a4308_0
41 | - cachetools=4.2.2=pyhd3eb1b0_0
42 | - certifi=2021.10.8=py38h06a4308_2
43 | - cffi=1.15.0=py38hd667e15_1
44 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
45 | - click=7.1.2=pyhd3eb1b0_0
46 | - colorama=0.4.4=pyhd3eb1b0_0
47 | - commonmark=0.9.1=pyhd3eb1b0_0
48 | - configparser=5.0.2=pyhd3eb1b0_0
49 | - cryptography=36.0.0=py38h9ce1e76_0
50 | - cudatoolkit=10.2.89=hfd86e86_1
51 | - cycler=0.11.0=pyhd8ed1ab_0
52 | - dataclasses=0.8=pyh6d0b6a4_7
53 | - datasets=1.18.3=py_0
54 | - decorator=5.1.1=pyhd8ed1ab_0
55 | - dill=0.3.4=pyhd3eb1b0_0
56 | - docker-pycreds=0.4.0=pyhd3eb1b0_0
57 | - ffmpeg=4.2.2=h20bf706_0
58 | - filelock=3.4.2=pyhd3eb1b0_0
59 | - fonttools=4.29.1=py38h497a2fe_0
60 | - freetype=2.11.0=h70c0345_0
61 | - frozenlist=1.2.0=py38h7f8727e_0
62 | - fsspec=2022.1.0=pyhd3eb1b0_0
63 | - future=0.18.2=py38_1
64 | - gettext=0.21.0=hf68c758_0
65 | - gflags=2.2.2=he6710b0_0
66 | - giflib=5.2.1=h7b6447c_0
67 | - gitdb=4.0.7=pyhd3eb1b0_0
68 | - gitpython=3.1.18=pyhd3eb1b0_1
69 | - glog=0.5.0=h2531618_0
70 | - gmp=6.2.1=h2531618_2
71 | - gnutls=3.6.15=he1e5248_0
72 | - google-auth=1.33.0=pyhd3eb1b0_0
73 | - google-auth-oauthlib=0.4.1=py_2
74 | - grpc-cpp=1.42.0=ha1441d3_1
75 | - grpcio=1.42.0=py38hce63b2e_0
76 | - huggingface_hub=0.2.1=pyhd3eb1b0_0
77 | - icu=58.2=he6710b0_3
78 | - idna=3.3=pyhd3eb1b0_0
79 | - importlib-metadata=4.8.2=py38h06a4308_0
80 | - importlib_metadata=4.8.2=hd3eb1b0_0
81 | - intel-openmp=2021.4.0=h06a4308_3561
82 | - jbig=2.1=hdba287a_0
83 | - joblib=1.1.0=pyhd3eb1b0_0
84 | - jpeg=9d=h7f8727e_0
85 | - kiwisolver=1.3.2=py38h1fd1430_1
86 | - krb5=1.19.2=hac12032_0
87 | - lame=3.100=h7b6447c_0
88 | - lcms2=2.12=h3be6417_0
89 | - ld_impl_linux-64=2.35.1=h7274673_9
90 | - lerc=3.0=h295c915_0
91 | - libbrotlicommon=1.0.9=h7f98852_6
92 | - libbrotlidec=1.0.9=h7f98852_6
93 | - libbrotlienc=1.0.9=h7f98852_6
94 | - libcurl=7.80.0=h0b77cf5_0
95 | - libdeflate=1.8=h7f8727e_5
96 | - libedit=3.1.20210910=h7f8727e_0
97 | - libev=4.33=h7f8727e_1
98 | - libevent=2.1.12=h8f2d780_0
99 | - libffi=3.3=he6710b0_2
100 | - libflac=1.3.4=h27087fc_0
101 | - libgcc-ng=11.2.0=h1d223b6_12
102 | - libgfortran-ng=7.5.0=h14aa051_20
103 | - libgfortran4=7.5.0=h14aa051_20
104 | - libidn2=2.3.2=h7f8727e_0
105 | - libllvm11=11.1.0=hf817b99_3
106 | - libnghttp2=1.46.0=hce63b2e_0
107 | - libogg=1.3.5=h27cfd23_1
108 | - libopus=1.3.1=h7b6447c_0
109 | - libpng=1.6.37=hbc83047_0
110 | - libprotobuf=3.19.1=h4ff587b_0
111 | - librosa=0.9.1=pyhd8ed1ab_0
112 | - libsndfile=1.0.31=h9c3ff4c_1
113 | - libssh2=1.9.0=h1ba5d50_1
114 | - libstdcxx-ng=11.2.0=he4da1e4_12
115 | - libtasn1=4.16.0=h27cfd23_0
116 | - libthrift=0.15.0=hcc01f38_0
117 | - libtiff=4.3.0=h6f004c6_2
118 | - libunistring=0.9.10=h27cfd23_0
119 | - libutf8proc=2.6.1=h27cfd23_0
120 | - libuv=1.40.0=h7b6447c_0
121 | - libvorbis=1.3.7=h7b6447c_0
122 | - libvpx=1.7.0=h439df22_0
123 | - libwebp=1.2.2=h55f646e_0
124 | - libwebp-base=1.2.2=h7f8727e_0
125 | - libxml2=2.9.12=h03d6c58_0
126 | - libzlib=1.2.11=h36c2ea0_1013
127 | - llvm-openmp=13.0.1=hf817b99_0
128 | - llvmlite=0.38.0=py38h4630a5e_0
129 | - lz4-c=1.9.3=h295c915_1
130 | - markdown=3.3.4=py38h06a4308_0
131 | - matplotlib-base=3.5.1=py38hf4fb855_0
132 | - mkl=2021.4.0=h06a4308_640
133 | - mkl-service=2.4.0=py38h7f8727e_0
134 | - mkl_fft=1.3.1=py38hd3c417c_0
135 | - mkl_random=1.2.2=py38h51133e4_0
136 | - multidict=5.2.0=py38h7f8727e_2
137 | - multiprocess=0.70.12.2=py38h7f8727e_0
138 | - munkres=1.1.4=pyh9f0ad1d_0
139 | - mypy_extensions=0.4.3=py38h06a4308_1
140 | - ncurses=6.3=h7f8727e_2
141 | - nettle=3.7.3=hbbd107a_1
142 | - numba=0.55.1=py38h4bf6c61_0
143 | - numexpr=2.8.1=py38h6abb31d_0
144 | - numpy=1.21.2=py38h20f2e39_0
145 | - numpy-base=1.21.2=py38h79a1101_0
146 | - oauthlib=3.2.0=pyhd8ed1ab_0
147 | - olefile=0.46=pyhd3eb1b0_0
148 | - openh264=2.1.1=h4ff587b_0
149 | - openssl=1.1.1n=h7f8727e_0
150 | - orc=1.7.1=h1be678f_1
151 | - packaging=21.3=pyhd3eb1b0_0
152 | - pandas=1.4.1=py38h295c915_0
153 | - parquet-cpp=1.5.1=h34088ae_4
154 | - pathtools=0.1.2=pyhd3eb1b0_1
155 | - pillow=8.4.0=py38h5aabda8_0
156 | - pip=21.2.4=py38h06a4308_0
157 | - pooch=1.6.0=pyhd8ed1ab_0
158 | - promise=2.3=py38h06a4308_0
159 | - protobuf=3.19.1=py38h295c915_0
160 | - psutil=5.8.0=py38h27cfd23_1
161 | - pyarrow=6.0.1=py38he7e5f7d_5_cpu
162 | - pyasn1=0.4.8=pyhd3eb1b0_0
163 | - pyasn1-modules=0.2.8=py_0
164 | - pycparser=2.21=pyhd3eb1b0_0
165 | - pydeprecate=0.3.2=pyhd8ed1ab_0
166 | - pygments=2.11.2=pyhd3eb1b0_0
167 | - pyjwt=2.3.0=pyhd8ed1ab_1
168 | - pyopenssl=22.0.0=pyhd3eb1b0_0
169 | - pyparsing=3.0.4=pyhd3eb1b0_0
170 | - pysocks=1.7.1=py38h06a4308_0
171 | - pysoundfile=0.10.3.post1=pyhd3deb0d_0
172 | - python=3.8.12=h12debd9_0
173 | - python-dateutil=2.8.2=pyhd3eb1b0_0
174 | - python-xxhash=2.0.2=py38h7f8727e_0
175 | - python_abi=3.8=1_cp38
176 | - pytorch=1.10.2=py3.8_cuda10.2_cudnn7.6.5_0
177 | - pytorch-lightning=1.5.10=pyhd8ed1ab_0
178 | - pytorch-mutex=1.0=cuda
179 | - pytz=2021.3=pyhd3eb1b0_0
180 | - pyyaml=6.0=py38h7f8727e_1
181 | - re2=2021.11.01=h9c3ff4c_0
182 | - readline=8.1.2=h7f8727e_1
183 | - regex=2021.11.2=py38h7f8727e_0
184 | - requests=2.27.1=pyhd3eb1b0_0
185 | - requests-oauthlib=1.3.0=py_0
186 | - resampy=0.2.2=py_0
187 | - rich=11.2.0=pyhd8ed1ab_0
188 | - rsa=4.7.2=pyhd3eb1b0_1
189 | - s2n=1.3.0=h9b69904_0
190 | - sacremoses=0.0.43=pyhd3eb1b0_0
191 | - scikit-learn=1.0.2=py38h51133e4_1
192 | - scipy=1.7.3=py38hc147768_0
193 | - sentry-sdk=1.5.6=pyhd8ed1ab_0
194 | - setuptools=58.0.4=py38h06a4308_0
195 | - shortuuid=1.0.8=py38h578d9bd_0
196 | - simple-parsing=0.0.18=pyhd8ed1ab_0
197 | - six=1.16.0=pyhd3eb1b0_1
198 | - smmap=4.0.0=pyhd3eb1b0_0
199 | - snappy=1.1.8=he6710b0_0
200 | - sqlite=3.37.2=hc218d9a_0
201 | - subprocess32=3.5.4=py_1
202 | - tensorboard=2.6.0=py_1
203 | - tensorboard-data-server=0.6.0=py38hca6d32c_0
204 | - tensorboard-plugin-wit=1.6.0=py_0
205 | - termcolor=1.1.0=py38h06a4308_1
206 | - threadpoolctl=3.1.0=pyh8a188c0_0
207 | - tk=8.6.11=h1ccaba5_0
208 | - tokenizers=0.10.3=py38hb317417_1
209 | - torchaudio=0.10.2=py38_cu102
210 | - torchmetrics=0.7.2=pyhd8ed1ab_0
211 | - torchvision=0.11.3=py38_cu102
212 | - tqdm=4.62.3=pyhd3eb1b0_1
213 | - typing-extensions=3.10.0.2=hd3eb1b0_0
214 | - typing_extensions=3.10.0.2=pyh06a4308_0
215 | - typing_inspect=0.7.1=pyhd3eb1b0_0
216 | - unicodedata2=14.0.0=py38h497a2fe_0
217 | - urllib3=1.26.8=pyhd3eb1b0_0
218 | - wandb=0.12.10=pyhd8ed1ab_0
219 | - werkzeug=2.0.3=pyhd3eb1b0_0
220 | - wheel=0.37.1=pyhd3eb1b0_0
221 | - x264=1!157.20191217=h7b6447c_0
222 | - xxhash=0.8.0=h7f8727e_3
223 | - xz=5.2.5=h7b6447c_0
224 | - yaml=0.2.5=h7b6447c_0
225 | - yarl=1.6.3=py38h27cfd23_0
226 | - yaspin=2.1.0=pyhd8ed1ab_0
227 | - zipp=3.7.0=pyhd3eb1b0_0
228 | - zlib=1.2.11=h36c2ea0_1013
229 | - zstd=1.5.0=ha4553b6_1
230 | - pip:
231 | - clldutils==3.10.1
232 | - colorlog==6.6.0
233 | - csvw==2.0.0
234 | - dlinfo==1.2.1
235 | - isodate==0.6.1
236 | - phonemizer==3.0.1
237 | - python-graphviz==0.19.1
238 | - rfc3986==1.5.0
239 | - segments==2.2.0
240 | - tabulate==0.8.9
241 | - torchviz==0.0.2
242 | - transformers==4.16.0
243 | - uritemplate==4.1.1
244 | - wget==3.2
245 | prefix: /home/clement/anaconda3/envs/speech-project
246 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/config/__init__.py
--------------------------------------------------------------------------------
/config/hparams.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pickle import FALSE
3 | import random
4 | from dataclasses import dataclass
5 | from os import path as osp
6 | from typing import Any, ClassVar, Dict, List, Optional
7 | from simple_parsing.helpers import Serializable, choice, dict_field, list_field
8 |
9 | import pytorch_lightning as pl
10 | import simple_parsing
11 | import torch
12 | import torch.optim
13 |
14 | ################################## Global parameters ##################################
15 |
16 |
17 | @dataclass
18 | class Hparams:
19 | """Hyperparameters of for the run"""
20 |
21 | # wandb
22 | wandb_entity: str = "asr-project" # name of the project
23 | debug: bool = (
24 | False # test code before running, if testing, no checkpoints are written
25 | )
26 | test: bool = True
27 | wandb_project: str = f"{'test-'*test}asr"
28 | root_dir: str = os.getcwd() # root_dir
29 |
30 | # basic params
31 | seed_everything: Optional[int] = None # seed for the whole run
32 | gpu: int = 1 # number or gpu
33 | max_epochs: int = 100 # maximum number of epochs
34 | weights_path: str = osp.join(os.getcwd(), "weights")
35 |
36 | # modes
37 | tune_lr: bool = False # tune the model on first run
38 | dev_run: bool = False
39 | train: bool = True
40 |
41 | best_model: str = ""
42 |
43 | log_freq_audio: int = 10
44 | log_nb_audio: int = 2
45 |
46 | # trainer params
47 | val_check_interval: float = 1.0 # 1.0 (at the end of the epoch)
48 | limit_train_batches: float = 1.0 # 1.0
49 | limit_val_batches: float = 1.0 # 1.0
50 | enable_progress_bar: bool = True
51 |
52 | # testing params
53 | best_model_run: str = "WavLM_sv"
54 |
55 | # Early Stopping
56 | early_stopping: bool = True
57 | early_stopping_params: Dict[str, Any] = dict_field(
58 | dict(monitor="val/per", patience=10, mode="min", verbose=True)
59 | )
60 |
61 |
62 | @dataclass
63 | class NetworkParams:
64 | network_name: str = "WavLM" # Hubert, Wav2Vec2, WavLM
65 | pretrained_name: Optional[str] = ""
66 |
67 | freeze: bool = True
68 | freeze_transformer: bool = True
69 |
70 | # Phoneme Tokenizer
71 | eos_token: str = ""
72 | bos_token: str = ""
73 | unk_token: str = ""
74 | pad_token: str = ""
75 | word_delimiter_token: str = "|"
76 |
77 |
78 | @dataclass
79 | class DatasetParams:
80 | """Dataset Parameters
81 | ! The batch_size and number of crops should be defined here
82 | """
83 |
84 | # Hugging Face datasets parameters
85 | dataset_name: str = "common_voice" # https://huggingface.co/mozilla-foundation or https://huggingface.co/datasets/common_voice # dataset, use Eval for FT
86 | use_auth_token: bool = False # True if use mozilla-foundation datasets
87 | subset: str = (
88 | "sv-SE" # chosen language (see https://huggingface.co/datasets/common_voice)
89 | )
90 | download_mode: str = "reuse_dataset_if_exists"
91 | cache_dir: str = osp.join(os.getcwd(), "assets")
92 |
93 | # to create vocabulary of phonemes
94 | language: str = "sv"
95 | root_path_annotation: str = osp.join(os.getcwd(), "assets", "common_voices_splits")
96 | phoible_csv_path: str = osp.join(os.getcwd(), "assets")
97 |
98 | # Dataloader parameters
99 | num_workers: int = 20 # number of workers for dataloaders
100 | batch_size: int = 2
101 |
102 | # Dataset processing parameters
103 | max_input_length_in_sec: float = 5
104 | num_proc: int = 4
105 |
106 | create_dataset: bool = False
107 |
108 |
109 | @dataclass
110 | class OptimizerParams:
111 | """Optimization parameters"""
112 |
113 | optimizer: str = "AdamW"
114 | lr: float = 2e-2
115 | weight_decay: float = 1e-8
116 |
117 | accumulate_grad_batches: int = 16 # 1 for no accumulation
118 |
119 | # Scheduler parameters
120 | scheduler: Optional[
121 | str
122 | ] = None # Cosine, ReduceLROnPlateau, MultiStepLR, StepLR or None
123 |
124 | # Cosine scheduler
125 | max_epochs: int = 10
126 | warmup_epochs: int = 1
127 | warmup_start_lr: float = 6e-4
128 | eta_min: float = 5e-6
129 |
130 | # Step LR scheduler
131 | step_size: int = 2
132 | gamma: float = 0.1 # also for multi step lr
133 |
134 | # MultiStepLR scheduler
135 | milestones: List[Any] = list_field(8, 10, 15)
136 |
137 | # ReduceLROnPlateau scheduler
138 | min_lr: float = 5e-9
139 | patience: int = 10
140 |
141 |
142 | @dataclass
143 | class Parameters:
144 | """base options."""
145 |
146 | hparams: Hparams = Hparams()
147 | data_param: DatasetParams = DatasetParams()
148 | network_param: NetworkParams = NetworkParams()
149 | optim_param: OptimizerParams = OptimizerParams()
150 |
151 | def __post_init__(self):
152 | """Post-initialization code"""
153 | if self.hparams.seed_everything is None:
154 | self.hparams.seed_everything = random.randint(1, 10000)
155 |
156 | self.hparams.wandb_project = f"{'test-'*self.hparams.test}asr"
157 |
158 | self.network_param.phonemizer_lang = self.data_param.language
159 | print(f"Phonemizer language : {self.network_param.phonemizer_lang }")
160 |
161 | random.seed(self.hparams.seed_everything)
162 | torch.manual_seed(self.hparams.seed_everything)
163 | pl.seed_everything(self.hparams.seed_everything)
164 |
165 | if self.network_param.pretrained_name == "":
166 | if self.network_param.network_name == "Wav2Vec2":
167 | # self.network_param.pretrained_name = "facebook/wav2vec2-xlsr-53-espeak-cv-ft"
168 | self.network_param.pretrained_name = "facebook/wav2vec2-base-960h"
169 | elif self.network_param.network_name == "WavLM":
170 | # self.network_param.pretrained_name = "microsoft/wavlm-base"
171 | self.network_param.pretrained_name = "microsoft/wavlm-large"
172 | elif self.network_param.network_name == "Hubert":
173 | self.network_param.pretrained_name = "facebook/hubert-large-ls960-ft"
174 | else:
175 | raise NotImplementedError(
176 | "Only Wav2Vec2, WavLM and Hubert are available !"
177 | )
178 | print(f"Pretrained model: {self.network_param.pretrained_name}")
179 |
180 | self.data_param.wandb_project = self.hparams.wandb_project
181 | self.hparams.accumulate_grad_batches = self.optim_param.accumulate_grad_batches
182 |
183 | @classmethod
184 | def parse(cls):
185 | parser = simple_parsing.ArgumentParser()
186 | parser.add_arguments(cls, dest="parameters")
187 | args = parser.parse_args()
188 | instance: Parameters = args.parameters
189 | return instance
190 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import faulthandler
2 | from pytest import param
3 |
4 | faulthandler.enable()
5 |
6 | from pytorch_lightning.loggers import WandbLogger
7 |
8 | # Standard libraries
9 | import wandb
10 | from agents.BaseTrainer import BaseTrainer
11 | from config.hparams import Parameters
12 | from utils.agent_utils import parse_params
13 |
14 |
15 | def main():
16 | parameters = Parameters.parse()
17 |
18 | # initialize wandb instance
19 | wdb_config = parse_params(parameters)
20 |
21 | if parameters.hparams.train:
22 | tags = [
23 | parameters.data_param.dataset_name,
24 | parameters.data_param.subset,
25 | parameters.optim_param.optimizer,
26 | parameters.network_param.network_name,
27 | f"{'not'*(not parameters.network_param.freeze)} freezed",
28 | parameters.network_param.pretrained_name,
29 | ]
30 |
31 | if parameters.hparams.limit_train_batches != 1.0:
32 | tags += [f"{parameters.hparams.limit_train_batches}_train"]
33 | if parameters.network_param.freeze_transformer:
34 | tags += ["transformer_freezed"]
35 |
36 | wandb.init(
37 | name=f"{parameters.network_param.network_name}_{parameters.data_param.language}{'_CNN_not_freezed'*(not parameters.network_param.freeze)}{f'_{parameters.hparams.limit_train_batches}_train'*(parameters.hparams.limit_train_batches!=1.0)}{'_tf_freezed'*(parameters.network_param.freeze_transformer)}",
38 | config=wdb_config,
39 | project=parameters.hparams.wandb_project,
40 | entity=parameters.hparams.wandb_entity,
41 | allow_val_change=True,
42 | job_type="train",
43 | tags=tags,
44 | )
45 |
46 | wandb_run = WandbLogger(
47 | config=wdb_config,
48 | project=parameters.hparams.wandb_project,
49 | entity=parameters.hparams.wandb_entity,
50 | allow_val_change=True,
51 | )
52 |
53 | agent = BaseTrainer(parameters, wandb_run)
54 | agent.run()
55 | else:
56 | tags = [
57 | parameters.data_param.dataset_name,
58 | parameters.data_param.subset,
59 | parameters.data_param.language,
60 | parameters.network_param.network_name,
61 | f"{'not'*(not parameters.network_param.freeze)} freezed",
62 | parameters.network_param.pretrained_name,
63 | "test",
64 | ]
65 | if parameters.hparams.limit_train_batches != 1.0:
66 | tags += [f"{parameters.hparams.limit_train_batches}_train"]
67 | if parameters.network_param.freeze_transformer:
68 | tags += ["transformer_freezed"]
69 |
70 | wandb_run = wandb.init(
71 | name=parameters.hparams.best_model_run + "_test",
72 | config=wdb_config,
73 | project=parameters.hparams.wandb_project,
74 | entity=parameters.hparams.wandb_entity,
75 | allow_val_change=True,
76 | job_type="test",
77 | tags=tags,
78 | )
79 |
80 | wandb_logger = WandbLogger(
81 | config=wdb_config,
82 | project=parameters.hparams.wandb_project,
83 | entity=parameters.hparams.wandb_entity,
84 | allow_val_change=True,
85 | )
86 |
87 | agent = BaseTrainer(parameters, wandb_logger)
88 | agent.predict()
89 |
90 |
91 | if __name__ == "__main__":
92 | main()
93 |
--------------------------------------------------------------------------------
/models/BaseModule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from pytorch_lightning import LightningModule
5 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, MultiStepLR
6 | from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
7 | from transformers import (
8 | Wav2Vec2PhonemeCTCTokenizer,
9 | Wav2Vec2Processor,
10 | Wav2Vec2FeatureExtractor,
11 | )
12 | from utils.agent_utils import get_model
13 | from utils.logger import init_logger
14 |
15 | from itertools import chain
16 |
17 |
18 | from torch.profiler import profile, record_function, ProfilerActivity
19 |
20 |
21 | class BaseModule(LightningModule):
22 | def __init__(self, network_param, optim_param):
23 | """
24 | method used to define our model parameters
25 | """
26 | super(BaseModule, self).__init__()
27 |
28 | logger = init_logger("BaseModule", "INFO")
29 |
30 | # Optimizer
31 | self.optim_param = optim_param
32 | self.lr = optim_param.lr
33 |
34 | logger.info(f"Optimizer : {optim_param.optimizer}, lr : {optim_param.lr}")
35 |
36 | # Tokenizer
37 | # https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py
38 | self.phonemes_tokenizer = Wav2Vec2PhonemeCTCTokenizer(
39 | vocab_file=network_param.vocab_file,
40 | eos_token=network_param.eos_token,
41 | bos_token=network_param.bos_token,
42 | unk_token=network_param.unk_token,
43 | pad_token=network_param.pad_token,
44 | word_delimiter_token=network_param.word_delimiter_token,
45 | do_phonemize=False,
46 | return_attention_mask=False,
47 | )
48 |
49 | network_param.vocab_size = self.phonemes_tokenizer.vocab_size
50 |
51 | # Loss function
52 | self.loss = nn.CTCLoss(
53 | blank=self.phonemes_tokenizer.encoder[network_param.word_delimiter_token]
54 | )
55 |
56 | # Feature_extractor
57 | feature_extractor = Wav2Vec2FeatureExtractor(
58 | feature_size=1,
59 | sampling_rate=16000,
60 | padding_value=0.0,
61 | do_normalize=True,
62 | return_attention_mask=False,
63 | )
64 |
65 | logger.info(f"Features extractor : {network_param.network_name}")
66 | self.processor = Wav2Vec2Processor(
67 | feature_extractor=feature_extractor, tokenizer=self.phonemes_tokenizer
68 | )
69 |
70 | # Model
71 | self.model = get_model(network_param.network_name, network_param)
72 | logger.info(f"Model: {network_param.network_name}")
73 |
74 | if network_param.freeze:
75 | self.model.model.freeze_feature_extractor()
76 |
77 | logger.info(f"Feature extactor:{' not'*(not network_param.freeze)} freezed")
78 |
79 | if network_param.freeze_transformer:
80 | self.model.model.requires_grad_(False)
81 | self.model.model.lm_head.requires_grad_(True)
82 |
83 | def forward(self, x):
84 | output = self.model(x)
85 | return output
86 |
87 | def training_step(self, batch, batch_idx):
88 | """needs to return a loss from a single batch"""
89 | loss, logits, preds, targets = self._get_outputs(batch, batch_idx)
90 | if loss != loss:
91 | print("loss is nan, model collapse, exiting")
92 | exit(1)
93 | # Log loss
94 | self.log("train/loss", loss, batch_size=len(preds))
95 |
96 | return {
97 | "loss": loss,
98 | "logits": logits.detach(),
99 | "preds": preds,
100 | "targets": targets,
101 | }
102 |
103 | def validation_step(self, batch, batch_idx):
104 | """used for logging metrics"""
105 | loss, logits, preds, targets = self._get_outputs(batch, batch_idx)
106 |
107 | # Log loss
108 | self.log("val/loss", loss)
109 |
110 | return {"loss": loss, "logits": logits, "preds": preds, "targets": targets}
111 |
112 | def test_step(self, batch, batch_idx):
113 | """used for logging metrics"""
114 | loss, logits, preds, targets = self._get_outputs(batch, batch_idx)
115 |
116 | # Log loss
117 | self.log("test/loss", loss)
118 |
119 | return {"loss": loss, "logits": logits, "preds": preds, "targets": targets}
120 |
121 | def configure_optimizers(self):
122 | """defines model optimizer"""
123 | optimizer = getattr(torch.optim, self.optim_param.optimizer)
124 | optimizer = optimizer(
125 | self.parameters(), lr=self.lr, weight_decay=self.optim_param.weight_decay
126 | )
127 |
128 | if self.optim_param.scheduler != None:
129 | if self.optim_param.scheduler == "Cosine":
130 | scheduler = LinearWarmupCosineAnnealingLR(
131 | optimizer,
132 | warmup_epochs=self.optim_param.warmup_epochs,
133 | max_epochs=self.optim_param.max_epochs,
134 | warmup_start_lr=self.optim_param.warmup_start_lr,
135 | eta_min=self.optim_param.eta_min,
136 | )
137 | elif self.optim_param.scheduler == "StepLR":
138 | scheduler = StepLR(
139 | optimizer,
140 | step_size=self.optim_param.step_size,
141 | gamma=self.optim_param.gamma,
142 | )
143 | elif self.optim_param.scheduler == "MultiStepLR":
144 | scheduler = MultiStepLR(
145 | optimizer,
146 | milestones=self.optim_param.milestones,
147 | gamma=self.optim_param.gamma,
148 | )
149 | else:
150 | scheduler = {
151 | "scheduler": ReduceLROnPlateau(
152 | optimizer,
153 | mode="min",
154 | patience=self.optim_param.patience,
155 | min_lr=self.optim_param.min_lr,
156 | ),
157 | "monitor": "val/loss",
158 | }
159 |
160 | return [[optimizer], [scheduler]]
161 |
162 | return optimizer
163 |
164 | def _get_outputs(self, batch, batch_idx):
165 | """convenience function since train/valid/test steps are similar"""
166 | x = batch
167 |
168 | # x['array'] gives the actual raw audio
169 | output = self(x["array"]).logits
170 |
171 | # process outputs
172 | log_probs = F.log_softmax(output, dim=-1)
173 | input_lengths = torch.LongTensor([len(b) for b in log_probs])
174 | log_probs = log_probs.permute(1, 0, 2)
175 |
176 | # process targets
177 | # extract the indices from the dictionary
178 | with self.processor.as_target_processor():
179 | # tokenizattion but no phonemization
180 | x["labels"] = self.processor(x["phonemes"]).input_ids
181 |
182 | target_lengths = torch.LongTensor([len(targ) for targ in x["labels"]])
183 | targets = torch.Tensor(list(chain.from_iterable(x["labels"]))).int()
184 |
185 | loss = self.loss(log_probs, targets, input_lengths, target_lengths)
186 |
187 | # to compute metric and log samples
188 | phone_preds = self.processor.batch_decode(torch.argmax(output, dim=-1))
189 | phone_targets = self.processor.batch_decode(x["labels"])
190 |
191 | return loss, output, phone_preds, phone_targets
192 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/models/__init__.py
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import HubertForCTC, Wav2Vec2ForCTC, WavLMForCTC
3 |
4 |
5 | class BaseModel(nn.Module):
6 | """
7 | BaseFeaturesExtractor class that will extract features according to the type of model
8 | https://huggingface.co/blog/fine-tune-wav2vec2-english
9 | """
10 |
11 | def __init__(self, params):
12 | super().__init__()
13 | self.params = params
14 |
15 | def forward(self, x):
16 | outputs = self.model(x)
17 | return outputs
18 |
19 |
20 | class Wav2Vec2(BaseModel):
21 | """
22 | https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/wav2vec2#transformers.Wav2Vec2ForCTC
23 | """
24 |
25 | def __init__(self, params):
26 | super().__init__(params)
27 |
28 | self.model = Wav2Vec2ForCTC.from_pretrained(params.pretrained_name)
29 | in_features = self.model.lm_head.in_features
30 | self.model.lm_head = nn.Linear(
31 | in_features=in_features, out_features=self.params.vocab_size
32 | )
33 |
34 |
35 | class WavLM(BaseModel):
36 | """
37 | https://huggingface.co/docs/transformers/model_doc/wavlm#transformers.WavLMForCTC
38 | """
39 |
40 | def __init__(self, params):
41 | super().__init__(params)
42 | self.model = WavLMForCTC.from_pretrained(params.pretrained_name)
43 | in_features = self.model.lm_head.in_features
44 | self.model.lm_head = nn.Linear(
45 | in_features=in_features, out_features=self.params.vocab_size
46 | )
47 |
48 |
49 | class Hubert(BaseModel):
50 | """
51 | https://huggingface.co/docs/transformers/v4.16.2/en/model_doc/hubert#transformers.HubertForCTC
52 | """
53 |
54 | def __init__(self, params):
55 | super().__init__(params)
56 | self.model = HubertForCTC.from_pretrained(params.pretrained_name)
57 | in_features = self.model.lm_head.in_features
58 | self.model.lm_head = nn.Linear(
59 | in_features=in_features, out_features=self.params.vocab_size
60 | )
61 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | wandb
2 | pytest
3 | transformers
4 | datasets
5 | soundfile
6 | simple-parsing
7 | torch==1.10.0
8 | pytorch-lightning
9 | torchaudio
10 | phonemizer
11 | rich
12 | librosa
13 | wget
14 | lightning-bolts
--------------------------------------------------------------------------------
/requirements_cuda11-3.txt:
--------------------------------------------------------------------------------
1 | datasets==1.18.3
2 | numpy==1.21.2
3 | pandas==1.4.1
4 | phonemizer==3.0.1
5 | pytorch_lightning==1.5.10
6 | rich==11.2.0
7 | simple_parsing==0.0.18
8 | -f https://download.pytorch.org/whl/cu113/torch_stable.html
9 | torch==1.10.2+cu113
10 | tqdm==4.62.3
11 | transformers==4.16.0
12 | wandb
13 | pytest
14 | soundfile
15 | torchaudio
16 | librosa
17 | wget
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | # Testing script to reproduce our test experiments
2 |
3 | python main.py --train False --language nl --subset nl --network_name Hubert --best_model_run Hubert_nl_tf_freezed
4 | python main.py --train False --language nl --subset nl --network_name WavLM --best_model_run WavLM_nl_tf_freezed
5 | python main.py --train False --language nl --subset nl --network_name Wav2Vec2 --best_model_run Wav2Vec2_nl_tf_freezed
6 |
7 | python main.py --train False --language sv --subset sv-SE --network_name Hubert --best_model_run Hubert_sv_tf_freezed
8 | python main.py --train False --language sv --subset sv-SE --network_name WavLM --best_model_run WavLM_sv_tf_freezed
9 | python main.py --train False --language sv --subset sv-SE --network_name Wav2Vec2 --best_model_run Wav2Vec2_sv_tf_freezed
10 |
11 | python main.py --train False --language it --subset it --network_name Hubert --best_model_run Hubert_it_tf_freezed
12 | python main.py --train False --language it --subset it --network_name WavLM --best_model_run WavLM_it_tf_freezed
13 | python main.py --train False --language it --subset it --network_name Wav2Vec2 --best_model_run Wav2Vec2_it_tf_freezed
14 |
15 | python main.py --train False --language ru --subset ru --network_name Hubert --best_model_run Hubert_ru_tf_freezed
16 | python main.py --train False --language ru --subset ru --network_name WavLM --best_model_run WavLM_ru_tf_freezed
17 | python main.py --train False --language ru --subset ru --network_name Wav2Vec2 --best_model_run Wav2Vec2_ru_tf_freezed
18 |
19 | python main.py --train False --language tr --subset tr --network_name Hubert --best_model_run Hubert_tr_tf_freezed
20 | python main.py --train False --language tr --subset tr --network_name WavLM --best_model_run WavLM_tr_tf_freezed
21 | python main.py --train False --language tr --subset tr --network_name Wav2Vec2 --best_model_run Wav2Vec2_tr_tf_freezed
22 |
23 | python main.py --train False --language sv --subset sv-SE --network_name Hubert --best_model_run Hubert_sv_0.05_train_tf_freezed --limit_train_batches 0.05
24 | python main.py --train False --language sv --subset sv-SE --network_name WavLM --best_model_run WavLM_sv_0.05_train_tf_freezed --limit_train_batches 0.05
25 | python main.py --train False --language sv --subset sv-SE --network_name Wav2Vec2 --best_model_run Wav2Vec2_sv_0.05_train_tf_freezed --limit_train_batches 0.05
26 |
27 | python main.py --train False --language sv --subset sv-SE --network_name Hubert --best_model_run Hubert_sv_0.1_train_tf_freezed --limit_train_batches 0.1
28 | python main.py --train False --language sv --subset sv-SE --network_name WavLM --best_model_run WavLM_sv_0.1_train_tf_freezed --limit_train_batches 0.1
29 | python main.py --train False --language sv --subset sv-SE --network_name Wav2Vec2 --best_model_run Wav2Vec2_sv_0.1_train_tf_freezed --limit_train_batches 0.1
30 |
31 | python main.py --train False --language sv --subset sv-SE --network_name Hubert --best_model_run Hubert_sv_0.5_train_tf_freezed --limit_train_batches 0.5
32 | python main.py --train False --language sv --subset sv-SE --network_name WavLM --best_model_run WavLM_sv_0.5_train_tf_freezed --limit_train_batches 0.5
33 | python main.py --train False --language sv --subset sv-SE --network_name Wav2Vec2 --best_model_run Wav2Vec2_sv_0.5_train_tf_freezed --limit_train_batches 0.5
--------------------------------------------------------------------------------
/train_notebook.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "id": "kyYQaEdwuQ1_"
8 | },
9 | "outputs": [],
10 | "source": [
11 | "!git clone https://github.com/ASR-project/Multilingual-PR.git\n",
12 | "%cd Multilingual-PR\n",
13 | "!pip install -q -r requirements.txt\n",
14 | "!pip install pytorch_lightning==1.5.10\n",
15 | "!apt-get install espeak -y"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "metadata": {
22 | "id": "6OwxeTlzwXAV"
23 | },
24 | "outputs": [],
25 | "source": [
26 | "import wandb\n",
27 | "!wandb login"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {
34 | "id": "qLg-sGwEFSoL"
35 | },
36 | "outputs": [],
37 | "source": [
38 | "!python main.py --gpu 1 --num_workers 2 --language ru --subset ru --network_name WavLM --train True --num_proc 1 --enable_progress_bar False --lr 2e-2"
39 | ]
40 | }
41 | ],
42 | "metadata": {
43 | "accelerator": "GPU",
44 | "colab": {
45 | "background_execution": "on",
46 | "collapsed_sections": [],
47 | "machine_shape": "hm",
48 | "name": "train_notebook.ipynb",
49 | "provenance": []
50 | },
51 | "kernelspec": {
52 | "display_name": "Python 3",
53 | "name": "python3"
54 | },
55 | "language_info": {
56 | "name": "python"
57 | }
58 | },
59 | "nbformat": 4,
60 | "nbformat_minor": 0
61 | }
62 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ASR-project/Multilingual-PR/e7c84948f7f65d62b9b1e085487557a44dc95564/utils/__init__.py
--------------------------------------------------------------------------------
/utils/agent_utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import errno
4 |
5 | import wandb
6 | from config.hparams import Parameters
7 | from Datasets.datamodule import BaseDataModule
8 | from rich.progress import (
9 | BarColumn,
10 | Progress,
11 | SpinnerColumn,
12 | TextColumn,
13 | TimeElapsedColumn,
14 | TimeRemainingColumn,
15 | )
16 |
17 |
18 | def get_net(network_name, network_param):
19 | """
20 | Get Network Architecture based on arguments provided
21 | """
22 |
23 | mod = importlib.import_module(f"models.{network_name}")
24 | net = getattr(mod, network_name)
25 | return net(network_param)
26 |
27 |
28 | def get_artifact(name: str, type: str) -> str:
29 | """Artifact utilities
30 | Extracts the artifact from the name by downloading it locally>
31 | Return : str = path to the artifact
32 | """
33 | if name != "" and name is not None:
34 | artifact = wandb.run.use_artifact(name, type=type)
35 | artifact_dir = artifact.download()
36 | file_path = os.path.join(artifact_dir, os.listdir(artifact_dir)[0])
37 | return file_path
38 | else:
39 | return None
40 |
41 |
42 | def get_datamodule(data_param):
43 | """
44 | Fetch Datamodule Function Pointer
45 | """
46 | return BaseDataModule(data_param)
47 |
48 |
49 | def get_model(model_name, params):
50 | """
51 | get features extractors
52 | """
53 | try:
54 | mod = importlib.import_module(f"models.models")
55 | net = getattr(mod, model_name)
56 | return net(params)
57 | except NotImplementedError:
58 | raise NotImplementedError(f"Not implemented only Wav2vec, WavLM and Hubert")
59 |
60 |
61 | def parse_params(parameters: Parameters) -> dict:
62 | wdb_config = {}
63 | for k, v in vars(parameters).items():
64 | for key, value in vars(v).items():
65 | wdb_config[f"{k}-{key}"] = value
66 | return wdb_config
67 |
68 |
69 | def get_progress_bar():
70 | return Progress(
71 | "[progress.description]{task.description}",
72 | SpinnerColumn(),
73 | BarColumn(),
74 | "[progress.percentage]{task.percentage:>3.0f}%",
75 | TextColumn("[bold blue]{task.fields[info]}", justify="right"),
76 | TimeElapsedColumn(),
77 | TimeRemainingColumn(),
78 | "\n",
79 | )
80 |
81 |
82 | def create_directory(dir_path):
83 | try:
84 | os.makedirs(dir_path)
85 | except OSError as e:
86 | print(e)
87 | if e.errno != errno.EEXIST:
88 | raise
89 |
--------------------------------------------------------------------------------
/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | from datetime import timedelta
2 | from typing import Any, Dict, Optional
3 |
4 | import torch
5 | import wandb
6 | import pytorch_lightning as pl
7 | import numpy as np
8 |
9 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback
10 | from pytorch_lightning.utilities import rank_zero_info
11 | from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
12 |
13 | from utils.metrics import MetricsModule
14 |
15 |
16 | class AutoSaveModelCheckpoint(ModelCheckpoint):
17 | def __init__(
18 | self,
19 | config,
20 | project,
21 | entity,
22 | dirpath: Optional[_PATH] = None,
23 | filename: Optional[str] = None,
24 | monitor: Optional[str] = None,
25 | verbose: bool = False,
26 | save_last: Optional[bool] = None,
27 | save_top_k: int = 1,
28 | save_weights_only: bool = False,
29 | mode: str = "min",
30 | auto_insert_metric_name: bool = True,
31 | every_n_train_steps: Optional[int] = None,
32 | train_time_interval: Optional[timedelta] = None,
33 | every_n_epochs: Optional[int] = None,
34 | save_on_train_epoch_end: Optional[bool] = None,
35 | every_n_val_epochs: Optional[int] = None,
36 | ):
37 | super().__init__(
38 | dirpath,
39 | filename,
40 | monitor,
41 | verbose,
42 | save_last,
43 | save_top_k,
44 | save_weights_only,
45 | mode,
46 | auto_insert_metric_name,
47 | every_n_train_steps,
48 | train_time_interval,
49 | every_n_epochs,
50 | save_on_train_epoch_end,
51 | every_n_val_epochs,
52 | )
53 | self.config = config
54 | self.project = project
55 | self.entity = entity
56 |
57 | def _update_best_and_save(
58 | self,
59 | current: torch.Tensor,
60 | trainer: "pl.Trainer",
61 | monitor_candidates: Dict[str, _METRIC],
62 | ) -> None:
63 | k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
64 |
65 | del_filepath = None
66 | if len(self.best_k_models) == k and k > 0:
67 | del_filepath = self.kth_best_model_path
68 | self.best_k_models.pop(del_filepath)
69 |
70 | # do not save nan, replace with +/- inf
71 | if isinstance(current, torch.Tensor) and torch.isnan(current):
72 | current = torch.tensor(
73 | float("inf" if self.mode == "min" else "-inf"), device=current.device
74 | )
75 |
76 | filepath = self._get_metric_interpolated_filepath_name(
77 | monitor_candidates, trainer, del_filepath
78 | )
79 |
80 | # save the current score
81 | self.current_score = current
82 | self.best_k_models[filepath] = current
83 |
84 | if len(self.best_k_models) == k:
85 | # monitor dict has reached k elements
86 | _op = max if self.mode == "min" else min
87 | self.kth_best_model_path = _op(
88 | self.best_k_models, key=self.best_k_models.get
89 | )
90 | self.kth_value = self.best_k_models[self.kth_best_model_path]
91 |
92 | _op = min if self.mode == "min" else max
93 | self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
94 | self.best_model_score = self.best_k_models[self.best_model_path]
95 |
96 | if self.verbose:
97 | epoch = monitor_candidates.get("epoch")
98 | step = monitor_candidates.get("step")
99 | rank_zero_info(
100 | f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
101 | f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
102 | )
103 | trainer.save_checkpoint(filepath, self.save_weights_only)
104 |
105 | if del_filepath is not None and filepath != del_filepath:
106 | trainer.training_type_plugin.remove_checkpoint(del_filepath)
107 |
108 | reverse = False if self.mode == "min" else True
109 | score = sorted(self.best_k_models.values(), reverse=reverse)
110 | # indices = [(i+1) for i, x in enumerate(score) if x == current]
111 | self.alias = f"latest" #
112 | self.name = f"{wandb.run.name}"
113 |
114 | self.filepath = filepath
115 |
116 | def log_artifact(self):
117 | rank_zero_info(f"Logging artifact")
118 |
119 | api = wandb.Api(overrides={"project": self.project, "entity": self.entity})
120 | model_artifact = wandb.Artifact(
121 | type="model", name=self.name, metadata=self.config
122 | )
123 |
124 | model_artifact.add_file(self.filepath)
125 | wandb.log_artifact(model_artifact, aliases=[self.alias])
126 | model_artifact.wait()
127 | rank_zero_info(f"Done. Saved '{self.name}' weights to wandb")
128 | rank_zero_info(f"Cleaning up artifacts")
129 | artifacts = []
130 | for art in list(api.artifact_versions("model", self.name)):
131 | try:
132 | per = art.logged_by().summary.get("val/per", 0)
133 | artifacts.append((art, per))
134 | except:
135 | pass
136 |
137 | artifacts = sorted(artifacts, key=lambda art: art[-1]["min"], reverse=False)
138 |
139 | for i, artifact in enumerate(artifacts):
140 | artifact[0].aliases = [f"top-{i+1}"]
141 | try:
142 | artifact[0].save()
143 | except:
144 | pass
145 |
146 | rank_zero_info(f"Done")
147 |
148 | def del_artifacts(self):
149 | api = wandb.Api(overrides={"project": self.project, "entity": self.entity})
150 | artifact_type, artifact_name = "model", f"{wandb.run.name}"
151 | try:
152 | for version in api.artifact_versions(artifact_type, artifact_name):
153 | # Clean previous versions with the same alias, to keep only the latest top k.
154 | if (
155 | len(version.aliases) == 0
156 | ): # this means that it does not have the latest alias
157 | # either this works, or I will have to remove the model with the alias first then log the next
158 | version.delete()
159 | except:
160 | print("error in del artifact to ignore")
161 | return
162 |
163 | def on_exception(
164 | self,
165 | trainer: "pl.Trainer",
166 | pl_module: "pl.LightningModule",
167 | exception: BaseException,
168 | ) -> None:
169 | self.log_artifact()
170 | return super().on_exception(trainer, pl_module, exception)
171 |
172 | def on_train_end(
173 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
174 | ) -> None:
175 | self.log_artifact()
176 |
177 |
178 | class LogMetricsCallback(Callback):
179 | def __init__(self):
180 | super().__init__()
181 |
182 | def on_fit_start(
183 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
184 | ) -> None:
185 | device = pl_module.device
186 |
187 | self.metrics_module_train = MetricsModule("train", device)
188 |
189 | self.metrics_module_validation = MetricsModule("val", device)
190 |
191 | def on_test_start(
192 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"
193 | ) -> None:
194 | device = pl_module.device
195 |
196 | self.metrics_module_test = MetricsModule("test", device)
197 |
198 | def on_train_batch_end(
199 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
200 | ):
201 | """Called when the train batch ends."""
202 |
203 | self.metrics_module_train.update_metrics(outputs["preds"], outputs["targets"])
204 |
205 | def on_train_epoch_end(self, trainer, pl_module):
206 | """Called when the train epoch ends."""
207 |
208 | self.metrics_module_train.log_metrics("train/", pl_module)
209 |
210 | def on_validation_batch_end(
211 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
212 | ):
213 | """Called when the validation batch ends."""
214 |
215 | self.metrics_module_validation.update_metrics(
216 | outputs["preds"], outputs["targets"]
217 | )
218 |
219 | def on_validation_epoch_end(self, trainer, pl_module):
220 | """Called when the validation epoch ends."""
221 |
222 | self.metrics_module_validation.log_metrics("val/", pl_module)
223 |
224 | def on_test_batch_end(
225 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
226 | ):
227 | """Called when the validation batch ends."""
228 |
229 | self.metrics_module_test.update_metrics(outputs["preds"], outputs["targets"])
230 |
231 | def on_test_epoch_end(self, trainer, pl_module):
232 | """Called when the validation epoch ends."""
233 |
234 | self.metrics_module_test.log_metrics("test/", pl_module)
235 |
236 |
237 | class LogAudioPrediction(Callback):
238 | def __init__(self, log_freq_audio, log_nb_audio) -> None:
239 | super().__init__()
240 | self.log_freq_audio = log_freq_audio
241 | self.log_nb_audio = log_nb_audio
242 |
243 | def on_validation_batch_end(
244 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
245 | ):
246 | """Called when the validation batch ends."""
247 |
248 | if batch_idx == 0 and pl_module.current_epoch % self.log_freq_audio == 0:
249 | self.log_audio(
250 | pl_module,
251 | "val",
252 | batch,
253 | self.log_nb_audio,
254 | outputs,
255 | trainer.datamodule.sampling_rate,
256 | )
257 |
258 | def on_train_batch_end(
259 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
260 | ):
261 | """Called when the training batch ends."""
262 |
263 | if batch_idx == 0 and pl_module.current_epoch % self.log_freq_audio == 0:
264 | self.log_audio(
265 | pl_module,
266 | "train",
267 | batch,
268 | self.log_nb_audio,
269 | outputs,
270 | trainer.datamodule.sampling_rate,
271 | )
272 |
273 | def on_test_batch_end(
274 | self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
275 | ):
276 | """Called when the test batch ends."""
277 |
278 | if batch_idx == 0 and pl_module.current_epoch % self.log_freq_audio == 0:
279 | self.log_audio(
280 | pl_module,
281 | "test",
282 | batch,
283 | self.log_nb_audio,
284 | outputs,
285 | trainer.datamodule.sampling_rate,
286 | )
287 |
288 | def log_audio(self, pl_module, name, batch, n, outputs, sampling_rate):
289 | x = batch
290 | audios = x["array"][:n].detach().cpu()
291 |
292 | samples = []
293 | for i in range(len(audios)):
294 | samples.append(
295 | [
296 | wandb.Audio(audios[i], sample_rate=sampling_rate),
297 | x["sentence"][i],
298 | outputs["targets"][i],
299 | outputs["preds"][i],
300 | ]
301 | )
302 |
303 | columns = ["Audio sample", "sentence", "target", "prediction"]
304 | table = wandb.Table(data=samples, columns=columns)
305 |
306 | wandb.run.log({f"{name}/predictions": table})
307 |
--------------------------------------------------------------------------------
/utils/constant.py:
--------------------------------------------------------------------------------
1 | CHARS_TO_REMOVE_REGEX = "[\,\?\.\!\-\;\:\"\“\%\‘\”\�'\。]"
2 |
--------------------------------------------------------------------------------
/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import os.path as osp
4 | import tarfile
5 |
6 | import pandas as pd
7 | import torch
8 | import wget
9 | from torch.nn.utils.rnn import pad_sequence
10 |
11 | from utils.logger import init_logger
12 |
13 |
14 | def coll_fn(batch):
15 |
16 | batch_dict = {}
17 | batch_dict["array"] = pad_sequence(
18 | [torch.Tensor(b["audio"]) for b in batch], padding_value=0, batch_first=True
19 | )
20 | batch_dict["path"] = [b["path"] for b in batch]
21 | batch_dict["sentence"] = [b["sentence"] for b in batch]
22 | batch_dict["phonemes"] = [b["phonemes"] for b in batch]
23 |
24 | return batch_dict
25 |
26 |
27 | def create_vocabulary(
28 | ISO6393, path_csv, eos_token, bos_token, unk_token, pad_token, word_delimiter_token
29 | ):
30 |
31 | logger = init_logger("create_vocabulary", "INFO")
32 |
33 | df = pd.read_csv(osp.join(path_csv, "phoible.csv"))
34 |
35 | df_phoneme_target_lang = df[df["ISO6393"] == ISO6393]["Phoneme"]
36 | df_phoneme_target_lang.drop_duplicates(keep="first", inplace=True)
37 | df_phoneme_target_lang.reset_index(drop=True, inplace=True)
38 |
39 | phoneme_vocab = dict(df_phoneme_target_lang)
40 | phoneme_vocab = {v: k for k, v in phoneme_vocab.items()}
41 |
42 | phoneme_vocab[eos_token] = len(phoneme_vocab)
43 | phoneme_vocab[bos_token] = len(phoneme_vocab)
44 | phoneme_vocab[unk_token] = len(phoneme_vocab)
45 | phoneme_vocab[pad_token] = len(phoneme_vocab)
46 | phoneme_vocab[word_delimiter_token] = len(phoneme_vocab)
47 |
48 | logger.info(f"Length vocabulary : {len(phoneme_vocab)}")
49 |
50 | vocab_path = osp.join(os.getcwd(), "assets", "vocab_phoneme")
51 | file_dict = os.path.join(vocab_path, f"vocab-phoneme-{ISO6393}.json")
52 |
53 | if not os.path.exists(vocab_path):
54 | os.makedirs(vocab_path)
55 |
56 | with open(file_dict, "w") as vocab_file:
57 | json.dump(phoneme_vocab, vocab_file)
58 |
59 | return file_dict, len(phoneme_vocab)
60 |
61 |
62 | def create_vocabulary2(
63 | language, path, eos_token, bos_token, unk_token, pad_token, word_delimiter_token
64 | ):
65 |
66 | logger = init_logger("create_vocabulary", "INFO")
67 |
68 | if not osp.exists(path):
69 | assets_path = "/".join(path.split("/")[:-1])
70 | url = "https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz"
71 | wget.download(url, assets_path)
72 |
73 | tar = tarfile.open(osp.join(assets_path, "common_voices_splits.tar.gz"), "r:gz")
74 | tar.extractall(assets_path)
75 | tar.close()
76 |
77 | json_file = osp.join(path, language, "phonesMatches_reduced.json")
78 |
79 | with open(json_file) as file:
80 | phoneme_vocab = json.load(file)
81 |
82 | phoneme_vocab[eos_token] = len(phoneme_vocab)
83 | phoneme_vocab[bos_token] = len(phoneme_vocab)
84 | phoneme_vocab[unk_token] = len(phoneme_vocab)
85 | phoneme_vocab[pad_token] = len(phoneme_vocab)
86 | phoneme_vocab[word_delimiter_token] = len(phoneme_vocab)
87 |
88 | logger.info(f"Length vocabulary : {len(phoneme_vocab)}")
89 |
90 | vocab_path = osp.join(os.getcwd(), "assets", "vocab_phoneme")
91 | file_dict = os.path.join(vocab_path, f"vocab-phoneme-{language}.json")
92 |
93 | if not os.path.exists(vocab_path):
94 | os.makedirs(vocab_path)
95 |
96 | with open(file_dict, "w") as vocab_file:
97 | json.dump(phoneme_vocab, vocab_file)
98 |
99 | return file_dict, len(phoneme_vocab)
100 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | class CustomFormatter(logging.Formatter):
5 | """Logging Formatter to add colors and count warning / errors"""
6 |
7 | yellow = "\x1b[93;1m"
8 | red = "\x1b[31;1m"
9 | blue = "\x1b[36;1m"
10 | green = "\x1b[32;1m"
11 | reset = "\x1b[0m"
12 | orange = "\x1b[33;1m"
13 | date = "%(asctime)s ["
14 | level_name = "%(levelname)s"
15 | prefix = "]\t"
16 | other = "%(name)s\t%(message)s"
17 | datefmt = "%H:%M:%S"
18 |
19 | FORMATS = {
20 | logging.DEBUG: date + green + level_name + reset + prefix + other,
21 | logging.INFO: date + blue + level_name + reset + prefix + "\t" + other,
22 | logging.WARNING: date + yellow + level_name + reset + prefix + other,
23 | logging.ERROR: date + red + level_name + reset + prefix + other,
24 | logging.CRITICAL: date + orange + level_name + reset + prefix + other,
25 | }
26 |
27 | def format(self, record):
28 | log_fmt = self.FORMATS.get(record.levelno)
29 | formatter = logging.Formatter(log_fmt, "%H:%M:%S")
30 | return formatter.format(record)
31 |
32 |
33 | def init_logger(name, log_level):
34 |
35 | logger = logging.getLogger(name)
36 | logger.setLevel(getattr(logging, log_level))
37 | ch = logging.StreamHandler()
38 | ch.setLevel(logging.DEBUG)
39 | ch.setFormatter(CustomFormatter())
40 | logger.addHandler(ch)
41 | return logger
42 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from utils.per import PhonemeErrorRate
2 |
3 |
4 | class MetricsModule:
5 | def __init__(self, set_name, device) -> None:
6 | """
7 | set_name: val/train/test
8 | """
9 | self.device = device
10 | dict_metrics = {}
11 | dict_metrics["per"] = PhonemeErrorRate(compute_on_step=False).to(device)
12 |
13 | self.dict_metrics = dict_metrics
14 |
15 | def update_metrics(self, x, y):
16 |
17 | for _, m in self.dict_metrics.items():
18 | # metric on current batch
19 | m(x, y) # update metrics (torchmetrics method)
20 |
21 | def log_metrics(self, name, pl_module):
22 |
23 | for k, m in self.dict_metrics.items():
24 |
25 | # metric on all batches using custom accumulation
26 | metric = m.compute()
27 | pl_module.log(name + k, metric)
28 |
29 | # Reseting internal state such that metric ready for new data
30 | m.reset()
31 | m.to(self.device)
32 |
--------------------------------------------------------------------------------
/utils/per.py:
--------------------------------------------------------------------------------
1 | from torchmetrics import Metric
2 | import torch
3 | from torch import Tensor, tensor
4 | from typing import Any, Dict, List, Optional, Union, Tuple
5 |
6 |
7 | class PhonemeErrorRate(Metric):
8 | """
9 | https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/text/wer.py#L23-L93
10 | """
11 |
12 | def __init__(self, compute_on_step=False):
13 | super().__init__(compute_on_step=compute_on_step)
14 | self.add_state("errors", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
15 | self.add_state("total", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
16 |
17 | def update(self, preds, targets):
18 | """
19 | preds : list of sentence phoneme
20 | targets : list of sentence phoneme
21 | """
22 | errors, total = _per_update(preds, targets)
23 |
24 | self.errors += errors
25 | self.total += total
26 |
27 | def compute(self):
28 | return _per_compute(self.errors, self.total)
29 |
30 |
31 | def _per_update(
32 | preds: Union[str, List[str]],
33 | target: Union[str, List[str]],
34 | ) -> Tuple[Tensor, Tensor]:
35 | """Update the wer score with the current set of references and predictions.
36 | Args:
37 | preds: Transcription(s) to score as a string or list of strings
38 | target: Reference(s) for each speech input as a string or list of strings
39 | Returns:
40 | Number of edit operations to get from the reference to the prediction, summed over all samples
41 | Number of words overall references
42 | """
43 | if isinstance(preds, str):
44 | preds = [preds]
45 | if isinstance(target, str):
46 | target = [target]
47 | errors = tensor(0, dtype=torch.float)
48 | total = tensor(0, dtype=torch.float)
49 | for pred, tgt in zip(preds, target):
50 | pred_tokens = pred.split()
51 | tgt_tokens = tgt.split()
52 | errors += _edit_distance(pred_tokens, tgt_tokens)
53 | total += len(tgt_tokens)
54 | return errors, total
55 |
56 |
57 | def _per_compute(errors: Tensor, total: Tensor) -> Tensor:
58 | """Compute the word error rate.
59 | Args:
60 | errors: Number of edit operations to get from the reference to the prediction, summed over all samples
61 | total: Number of words overall references
62 | Returns:
63 | Word error rate score
64 | """
65 | return errors / total
66 |
67 |
68 | def _edit_distance(prediction_tokens: List[str], reference_tokens: List[str]) -> int:
69 | """Standard dynamic programming algorithm to compute the edit distance.
70 | Args:
71 | prediction_tokens: A tokenized predicted sentence
72 | reference_tokens: A tokenized reference sentence
73 | Returns:
74 | Edit distance between the predicted sentence and the reference sentence
75 | """
76 | dp = [[0] * (len(reference_tokens) + 1) for _ in range(len(prediction_tokens) + 1)]
77 | for i in range(len(prediction_tokens) + 1):
78 | dp[i][0] = i
79 | for j in range(len(reference_tokens) + 1):
80 | dp[0][j] = j
81 | for i in range(1, len(prediction_tokens) + 1):
82 | for j in range(1, len(reference_tokens) + 1):
83 | if prediction_tokens[i - 1] == reference_tokens[j - 1]:
84 | dp[i][j] = dp[i - 1][j - 1]
85 | else:
86 | dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
87 | return dp[-1][-1]
88 |
--------------------------------------------------------------------------------