├── .gitignore ├── LICENSE ├── README.md ├── charsiu_forced_alignment_demo.ipynb ├── charsiu_textless_demo.ipynb ├── charsiu_tutorial.ipynb ├── experiments ├── common_voice_preprocess.py ├── common_voice_pretraining.py ├── masked_phone_modeling.py ├── train_alignment.py ├── train_asr.py ├── train_attention_aligner.py └── train_frame_classification.py ├── local ├── SA1.TXT ├── SA1.TextGrid ├── SA1.WAV ├── SSB00050015.TextGrid ├── SSB00050015_16k.wav ├── SSB16240001.TextGrid ├── SSB16240001_16k.wav ├── good_ali.pdf ├── image.png ├── sample.TextGrid ├── vocab-ctc.json └── vocab.json ├── misc └── data.md └── src ├── Charsiu.py ├── __pycache__ ├── models.cpython-38.pyc └── utils.cpython-38.pyc ├── data_loader.py ├── evaluation.py ├── models.py ├── processors.py ├── utils.py └── vocab-ctc.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # DotEnv configuration 60 | .env 61 | 62 | # Database 63 | *.db 64 | *.rdb 65 | 66 | # Pycharm 67 | .idea 68 | 69 | # VS Code 70 | .vscode/ 71 | 72 | # Spyder 73 | .spyproject/ 74 | 75 | # Jupyter NB Checkpoints 76 | .ipynb_checkpoints/ 77 | 78 | # exclude data from source control by default 79 | /data/ 80 | /models/ 81 | /examples/ 82 | 83 | # Mac OS-specific storage files 84 | .DS_Store 85 | 86 | # vim 87 | *.swp 88 | *.swo 89 | 90 | # Mypy cache 91 | .mypy_cache/ 92 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 jzhu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Charsiu: A transformer-based phonetic aligner [[arXiv]](https://arxiv.org/abs/2110.03876) 2 | 3 | ### Updates 4 | - 2.10.2022. We release phone- and word-level alignments for 860k utterances from the English subset of Common Voice. Check out [this link](misc/data.md#alignments-for-english-datasets). 5 | - 1.31.2022. We release phone- and word-level alignments for over a million Mandarin utterances. Check out [this link](misc/data.md#alignments-for-mandarin-speech-datasets). 6 | - 1.26.2022. Word alignment functionality has been added to `charsiu_forced_aligner` . 7 | 8 | ### Intro 9 | **Charsiu** is a phonetic alignment tool, which can: 10 | - recognise phonemes in a given audio file 11 | - perform forced alignment using phone transcriptions created in the previous step or provided by the user. 12 | - directly predict the phone-to-audio alignment from audio (text-independent alignment) 13 | 14 | The aligner is under active development. New functions, new languages and detailed documentation will be added soon! Give us a star if you like our project! 15 | **Fun fact**: Char Siu is one of the most representative dishes of Cantonese cuisine 🍲 (see [wiki](https://en.wikipedia.org/wiki/Char_siu)). 16 | 17 | 18 | 19 | ### Table of content 20 | - [Tutorial](README.md#Tutorial) 21 | - [Usage](README.md#Usage) 22 | - [Pretrained models](README.md#Pretrained-models) 23 | - [Development plan](README.md#Development-plan) 24 | - [Dependencies](README.md#Dependencies) 25 | - [Training](README.md#Training) 26 | - [Attribution and Citation](README.md#attribution-and-citation) 27 | - [References](README.md#References) 28 | - [Disclaimer](README.md#Disclaimer) 29 | - [Support or Contact](README.md#support-or-contact) 30 | 31 | 32 | 33 | 34 | ### Tutorial 35 | **[!NEW]** A step-by-step tutorial for linguists: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lingjzhu/charsiu/blob/development/charsiu_tutorial.ipynb) 36 | 37 | You can directly run our model in the cloud via Google Colab! 38 | - Forced alignment: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lingjzhu/charsiu/blob/development/charsiu_forced_alignment_demo.ipynb) 39 | - Textless alignment: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lingjzhu/charsiu/blob/development/charsiu_textless_demo.ipynb) 40 | 41 | ### Usage 42 | ``` 43 | git clone https://github.com/lingjzhu/charsiu 44 | cd charsiu 45 | ``` 46 | #### Forced alignment 47 | ```Python 48 | from Charsiu import charsiu_forced_aligner 49 | # if there are errors importing, uncomment the following lines and add path to charsiu 50 | # import sys 51 | # sys.path.append('path_to_charsiu/src') 52 | 53 | # initialize model 54 | charsiu = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms') 55 | # perform forced alignment 56 | alignment = charsiu.align(audio='./local/SA1.WAV', 57 | text='She had your dark suit in greasy wash water all year.') 58 | # perform forced alignment and save the output as a textgrid file 59 | charsiu.serve(audio='./local/SA1.WAV', 60 | text='She had your dark suit in greasy wash water all year.', 61 | save_to='./local/SA1.TextGrid') 62 | 63 | 64 | # Chinese 65 | charsiu = charsiu_forced_aligner(aligner='charsiu/zh_w2v2_tiny_fc_10ms',lang='zh') 66 | charsiu.align(audio='./local/SSB00050015_16k.wav',text='经广州日报报道后成为了社会热点。') 67 | charsiu.serve(audio='./local/SSB00050015_16k.wav', text='经广州日报报道后成为了社会热点。', 68 | save_to='./local/SSB00050015.TextGrid') 69 | 70 | # An numpy array of speech signal can also be passed to the model. 71 | import soundfile as sf 72 | y, sr = sf.read('./local/SSB00050015_16k.wav') 73 | charsiu.align(audio=y,text='经广州日报报道后成为了社会热点。') 74 | ``` 75 | 76 | 77 | #### Textless alignment 78 | ```Python 79 | from Charsiu import charsiu_predictive_aligner 80 | # English 81 | # initialize a model 82 | charsiu = charsiu_predictive_aligner(aligner='charsiu/en_w2v2_fc_10ms') 83 | # perform textless alignment 84 | alignment = charsiu.align(audio='./local/SA1.WAV') 85 | # Or 86 | # perform textless alignment and output the results to a textgrid file 87 | charsiu.serve(audio='./local/SA1.WAV', save_to='./local/SA1.TextGrid') 88 | 89 | 90 | # Chinese 91 | charsiu = charsiu_predictive_aligner(aligner='charsiu/zh_xlsr_fc_10ms',lang='zh') 92 | 93 | charsiu.align(audio='./local/SSB16240001_16k.wav') 94 | # Or 95 | charsiu.serve(audio='./local/SSB16240001_16k.wav', save_to='./local/SSB16240001.TextGrid') 96 | ``` 97 | 98 | ### Pretrained models 99 | Pretrained models are available at the 🤗 *HuggingFace* model hub: https://huggingface.co/charsiu. 100 | 101 | 102 | ### Development plan 103 | 104 | - Package 105 | 106 | | Items | Progress | 107 | |:------------------:|:--------:| 108 | | Documentation | Nov 2021 | 109 | | Textgrid support | √ | 110 | | Word Segmentation | √ | 111 | | Model compression | TBD | 112 | | IPA support | TBD | 113 | 114 | - Multilingual support 115 | 116 | | Language | Progress | 117 | |:------------------:|:--------:| 118 | | English (American) | √ | 119 | | Mandarin Chinese | √ | 120 | | German | TBD | 121 | | Spanish | TBD | 122 | | English (British) | TBD | 123 | | Cantonese | TBD | 124 | | AAVE | TBD | 125 | 126 | 127 | 128 | 129 | 130 | ### Dependencies 131 | pytorch 132 | transformers 133 | datasets 134 | librosa 135 | g2pe 136 | praatio 137 | g2pM 138 | 139 | 140 | ### Training 141 | The training pipeline is coming soon! 142 | 143 | Note.Training code is in `experiments/`. Those were original research code for training the model. They still need to be reorganized. 144 | 145 | 146 | ### Attribution and Citation 147 | For now, you can cite this tool as: 148 | 149 | ``` 150 | @article{zhu2022charsiu, 151 | title={Phone-to-audio alignment without text: A Semi-supervised Approach}, 152 | author={Zhu, Jian and Zhang, Cong and Jurgens, David}, 153 | journal={IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 154 | year={2022} 155 | } 156 | ``` 157 | Or 158 | 159 | 160 | To share a direct web link: https://github.com/lingjzhu/charsiu/. 161 | 162 | ### References 163 | [Transformers](https://huggingface.co/transformers/) 164 | [s3prl](https://github.com/s3prl/s3prl) 165 | [Montreal Forced Aligner](https://montreal-forced-aligner.readthedocs.io/en/latest/) 166 | 167 | 168 | ### Disclaimer 169 | 170 | This tool is a beta version and is still under active development. It may have bugs and quirks, alongside the difficulties and provisos which are described throughout the documentation. 171 | This tool is distributed under MIT license. Please see [license](https://github.com/lingjzhu/charsiu/blob/main/LICENSE) for details. 172 | 173 | By using this tool, you acknowledge: 174 | 175 | * That you understand that this tool does not produce perfect camera-ready data, and that all results should be hand-checked for sanity's sake, or at the very least, noise should be taken into account. 176 | 177 | * That you understand that this tool is a work in progress which may contain bugs. Future versions will be released, and bug fixes (and additions) will not necessarily be advertised. 178 | 179 | * That this tool may break with future updates of the various dependencies, and that the authors are not required to repair the package when that happens. 180 | 181 | * That you understand that the authors are not required or necessarily available to fix bugs which are encountered (although you're welcome to submit bug reports to Jian Zhu (lingjzhu@umich.edu), if needed), nor to modify the tool to your needs. 182 | 183 | * That you will acknowledge the authors of the tool if you use, modify, fork, or re-use the code in your future work. 184 | 185 | * That rather than re-distributing this tool to other researchers, you will instead advise them to download the latest version from the website. 186 | 187 | ... and, most importantly: 188 | 189 | * That neither the authors, our collaborators, nor the the University of Michigan or any related universities on the whole, are responsible for the results obtained from the proper or improper usage of the tool, and that the tool is provided as-is, as a service to our fellow linguists. 190 | 191 | All that said, thanks for using our tool, and we hope it works wonderfully for you! 192 | 193 | ### Support or Contact 194 | Please contact Jian Zhu ([lingjzhu@umich.edu](lingjzhu@umich.edu)) for technical support. 195 | Contact Cong Zhang ([cong.zhang@ru.nl](cong.zhang@ru.nl)) if you would like to receive more instructions on how to use the package. 196 | 197 | 198 | 199 | -------------------------------------------------------------------------------- /charsiu_forced_alignment_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.8.5" 21 | }, 22 | "colab": { 23 | "name": "charsiu_demo.ipynb", 24 | "provenance": [] 25 | } 26 | }, 27 | "cells": [ 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "0zxKOeyTROc2" 32 | }, 33 | "source": [ 34 | "!pip install torch torchvision torchaudio\n", 35 | "!pip install datasets transformers\n", 36 | "!pip install g2p_en praatio librosa" 37 | ], 38 | "execution_count": null, 39 | "outputs": [] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "id": "DIROcsj7Rv4g", 48 | "outputId": "4cda97cf-0ed3-4027-cf17-800c4d0816d3" 49 | }, 50 | "source": [ 51 | "import os\n", 52 | "from os.path import exists, join, expanduser\n", 53 | "\n", 54 | "os.chdir(expanduser(\"~\"))\n", 55 | "charsiu_dir = 'charsiu'\n", 56 | "if exists(charsiu_dir):\n", 57 | " !rm -rf /root/charsiu\n", 58 | "if not exists(charsiu_dir):\n", 59 | " ! git clone -b development https://github.com/lingjzhu/$charsiu_dir\n", 60 | " ! cd charsiu && git checkout && cd -\n", 61 | " \n", 62 | "os.chdir(charsiu_dir) " 63 | ], 64 | "execution_count": 2, 65 | "outputs": [ 66 | { 67 | "output_type": "stream", 68 | "name": "stdout", 69 | "text": [ 70 | "Cloning into 'charsiu'...\n", 71 | "remote: Enumerating objects: 308, done.\u001b[K\n", 72 | "remote: Counting objects: 100% (308/308), done.\u001b[K\n", 73 | "remote: Compressing objects: 100% (254/254), done.\u001b[K\n", 74 | "remote: Total 308 (delta 149), reused 148 (delta 48), pack-reused 0\u001b[K\n", 75 | "Receiving objects: 100% (308/308), 508.52 KiB | 4.99 MiB/s, done.\n", 76 | "Resolving deltas: 100% (149/149), done.\n", 77 | "Your branch is up to date with 'origin/development'.\n", 78 | "/root\n" 79 | ] 80 | } 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "metadata": { 86 | "id": "GmHNb4OxRVD8", 87 | "outputId": "504264d8-5abe-4ba2-a56f-6db5fe5456ba", 88 | "colab": { 89 | "base_uri": "https://localhost:8080/" 90 | } 91 | }, 92 | "source": [ 93 | "import sys\n", 94 | "import torch\n", 95 | "from datasets import load_dataset\n", 96 | "import matplotlib.pyplot as plt\n", 97 | "sys.path.append('src/')\n", 98 | "#sys.path.insert(0,'src')\n", 99 | "from Charsiu import charsiu_forced_aligner, charsiu_attention_aligner" 100 | ], 101 | "execution_count": 3, 102 | "outputs": [ 103 | { 104 | "output_type": "stream", 105 | "name": "stdout", 106 | "text": [ 107 | "[nltk_data] Downloading package averaged_perceptron_tagger to\n", 108 | "[nltk_data] /root/nltk_data...\n", 109 | "[nltk_data] Unzipping taggers/averaged_perceptron_tagger.zip.\n", 110 | "[nltk_data] Downloading package cmudict to /root/nltk_data...\n", 111 | "[nltk_data] Unzipping corpora/cmudict.zip.\n" 112 | ] 113 | } 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "metadata": { 119 | "id": "N2wZBRx_WOfv" 120 | }, 121 | "source": [ 122 | "timit = load_dataset('timit_asr')" 123 | ], 124 | "execution_count": null, 125 | "outputs": [] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "metadata": { 130 | "id": "kBzpi5mSjiyL", 131 | "colab": { 132 | "base_uri": "https://localhost:8080/" 133 | }, 134 | "outputId": "bb57f11a-5ebd-45c5-f593-89253b77151a" 135 | }, 136 | "source": [ 137 | "# load data\n", 138 | "sample = timit['train'][0]\n", 139 | "text = sample['text']\n", 140 | "audio_path = sample['file']\n", 141 | "print('Text transcription:%s'%(text))\n", 142 | "print('Audio path: %s'%audio_path)" 143 | ], 144 | "execution_count": 5, 145 | "outputs": [ 146 | { 147 | "output_type": "stream", 148 | "name": "stdout", 149 | "text": [ 150 | "Text transcription:Would such an act of refusal be useful?\n", 151 | "Audio path: /root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV\n" 152 | ] 153 | } 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "metadata": { 159 | "colab": { 160 | "base_uri": "https://localhost:8080/" 161 | }, 162 | "id": "q7paWfYdROc5", 163 | "outputId": "d494e91a-863d-4ca5-f9bb-e8c3160d3543" 164 | }, 165 | "source": [ 166 | "# initialize model\n", 167 | "charsiu = charsiu_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms')" 168 | ], 169 | "execution_count": null, 170 | "outputs": [ 171 | { 172 | "output_type": "stream", 173 | "name": "stderr", 174 | "text": [ 175 | "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", 176 | "/usr/local/lib/python3.7/dist-packages/transformers/configuration_utils.py:341: UserWarning: Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the `Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`.\n", 177 | " \"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 \"\n" 178 | ] 179 | } 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": { 185 | "id": "zYUK0SUKxtt_" 186 | }, 187 | "source": [ 188 | "Forced alignment with a neural forced alignment model" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "metadata": { 194 | "id": "yHW92QgDROc4" 195 | }, 196 | "source": [ 197 | "alignment = charsiu.align(audio=audio_path,text=text)" 198 | ], 199 | "execution_count": null, 200 | "outputs": [] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "metadata": { 205 | "colab": { 206 | "base_uri": "https://localhost:8080/" 207 | }, 208 | "id": "yCmbdfpzXrQ3", 209 | "outputId": "e2017f28-4fc1-4394-e21e-47771a59d740" 210 | }, 211 | "source": [ 212 | "print(alignment)\n", 213 | "print('\\n Ground Truth \\n')\n", 214 | "print([(s/16000,e/16000,p) for s,e,p in zip(sample['phonetic_detail']['start'],sample['phonetic_detail']['stop'],sample['phonetic_detail']['utterance'])])" 215 | ], 216 | "execution_count": null, 217 | "outputs": [ 218 | { 219 | "output_type": "stream", 220 | "name": "stdout", 221 | "text": [ 222 | "[(0.0, 0.08, '[SIL]'), (0.08, 0.15, 'W'), (0.15, 0.19, 'UH'), (0.19, 0.24, 'D'), (0.24, 0.38, 'S'), (0.38, 0.46, 'AH'), (0.46, 0.58, 'CH'), (0.58, 0.6, 'AE'), (0.6, 0.68, 'N'), (0.68, 0.82, 'AE'), (0.82, 0.93, 'K'), (0.93, 0.99, 'T'), (0.99, 1.04, 'AH'), (1.04, 1.13, 'V'), (1.13, 1.17, 'R'), (1.17, 1.22, 'AH'), (1.22, 1.33, 'F'), (1.33, 1.41, 'Y'), (1.41, 1.48, 'UW'), (1.48, 1.53, 'Z'), (1.53, 1.62, 'AH'), (1.62, 1.68, 'L'), (1.68, 1.78, 'B'), (1.78, 1.88, 'IY'), (1.88, 1.99, 'Y'), (1.99, 2.08, 'UW'), (2.08, 2.13, 'S'), (2.13, 2.23, 'F'), (2.23, 2.27, 'AH'), (2.27, 2.47, 'L'), (2.47, 2.48, '[SIL]')]\n", 223 | "\n", 224 | " Ground Truth \n", 225 | "\n", 226 | "[(0.0, 0.1225, 'h#'), (0.1225, 0.154125, 'w'), (0.154125, 0.2175, 'ix'), (0.2175, 0.25, 'dcl'), (0.25, 0.3725, 's'), (0.3725, 0.4675, 'ah'), (0.4675, 0.4925, 'tcl'), (0.4925, 0.5875, 'ch'), (0.5875, 0.6225, 'ix'), (0.6225, 0.6675, 'n'), (0.6675, 0.8425, 'ae'), (0.8425, 0.98, 'kcl'), (0.98, 0.9925, 't'), (0.9925, 1.0575, 'ix'), (1.0575, 1.1435625, 'v'), (1.1435625, 1.180125, 'r'), (1.180125, 1.2175, 'ix'), (1.2175, 1.3576875, 'f'), (1.3576875, 1.40725, 'y'), (1.40725, 1.5025, 'ux'), (1.5025, 1.574375, 'zh'), (1.574375, 1.6925, 'el'), (1.6925, 1.76, 'bcl'), (1.76, 1.785, 'b'), (1.785, 1.8825, 'iy'), (1.8825, 1.9895, 'y'), (1.9895, 2.0775, 'ux'), (2.0775, 2.165, 's'), (2.165, 2.248, 'f'), (2.248, 2.3575, 'el'), (2.3575, 2.495, 'h#')]\n" 227 | ] 228 | } 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "metadata": { 234 | "colab": { 235 | "base_uri": "https://localhost:8080/" 236 | }, 237 | "id": "gJkF2z91ROc5", 238 | "outputId": "1b2c838d-b794-489e-aacd-7cccb149fb81" 239 | }, 240 | "source": [ 241 | "# save alignment\n", 242 | "charsiu.serve(audio=audio_path,text=text,save_to='./local/sample.TextGrid')" 243 | ], 244 | "execution_count": null, 245 | "outputs": [ 246 | { 247 | "output_type": "stream", 248 | "name": "stdout", 249 | "text": [ 250 | "Alignment output has been saved to ./local/sample.TextGrid\n" 251 | ] 252 | } 253 | ] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": { 258 | "id": "swtLnWRlvTdr" 259 | }, 260 | "source": [ 261 | "Forced Alignment with An Attention Alignment Model" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "metadata": { 267 | "id": "tHhEVix-ugEU", 268 | "colab": { 269 | "base_uri": "https://localhost:8080/" 270 | }, 271 | "outputId": "882770bc-cc54-456f-de6f-d8cbf71bc25c" 272 | }, 273 | "source": [ 274 | "# load data\n", 275 | "sample = timit['train'][0]\n", 276 | "text = sample['text']\n", 277 | "audio_path = sample['file']\n", 278 | "print('Text transcription:%s'%(text))\n", 279 | "print('Audio path: %s'%audio_path)" 280 | ], 281 | "execution_count": 7, 282 | "outputs": [ 283 | { 284 | "output_type": "stream", 285 | "name": "stdout", 286 | "text": [ 287 | "Text transcription:Would such an act of refusal be useful?\n", 288 | "Audio path: /root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV\n" 289 | ] 290 | } 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "metadata": { 296 | "id": "8mKL4RzfuP2a" 297 | }, 298 | "source": [ 299 | "# intialize model\n", 300 | "charsiu = charsiu_attention_aligner('charsiu/en_w2v2_fs_10ms')" 301 | ], 302 | "execution_count": null, 303 | "outputs": [] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "metadata": { 308 | "id": "m-SpQfeIvnBu" 309 | }, 310 | "source": [ 311 | "alignment = charsiu.align(audio=audio_path,text=text)" 312 | ], 313 | "execution_count": null, 314 | "outputs": [] 315 | }, 316 | { 317 | "cell_type": "code", 318 | "metadata": { 319 | "id": "lbbV0gpuvJW2", 320 | "colab": { 321 | "base_uri": "https://localhost:8080/" 322 | }, 323 | "outputId": "3dbf4bd3-b36a-44b8-c810-5cfc579958df" 324 | }, 325 | "source": [ 326 | "print(alignment)\n", 327 | "print('\\n Ground Truth \\n')\n", 328 | "print([(s/16000,e/16000,p) for s,e,p in zip(sample['phonetic_detail']['start'],sample['phonetic_detail']['stop'],sample['phonetic_detail']['utterance'])])" 329 | ], 330 | "execution_count": 9, 331 | "outputs": [ 332 | { 333 | "output_type": "stream", 334 | "name": "stdout", 335 | "text": [ 336 | "[(0.0, 0.11, '[SIL]'), (0.11, 0.15, 'W'), (0.15, 0.2, 'UH'), (0.2, 0.27, 'D'), (0.27, 0.38, 'S'), (0.38, 0.5, 'AH'), (0.5, 0.58, 'CH'), (0.58, 0.63, 'AE'), (0.63, 0.69, 'N'), (0.69, 0.83, 'AE'), (0.83, 0.94, 'K'), (0.94, 1.0, 'T'), (1.0, 1.05, 'AH'), (1.05, 1.12, 'V'), (1.12, 1.18, 'R'), (1.18, 1.24, 'AH'), (1.24, 1.34, 'F'), (1.34, 1.43, 'Y'), (1.43, 1.5, 'UW'), (1.5, 1.58, 'Z'), (1.58, 1.64, 'AH'), (1.64, 1.73, 'L'), (1.73, 1.79, 'B'), (1.79, 1.9, 'IY'), (1.9, 2.01, 'Y'), (2.01, 2.08, 'UW'), (2.08, 2.17, 'S'), (2.17, 2.24, 'F'), (2.24, 2.31, 'AH'), (2.31, 2.41, 'L'), (2.41, 2.48, '[SIL]')]\n", 337 | "\n", 338 | " Ground Truth \n", 339 | "\n", 340 | "[(0.0, 0.1225, 'h#'), (0.1225, 0.154125, 'w'), (0.154125, 0.2175, 'ix'), (0.2175, 0.25, 'dcl'), (0.25, 0.3725, 's'), (0.3725, 0.4675, 'ah'), (0.4675, 0.4925, 'tcl'), (0.4925, 0.5875, 'ch'), (0.5875, 0.6225, 'ix'), (0.6225, 0.6675, 'n'), (0.6675, 0.8425, 'ae'), (0.8425, 0.98, 'kcl'), (0.98, 0.9925, 't'), (0.9925, 1.0575, 'ix'), (1.0575, 1.1435625, 'v'), (1.1435625, 1.180125, 'r'), (1.180125, 1.2175, 'ix'), (1.2175, 1.3576875, 'f'), (1.3576875, 1.40725, 'y'), (1.40725, 1.5025, 'ux'), (1.5025, 1.574375, 'zh'), (1.574375, 1.6925, 'el'), (1.6925, 1.76, 'bcl'), (1.76, 1.785, 'b'), (1.785, 1.8825, 'iy'), (1.8825, 1.9895, 'y'), (1.9895, 2.0775, 'ux'), (2.0775, 2.165, 's'), (2.165, 2.248, 'f'), (2.248, 2.3575, 'el'), (2.3575, 2.495, 'h#')]\n" 341 | ] 342 | } 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "metadata": { 348 | "id": "V_Y9eH5KlO4r", 349 | "outputId": "35419b67-3f65-4b35-8d17-6ea3fb458436", 350 | "colab": { 351 | "base_uri": "https://localhost:8080/" 352 | } 353 | }, 354 | "source": [ 355 | "charsiu.serve(audio=audio_path,text=text,save_to='./local/sample.TextGrid')" 356 | ], 357 | "execution_count": 11, 358 | "outputs": [ 359 | { 360 | "output_type": "stream", 361 | "name": "stderr", 362 | "text": [ 363 | "/usr/local/lib/python3.7/dist-packages/transformers/models/wav2vec2/modeling_wav2vec2.py:984: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", 364 | " return (input_length - kernel_size) // stride + 1\n" 365 | ] 366 | }, 367 | { 368 | "output_type": "stream", 369 | "name": "stdout", 370 | "text": [ 371 | "Alignment output has been saved to ./local/sample.TextGrid\n" 372 | ] 373 | } 374 | ] 375 | } 376 | ] 377 | } -------------------------------------------------------------------------------- /charsiu_textless_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "kernelspec": { 6 | "display_name": "Python 3", 7 | "language": "python", 8 | "name": "python3" 9 | }, 10 | "language_info": { 11 | "codemirror_mode": { 12 | "name": "ipython", 13 | "version": 3 14 | }, 15 | "file_extension": ".py", 16 | "mimetype": "text/x-python", 17 | "name": "python", 18 | "nbconvert_exporter": "python", 19 | "pygments_lexer": "ipython3", 20 | "version": "3.8.5" 21 | }, 22 | "colab": { 23 | "name": "charsiu_demo.ipynb", 24 | "provenance": [] 25 | } 26 | }, 27 | "cells": [ 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "0zxKOeyTROc2" 32 | }, 33 | "source": [ 34 | "!pip install torch torchvision torchaudio\n", 35 | "!pip install datasets transformers\n", 36 | "!pip install g2p_en praatio librosa" 37 | ], 38 | "execution_count": null, 39 | "outputs": [] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/" 46 | }, 47 | "id": "DIROcsj7Rv4g", 48 | "outputId": "290c10ae-5ee4-4b9f-86de-8de075c80a76" 49 | }, 50 | "source": [ 51 | "import os\n", 52 | "from os.path import exists, join, expanduser\n", 53 | "\n", 54 | "os.chdir(expanduser(\"~\"))\n", 55 | "charsiu_dir = 'charsiu'\n", 56 | "if exists(charsiu_dir):\n", 57 | " !rm -rf /root/charsiu\n", 58 | "if not exists(charsiu_dir):\n", 59 | " ! git clone -b development https://github.com/lingjzhu/$charsiu_dir\n", 60 | " ! cd charsiu && git checkout && cd -\n", 61 | " \n", 62 | "os.chdir(charsiu_dir) " 63 | ], 64 | "execution_count": 2, 65 | "outputs": [ 66 | { 67 | "output_type": "stream", 68 | "name": "stdout", 69 | "text": [ 70 | "Cloning into 'charsiu'...\n", 71 | "remote: Enumerating objects: 322, done.\u001b[K\n", 72 | "remote: Counting objects: 100% (322/322), done.\u001b[K\n", 73 | "remote: Compressing objects: 100% (260/260), done.\u001b[K\n", 74 | "remote: Total 322 (delta 158), reused 159 (delta 55), pack-reused 0\u001b[K\n", 75 | "Receiving objects: 100% (322/322), 511.27 KiB | 12.47 MiB/s, done.\n", 76 | "Resolving deltas: 100% (158/158), done.\n", 77 | "Your branch is up to date with 'origin/development'.\n", 78 | "/root\n" 79 | ] 80 | } 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "metadata": { 86 | "id": "GmHNb4OxRVD8" 87 | }, 88 | "source": [ 89 | "import sys\n", 90 | "import torch\n", 91 | "from itertools import groupby\n", 92 | "from datasets import load_dataset\n", 93 | "import matplotlib.pyplot as plt\n", 94 | "\n", 95 | "sys.path.insert(0,'src')\n", 96 | "from Charsiu import charsiu_chain_attention_aligner, charsiu_chain_forced_aligner, charsiu_predictive_aligner" 97 | ], 98 | "execution_count": null, 99 | "outputs": [] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "metadata": { 104 | "id": "q7paWfYdROc5" 105 | }, 106 | "source": [ 107 | "# download timit\n", 108 | "timit = load_dataset('timit_asr')" 109 | ], 110 | "execution_count": null, 111 | "outputs": [] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "metadata": { 116 | "id": "psXcfdsd48NJ", 117 | "outputId": "8b13c545-1e59-4e4e-cb4e-d3b021a88e3f", 118 | "colab": { 119 | "base_uri": "https://localhost:8080/" 120 | } 121 | }, 122 | "source": [ 123 | "# load data\n", 124 | "sample = timit['train'][0]\n", 125 | "text = sample['text']\n", 126 | "audio_path = sample['file']\n", 127 | "print('Text transcription:%s'%(text))\n", 128 | "print('Audio path: %s'%audio_path)" 129 | ], 130 | "execution_count": 6, 131 | "outputs": [ 132 | { 133 | "output_type": "stream", 134 | "name": "stdout", 135 | "text": [ 136 | "Text transcription:Would such an act of refusal be useful?\n", 137 | "Audio path: /root/.cache/huggingface/datasets/downloads/extracted/404950a46da14eac65eb4e2a8317b1372fb3971d980d91d5d5b221275b1fd7e0/data/TRAIN/DR4/MMDM0/SI681.WAV\n" 138 | ] 139 | } 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": { 145 | "id": "gUye9Hgpzpxb" 146 | }, 147 | "source": [ 148 | "Phone recognizer + Neural Forced Alignment" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "metadata": { 154 | "id": "yHW92QgDROc4" 155 | }, 156 | "source": [ 157 | "# load model\n", 158 | "charsiu = charsiu_chain_attention_aligner(aligner='charsiu/en_w2v2_fs_10ms',recognizer='charsiu/en_w2v2_ctc_libris_and_cv')" 159 | ], 160 | "execution_count": null, 161 | "outputs": [] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "metadata": { 166 | "id": "gJkF2z91ROc5" 167 | }, 168 | "source": [ 169 | "alignment = charsiu.align(audio=audio_path)" 170 | ], 171 | "execution_count": null, 172 | "outputs": [] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "metadata": { 177 | "id": "gLE1r9LQ5CYq", 178 | "outputId": "cdd35c8f-620b-4d04-d51e-bf37d9ebcf21", 179 | "colab": { 180 | "base_uri": "https://localhost:8080/" 181 | } 182 | }, 183 | "source": [ 184 | "print(alignment)\n", 185 | "print('\\n Ground Truth \\n')\n", 186 | "print([(s/16000,e/16000,p) for s,e,p in zip(sample['phonetic_detail']['start'],sample['phonetic_detail']['stop'],sample['phonetic_detail']['utterance'])])" 187 | ], 188 | "execution_count": 17, 189 | "outputs": [ 190 | { 191 | "output_type": "stream", 192 | "name": "stdout", 193 | "text": [ 194 | "[(0.0, 0.11, '[SIL]'), (0.11, 0.16, 'W'), (0.16, 0.21, 'IH'), (0.21, 0.27, 'DH'), (0.27, 0.38, 'S'), (0.38, 0.49, 'AH'), (0.49, 0.58, 'CH'), (0.58, 0.63, 'AE'), (0.63, 0.69, 'N'), (0.69, 0.83, 'AE'), (0.83, 0.94, 'K'), (0.94, 1.0, 'T'), (1.0, 1.04, 'IH'), (1.04, 1.12, 'V'), (1.12, 1.18, 'R'), (1.18, 1.24, 'AH'), (1.24, 1.34, 'F'), (1.34, 1.43, 'Y'), (1.43, 1.5, 'UW'), (1.5, 1.58, 'Z'), (1.58, 1.64, 'AH'), (1.64, 1.73, 'L'), (1.73, 1.79, 'B'), (1.79, 1.9, 'IY'), (1.9, 2.01, 'Y'), (2.01, 2.09, 'UW'), (2.09, 2.17, 'S'), (2.17, 2.24, 'F'), (2.24, 2.31, 'AH'), (2.31, 2.4, 'L'), (2.4, 2.48, '[SIL]')]\n", 195 | "\n", 196 | " Ground Truth \n", 197 | "\n", 198 | "[(0.0, 0.1225, 'h#'), (0.1225, 0.154125, 'w'), (0.154125, 0.2175, 'ix'), (0.2175, 0.25, 'dcl'), (0.25, 0.3725, 's'), (0.3725, 0.4675, 'ah'), (0.4675, 0.4925, 'tcl'), (0.4925, 0.5875, 'ch'), (0.5875, 0.6225, 'ix'), (0.6225, 0.6675, 'n'), (0.6675, 0.8425, 'ae'), (0.8425, 0.98, 'kcl'), (0.98, 0.9925, 't'), (0.9925, 1.0575, 'ix'), (1.0575, 1.1435625, 'v'), (1.1435625, 1.180125, 'r'), (1.180125, 1.2175, 'ix'), (1.2175, 1.3576875, 'f'), (1.3576875, 1.40725, 'y'), (1.40725, 1.5025, 'ux'), (1.5025, 1.574375, 'zh'), (1.574375, 1.6925, 'el'), (1.6925, 1.76, 'bcl'), (1.76, 1.785, 'b'), (1.785, 1.8825, 'iy'), (1.8825, 1.9895, 'y'), (1.9895, 2.0775, 'ux'), (2.0775, 2.165, 's'), (2.165, 2.248, 'f'), (2.248, 2.3575, 'el'), (2.3575, 2.495, 'h#')]\n" 199 | ] 200 | } 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "metadata": { 206 | "id": "5R2M4YMHUf-X" 207 | }, 208 | "source": [ 209 | "charsiu.serve(audio=audio_path, save_to='sample.TextGrid')" 210 | ], 211 | "execution_count": null, 212 | "outputs": [] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "metadata": { 217 | "id": "yCmbdfpzXrQ3" 218 | }, 219 | "source": [ 220 | "# load model\n", 221 | "charsiu = charsiu_chain_forced_aligner(aligner='charsiu/en_w2v2_fc_10ms',recognizer='charsiu/en_w2v2_ctc_libris_and_cv')" 222 | ], 223 | "execution_count": null, 224 | "outputs": [] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "metadata": { 229 | "id": "SdZqWsE45Swv" 230 | }, 231 | "source": [ 232 | "alignment = charsiu.align(audio=audio_path)" 233 | ], 234 | "execution_count": null, 235 | "outputs": [] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "metadata": { 240 | "id": "68sFGMxN5Y7b", 241 | "outputId": "493ed56f-64fb-43a4-ed64-e5257896a2b0", 242 | "colab": { 243 | "base_uri": "https://localhost:8080/" 244 | } 245 | }, 246 | "source": [ 247 | "print(alignment)\n", 248 | "print('\\n Ground Truth \\n')\n", 249 | "print([(s/16000,e/16000,p) for s,e,p in zip(sample['phonetic_detail']['start'],sample['phonetic_detail']['stop'],sample['phonetic_detail']['utterance'])])" 250 | ], 251 | "execution_count": 24, 252 | "outputs": [ 253 | { 254 | "output_type": "stream", 255 | "name": "stdout", 256 | "text": [ 257 | "[(0.0, 0.08, '[SIL]'), (0.08, 0.16, 'W'), (0.16, 0.21, 'IH'), (0.21, 0.22, 'DH'), (0.22, 0.38, 'S'), (0.38, 0.46, 'AH'), (0.46, 0.58, 'CH'), (0.58, 0.6, 'AE'), (0.6, 0.68, 'N'), (0.68, 0.82, 'AE'), (0.82, 0.93, 'K'), (0.93, 0.99, 'T'), (0.99, 1.05, 'IH'), (1.05, 1.13, 'V'), (1.13, 1.17, 'R'), (1.17, 1.22, 'AH'), (1.22, 1.33, 'F'), (1.33, 1.41, 'Y'), (1.41, 1.48, 'UW'), (1.48, 1.53, 'Z'), (1.53, 1.62, 'AH'), (1.62, 1.68, 'L'), (1.68, 1.78, 'B'), (1.78, 1.88, 'IY'), (1.88, 1.99, 'Y'), (1.99, 2.08, 'UW'), (2.08, 2.13, 'S'), (2.13, 2.23, 'F'), (2.23, 2.27, 'AH'), (2.27, 2.47, 'L'), (2.47, 2.48, '[SIL]')]\n", 258 | "\n", 259 | " Ground Truth \n", 260 | "\n", 261 | "[(0.0, 0.1225, 'h#'), (0.1225, 0.154125, 'w'), (0.154125, 0.2175, 'ix'), (0.2175, 0.25, 'dcl'), (0.25, 0.3725, 's'), (0.3725, 0.4675, 'ah'), (0.4675, 0.4925, 'tcl'), (0.4925, 0.5875, 'ch'), (0.5875, 0.6225, 'ix'), (0.6225, 0.6675, 'n'), (0.6675, 0.8425, 'ae'), (0.8425, 0.98, 'kcl'), (0.98, 0.9925, 't'), (0.9925, 1.0575, 'ix'), (1.0575, 1.1435625, 'v'), (1.1435625, 1.180125, 'r'), (1.180125, 1.2175, 'ix'), (1.2175, 1.3576875, 'f'), (1.3576875, 1.40725, 'y'), (1.40725, 1.5025, 'ux'), (1.5025, 1.574375, 'zh'), (1.574375, 1.6925, 'el'), (1.6925, 1.76, 'bcl'), (1.76, 1.785, 'b'), (1.785, 1.8825, 'iy'), (1.8825, 1.9895, 'y'), (1.9895, 2.0775, 'ux'), (2.0775, 2.165, 's'), (2.165, 2.248, 'f'), (2.248, 2.3575, 'el'), (2.3575, 2.495, 'h#')]\n" 262 | ] 263 | } 264 | ] 265 | }, 266 | { 267 | "cell_type": "code", 268 | "metadata": { 269 | "id": "PqI70asA5btU" 270 | }, 271 | "source": [ 272 | "charsiu.serve(audio=audio_path, save_to='sample.TextGrid')" 273 | ], 274 | "execution_count": null, 275 | "outputs": [] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "metadata": { 280 | "id": "O2hj-1ZJ1tfA" 281 | }, 282 | "source": [ 283 | "Direct inference with frame classification model" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "metadata": { 289 | "id": "yG_eg8KJ1snD" 290 | }, 291 | "source": [ 292 | "charsiu = charsiu_predictive_aligner(aligner='charsiu/en_w2v2_fc_10ms')" 293 | ], 294 | "execution_count": null, 295 | "outputs": [] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "metadata": { 300 | "id": "qmkfJ9gD1tE0" 301 | }, 302 | "source": [ 303 | "alignment = charsiu.align(audio=audio_path)" 304 | ], 305 | "execution_count": 27, 306 | "outputs": [] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "metadata": { 311 | "id": "IRQr20Av2Iq7", 312 | "outputId": "7b1802ac-0ea9-4582-ec92-b1e9282122c0", 313 | "colab": { 314 | "base_uri": "https://localhost:8080/" 315 | } 316 | }, 317 | "source": [ 318 | "print(alignment)\n", 319 | "print('\\n Ground Truth \\n')\n", 320 | "print([(s/16000,e/16000,p) for s,e,p in zip(sample['phonetic_detail']['start'],sample['phonetic_detail']['stop'],sample['phonetic_detail']['utterance'])])" 321 | ], 322 | "execution_count": 28, 323 | "outputs": [ 324 | { 325 | "output_type": "stream", 326 | "name": "stdout", 327 | "text": [ 328 | "[(0.0, 0.08, '[SIL]'), (0.08, 0.15, 'W'), (0.15, 0.2, 'UH'), (0.2, 0.24, 'D'), (0.24, 0.38, 'S'), (0.38, 0.46, 'AH'), (0.46, 0.57, 'CH'), (0.57, 0.62, 'AH'), (0.62, 0.68, 'N'), (0.68, 0.83, 'AE'), (0.83, 0.94, 'K'), (0.94, 0.99, 'T'), (0.99, 1.05, 'IH'), (1.05, 1.13, 'V'), (1.13, 1.17, 'R'), (1.17, 1.21, 'IH'), (1.21, 1.22, 'AH'), (1.22, 1.34, 'F'), (1.34, 1.42, 'Y'), (1.42, 1.48, 'UW'), (1.48, 1.53, 'Z'), (1.53, 1.63, 'AH'), (1.63, 1.68, 'L'), (1.68, 1.78, 'B'), (1.78, 1.88, 'IY'), (1.88, 1.99, 'Y'), (1.99, 2.08, 'UW'), (2.08, 2.14, 'S'), (2.14, 2.23, 'F'), (2.23, 2.27, 'AH'), (2.27, 2.48, 'L')]\n", 329 | "\n", 330 | " Ground Truth \n", 331 | "\n", 332 | "[(0.0, 0.1225, 'h#'), (0.1225, 0.154125, 'w'), (0.154125, 0.2175, 'ix'), (0.2175, 0.25, 'dcl'), (0.25, 0.3725, 's'), (0.3725, 0.4675, 'ah'), (0.4675, 0.4925, 'tcl'), (0.4925, 0.5875, 'ch'), (0.5875, 0.6225, 'ix'), (0.6225, 0.6675, 'n'), (0.6675, 0.8425, 'ae'), (0.8425, 0.98, 'kcl'), (0.98, 0.9925, 't'), (0.9925, 1.0575, 'ix'), (1.0575, 1.1435625, 'v'), (1.1435625, 1.180125, 'r'), (1.180125, 1.2175, 'ix'), (1.2175, 1.3576875, 'f'), (1.3576875, 1.40725, 'y'), (1.40725, 1.5025, 'ux'), (1.5025, 1.574375, 'zh'), (1.574375, 1.6925, 'el'), (1.6925, 1.76, 'bcl'), (1.76, 1.785, 'b'), (1.785, 1.8825, 'iy'), (1.8825, 1.9895, 'y'), (1.9895, 2.0775, 'ux'), (2.0775, 2.165, 's'), (2.165, 2.248, 'f'), (2.248, 2.3575, 'el'), (2.3575, 2.495, 'h#')]\n" 333 | ] 334 | } 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "metadata": { 340 | "id": "dwTpN_5c2Zm1", 341 | "colab": { 342 | "base_uri": "https://localhost:8080/" 343 | }, 344 | "outputId": "ac8b79e3-9c5c-439f-be60-2773c730e19f" 345 | }, 346 | "source": [ 347 | "charsiu.serve(audio=audio_path, save_to='sample.TextGrid')" 348 | ], 349 | "execution_count": 29, 350 | "outputs": [ 351 | { 352 | "output_type": "stream", 353 | "name": "stdout", 354 | "text": [ 355 | "Alignment output has been saved to sample.TextGrid\n" 356 | ] 357 | } 358 | ] 359 | } 360 | ] 361 | } -------------------------------------------------------------------------------- /experiments/common_voice_preprocess.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import soundfile as sf 3 | from tqdm.contrib.concurrent import process_map 4 | import re 5 | import subprocess 6 | from tqdm import tqdm 7 | 8 | common_voice_train = load_dataset("common_voice", "en", split="train+validation",cache_dir='/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr') 9 | 10 | file_pairs = [(i['path'],re.search(r'(.*?)\.mp3',i['path']).group(1)+'.wav') for i in tqdm(common_voice_train)] 11 | 12 | def convert_and_resample(item): 13 | command = ['sox', item[0],'-r','16000',item[1]] 14 | subprocess.run(command) 15 | 16 | r = process_map(convert_and_resample, file_pairs, max_workers=8, chunksize=1) 17 | -------------------------------------------------------------------------------- /experiments/common_voice_pretraining.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import re 6 | import argparse 7 | import transformers 8 | import soundfile as sf 9 | import librosa 10 | import jiwer 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.utils.data import DataLoader 15 | 16 | 17 | from dataclasses import dataclass, field 18 | from typing import Any, Dict, List, Optional, Union 19 | from g2p_en import G2p 20 | import numpy as np 21 | from datasets import load_dataset, load_metric, load_from_disk 22 | from transformers import Wav2Vec2CTCTokenizer 23 | from transformers import Wav2Vec2FeatureExtractor 24 | from transformers import Trainer,TrainingArguments 25 | from transformers import Wav2Vec2Processor 26 | from transformers import Wav2Vec2ForPreTraining,Wav2Vec2Config 27 | from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices,Wav2Vec2ForPreTrainingOutput 28 | 29 | 30 | 31 | def prepare_common_voice_dataset(batch): 32 | # check that all files have the correct sampling rate 33 | 34 | batch["input_values"] = re.search(r'(.*?)\.mp3', batch['path']).group(1)+'.wav' 35 | return batch 36 | 37 | def prepare_multicn_dataset(batch): 38 | # check that all files have the correct sampling rate 39 | 40 | batch["input_values"] = batch['path'] 41 | return batch 42 | 43 | 44 | def audio_preprocess(path): 45 | 46 | features,sr = sf.read(path) 47 | assert sr == 16000 48 | return processor(features, sampling_rate=16000).input_values.squeeze() 49 | 50 | 51 | @dataclass 52 | class DataCollatorWithPadding: 53 | """ 54 | Data collator that will dynamically pad the inputs received. 55 | Args: 56 | processor (:class:`~transformers.Wav2Vec2Processor`) 57 | The processor used for proccessing the data. 58 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 59 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 60 | among: 61 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 62 | sequence if provided). 63 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 64 | maximum acceptable input length for the model if that argument is not provided. 65 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 66 | different lengths). 67 | max_length (:obj:`int`, `optional`): 68 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). 69 | max_length_labels (:obj:`int`, `optional`): 70 | Maximum length of the ``labels`` returned list and optionally padding length (see above). 71 | pad_to_multiple_of (:obj:`int`, `optional`): 72 | If set will pad the sequence to a multiple of the provided value. 73 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 74 | 7.5 (Volta). 75 | """ 76 | 77 | processor: Wav2Vec2Processor 78 | padding: Union[bool, str] = True 79 | return_attention_mask: Optional[bool] = True 80 | max_length: Optional[int] = None 81 | max_length_labels: Optional[int] = None 82 | pad_to_multiple_of: Optional[int] = None 83 | pad_to_multiple_of_labels: Optional[int] = None 84 | 85 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 86 | # split inputs and labels since they have to be of different lenghts and need 87 | # different padding methods 88 | input_features = [{"input_values": audio_preprocess(feature["input_values"])} for feature in features] 89 | 90 | batch = self.processor.pad( 91 | input_features, 92 | padding=self.padding, 93 | max_length=self.max_length, 94 | pad_to_multiple_of=self.pad_to_multiple_of, 95 | return_attention_mask=self.return_attention_mask, 96 | return_tensors="pt", 97 | ) 98 | batch_size, raw_sequence_length = batch['input_values'].shape 99 | sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) 100 | batch['mask_time_indices'] = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.1, mask_length=2,device='cpu') 101 | 102 | return batch 103 | 104 | 105 | 106 | tokenizer = Wav2Vec2CTCTokenizer("vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token=" ") 107 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 108 | 109 | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 110 | 111 | 112 | 113 | if __name__ == "__main__": 114 | 115 | lang = 'en' 116 | # loading data 117 | if lang == 'en': 118 | common_voice = load_from_disk('/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr/common_voice_filtered') 119 | data_prepared = common_voice.map(prepare_common_voice_dataset, remove_columns=common_voice.column_names) 120 | print('English data ready!') 121 | elif lang == 'zh': 122 | multicn = load_from_disk('/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr/multicn_list') 123 | data_prepared = multicn.map(prepare_multicn_dataset, remove_columns=multicn.column_names) 124 | print('Chinese data ready!') 125 | 126 | 127 | 128 | 129 | # data loader 130 | data_collator = DataCollatorWithPadding(processor=processor, padding=True) 131 | 132 | 133 | # load model 134 | config = Wav2Vec2Config() 135 | config.num_attention_heads = 6 136 | config.hidden_size = 384 137 | config.num_hidden_layers = 6 138 | config.num_negatives = 20 139 | model = Wav2Vec2ForPreTraining(config) 140 | 141 | if lang == 'en': 142 | pre_trained_model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base") 143 | 144 | layers = {'wav2vec2.feature_extractor.conv_layers.0.conv.weight', 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight', 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias', 'wav2vec2.feature_extractor.conv_layers.1.conv.weight', 'wav2vec2.feature_extractor.conv_layers.2.conv.weight', 'wav2vec2.feature_extractor.conv_layers.3.conv.weight', 'wav2vec2.feature_extractor.conv_layers.4.conv.weight', 'wav2vec2.feature_extractor.conv_layers.5.conv.weight', 'wav2vec2.feature_extractor.conv_layers.6.conv.weight','quantizer.codevectors', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_q.weight', 'project_q.bias'} 145 | pretrained_dict = {k: v for k, v in pre_trained_model.state_dict().items() if k in layers} 146 | print('Loaded Wav2Vec2 English') 147 | elif lang == 'zh': 148 | pre_trained_model = Wav2Vec2ForPreTraining.from_pretrained('facebook/wav2vec2-large-xlsr-53') 149 | 150 | layers = {'wav2vec2.feature_extractor.conv_layers.0.conv.weight', 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight', 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias', 'wav2vec2.feature_extractor.conv_layers.1.conv.weight', 'wav2vec2.feature_extractor.conv_layers.2.conv.weight', 'wav2vec2.feature_extractor.conv_layers.3.conv.weight', 'wav2vec2.feature_extractor.conv_layers.4.conv.weight', 'wav2vec2.feature_extractor.conv_layers.5.conv.weight', 'wav2vec2.feature_extractor.conv_layers.6.conv.weight'} 151 | pretrained_dict = {k:v for k,v in pre_trained_model.state_dict().items() if k in layers} 152 | print('Loaded xlsr-53 weights') 153 | 154 | 155 | state_dict = model.state_dict() 156 | state_dict.update(pretrained_dict) 157 | model.load_state_dict(state_dict) 158 | 159 | model.freeze_feature_extractor() 160 | model.wav2vec2.feature_extractor.conv_layers[6].conv.stride = (1,) 161 | model.config.conv_stride[-1] = 1 162 | 163 | del pre_trained_model 164 | 165 | 166 | if lang == 'en': 167 | output_dir = "/scratch/lingjzhu_root/lingjzhu1/lingjzhu/asr/wav2vec2-common_voice-pretraining" 168 | elif lang == 'zh': 169 | output_dir = "/scratch/lingjzhu_root/lingjzhu1/lingjzhu/asr/wav2vec2-multicn-pretraining" 170 | 171 | 172 | # training settings 173 | training_args = TrainingArguments( 174 | output_dir=output_dir, 175 | group_by_length=True, 176 | per_device_train_batch_size=4, 177 | gradient_accumulation_steps=40, 178 | # evaluation_strategy="steps", 179 | num_train_epochs=4, 180 | fp16=True, 181 | save_steps=1000, 182 | # eval_steps=1000, 183 | logging_steps=1000, 184 | learning_rate=5e-4, 185 | weight_decay=1e-6, 186 | warmup_steps=1000, 187 | save_total_limit=2, 188 | ignore_data_skip=True, 189 | ) 190 | 191 | 192 | trainer = Trainer( 193 | model=model, 194 | data_collator=data_collator, 195 | args=training_args, 196 | # compute_metrics=compute_metrics, 197 | train_dataset=data_prepared, 198 | # eval_dataset=libris_train_prepared, 199 | tokenizer=processor.feature_extractor, 200 | ) 201 | 202 | 203 | trainer.train() 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | -------------------------------------------------------------------------------- /experiments/masked_phone_modeling.py: -------------------------------------------------------------------------------- 1 | import re 2 | from g2p_en import G2p 3 | import json 4 | import torch 5 | import argparse 6 | from dataclasses import dataclass, field 7 | from transformers import BertTokenizerFast 8 | from transformers import BertForMaskedLM, BertConfig 9 | from transformers import Wav2Vec2CTCTokenizer 10 | from transformers import Wav2Vec2FeatureExtractor 11 | from transformers import Wav2Vec2Processor 12 | from transformers import Trainer,TrainingArguments 13 | from datasets import load_dataset 14 | from typing import Any, Dict, List, Optional, Union 15 | from torch.utils.data import DataLoader 16 | 17 | 18 | 19 | def get_phones(sen): 20 | ''' 21 | convert texts to phone sequence 22 | ''' 23 | sen = re.sub(r'\d','',sen) 24 | sen = g2p(sen) 25 | sen = [re.sub(r'\d','',p) for p in sen] 26 | return sen 27 | 28 | def get_phone_ids(phones): 29 | ''' 30 | convert phone sequence to ids 31 | ''' 32 | ids = [] 33 | punctuation = set('.,!?') 34 | for p in phones: 35 | if re.match(r'^\w+?$',p): 36 | ids.append(mapping_phone2id.get(p,mapping_phone2id['[UNK]'])) 37 | elif p in punctuation: 38 | ids.append(mapping_phone2id.get('[SIL]')) 39 | ids = [0]+ids 40 | if ids[-1]!=0: 41 | ids.append(0) 42 | return ids #append silence token at the beginning 43 | 44 | 45 | def torch_mask_tokens(inputs, special_tokens_mask,mlm_probability=0.2): 46 | """ 47 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. 48 | """ 49 | 50 | labels = inputs.clone() 51 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 52 | probability_matrix = torch.full(labels.shape, mlm_probability) 53 | # We'll use the attention mask here 54 | special_tokens_mask = special_tokens_mask.bool() 55 | 56 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 57 | masked_indices = torch.bernoulli(probability_matrix).bool() 58 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 59 | 60 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 61 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 62 | # 63 | inputs[indices_replaced] = torch.tensor(mapping_phone2id['[MASK]']) 64 | 65 | # 10% of the time, we replace masked input tokens with random word 66 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 67 | random_words = torch.randint(len(mapping_phone2id), labels.shape, dtype=torch.long) 68 | inputs[indices_random] = random_words[indices_random] 69 | 70 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 71 | return inputs, labels 72 | 73 | 74 | 75 | @dataclass 76 | class DataCollatorMPMWithPadding: 77 | """ 78 | Data collator that will dynamically pad the inputs received. 79 | Args: 80 | processor (:class:`~transformers.Wav2Vec2Processor`) 81 | The processor used for proccessing the data. 82 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 83 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 84 | among: 85 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 86 | sequence if provided). 87 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 88 | maximum acceptable input length for the model if that argument is not provided. 89 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 90 | different lengths). 91 | max_length (:obj:`int`, `optional`): 92 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). 93 | max_length_labels (:obj:`int`, `optional`): 94 | Maximum length of the ``labels`` returned list and optionally padding length (see above). 95 | pad_to_multiple_of (:obj:`int`, `optional`): 96 | If set will pad the sequence to a multiple of the provided value. 97 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 98 | 7.5 (Volta). 99 | """ 100 | 101 | processor: Wav2Vec2Processor 102 | padding: Union[bool, str] = True 103 | return_attention_mask: Optional[bool] = True 104 | max_length: Optional[int] = None 105 | max_length_labels: Optional[int] = 256 106 | pad_to_multiple_of: Optional[int] = None 107 | pad_to_multiple_of_labels: Optional[int] = None 108 | 109 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 110 | # split inputs and labels since they have to be of different lenghts and need 111 | # different padding methods 112 | label_features = [{"input_ids": get_phone_ids(get_phones(feature["input_ids"]))[:self.max_length_labels]} for feature in features] 113 | 114 | 115 | with self.processor.as_target_processor(): 116 | batch = self.processor.pad( 117 | label_features, 118 | padding=self.padding, 119 | max_length=self.max_length_labels, 120 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 121 | return_attention_mask=self.return_attention_mask, 122 | return_tensors="pt", 123 | ) 124 | 125 | inputs, labels = torch_mask_tokens(batch['input_ids'],1-batch['attention_mask']) 126 | # replace padding with -100 to ignore loss correctly 127 | labels = labels.masked_fill(batch.attention_mask.ne(1), -100) 128 | 129 | batch['input_ids'] = inputs 130 | batch["labels"] = labels 131 | 132 | return batch 133 | 134 | 135 | g2p = G2p() 136 | mapping_phone2id = json.load(open("./vocab.json",'r')) 137 | mapping_id2phone = {v:k for k,v in mapping_phone2id.items()} 138 | 139 | tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="") 140 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 141 | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 142 | 143 | 144 | if __name__ == "__main__": 145 | 146 | parser = argparse.ArgumentParser() 147 | parser.add_argument('--data_dir', type=str, default='/scratch/lingjzhu_root/lingjzhu1/lingjzhu/asr') 148 | parser.add_argument('--out_dir',type=str,default="./models/bert-phones") 149 | 150 | args = parser.parse_args() 151 | 152 | 153 | # load text dataset 154 | books = load_dataset('bookcorpus',cache_dir=args.data_dir) 155 | books = books.rename_column('text','input_ids') 156 | 157 | # load model 158 | config = BertConfig() 159 | config.vocab_size = len(mapping_phone2id) 160 | config.hidden_size = 384 161 | config.num_hidden_layers = 4 162 | config.intermediate_size=768 163 | model = BertForMaskedLM(config) 164 | 165 | data_collator = DataCollatorMPMWithPadding(processor=processor, padding=True) 166 | 167 | training_args = TrainingArguments( 168 | output_dir=args.out_dir, 169 | group_by_length=True, 170 | per_device_train_batch_size=64, 171 | gradient_accumulation_steps=32, 172 | num_train_epochs=1, 173 | fp16=True, 174 | save_steps=2000, 175 | logging_steps=2000, 176 | learning_rate=1e-3, 177 | weight_decay=0.00001, 178 | warmup_steps=1000, 179 | save_total_limit=2, 180 | ignore_data_skip=True 181 | ) 182 | 183 | 184 | trainer = Trainer( 185 | model=model, 186 | data_collator=data_collator, 187 | args=training_args, 188 | train_dataset=books['train'], 189 | # eval_dataset=books['val'], 190 | ) 191 | 192 | 193 | trainer.train(resume_from_checkpoint='models/bert-phones/checkpoint-10000') 194 | 195 | 196 | 197 | -------------------------------------------------------------------------------- /experiments/train_alignment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import re 5 | import os 6 | import argparse 7 | import transformers 8 | import soundfile as sf 9 | import librosa 10 | import json 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.utils.data import DataLoader 17 | from tqdm import tqdm 18 | from itertools import chain,groupby 19 | 20 | from dataclasses import dataclass, field 21 | from typing import Any, Dict, List, Optional, Union, Tuple 22 | from g2p_en import G2p 23 | from datasets import load_dataset, load_from_disk 24 | from transformers import Wav2Vec2CTCTokenizer 25 | from transformers import Wav2Vec2FeatureExtractor 26 | from transformers import Trainer,TrainingArguments 27 | from transformers import Wav2Vec2Processor 28 | from transformers import Wav2Vec2ForCTC 29 | from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices,Wav2Vec2ForPreTrainingOutput 30 | from transformers.file_utils import ModelOutput 31 | from transformers.modeling_outputs import CausalLMOutput, MaskedLMOutput 32 | from transformers import BertForMaskedLM 33 | from transformers import AdamW 34 | 35 | 36 | class ForwardSumLoss(torch.nn.Module): 37 | def __init__(self, blank_logprob=-1): 38 | super(ForwardSumLoss, self).__init__() 39 | self.log_softmax = torch.nn.LogSoftmax(dim=3) 40 | self.blank_logprob = blank_logprob 41 | # self.off_diag_penalty = off_diag_penalty 42 | self.CTCLoss = nn.CTCLoss(zero_infinity=True) 43 | 44 | def forward(self, attn_logprob, text_lens, mel_lens): 45 | """ 46 | Args: 47 | attn_logprob: batch x 1 x max(mel_lens) x max(text_lens) 48 | batched tensor of attention log 49 | probabilities, padded to length 50 | of longest sequence in each dimension 51 | text_lens: batch-D vector of length of 52 | each text sequence 53 | mel_lens: batch-D vector of length of 54 | each mel sequence 55 | """ 56 | # The CTC loss module assumes the existence of a blank token 57 | # that can be optionally inserted anywhere in the sequence for 58 | # a fixed probability. 59 | # A row must be added to the attention matrix to account for this 60 | attn_logprob_pd = F.pad(input=attn_logprob, 61 | pad=(1, 0, 0, 0, 0, 0, 0, 0), 62 | value=self.blank_logprob) 63 | cost_total = 0.0 64 | # for-loop over batch because of variable-length 65 | # sequences 66 | for bid in range(attn_logprob.shape[0]): 67 | # construct the target sequence. Every 68 | # text token is mapped to a unique sequence number, 69 | # thereby ensuring the monotonicity constraint 70 | target_seq = torch.arange(1, text_lens[bid]+1) 71 | target_seq=target_seq.unsqueeze(0) 72 | curr_logprob = attn_logprob_pd[bid].permute(1, 0, 2) 73 | curr_logprob = curr_logprob[:mel_lens[bid],:,:text_lens[bid]+1] 74 | 75 | # curr_logprob = curr_logprob + self.off_diagonal_loss(curr_logprob,text_lens[bid]+1,mel_lens[bid]) 76 | curr_logprob = self.log_softmax(curr_logprob[None])[0] 77 | cost = self.CTCLoss(curr_logprob, 78 | target_seq, 79 | input_lengths=mel_lens[bid:bid+1], 80 | target_lengths=text_lens[bid:bid+1]) 81 | cost_total += cost 82 | # average cost over batch 83 | cost_total = cost_total/attn_logprob.shape[0] 84 | return cost_total 85 | 86 | def off_diagonal_loss(self,log_prob,N, T, g=0.2): 87 | 88 | n = torch.arange(N).to(log_prob.device) 89 | t = torch.arange(T).to(log_prob.device) 90 | t = t.unsqueeze(1).repeat(1,N) 91 | n = n.unsqueeze(0).repeat(T,1) 92 | 93 | 94 | # W = 1 - torch.exp(-(n/N - t/T)**2/(2*g**2)) 95 | 96 | # penalty = log_prob*W.unsqueeze(1) 97 | # return torch.mean(penalty) 98 | W = torch.exp(-(n/N - t/T)**2/(2*g**2)) 99 | 100 | return torch.log_softmax(W.unsqueeze(1),dim=-1) 101 | 102 | 103 | class ConvBank(nn.Module): 104 | def __init__(self, input_dim, output_class_num, kernels, cnn_size, hidden_size, dropout, **kwargs): 105 | super(ConvBank, self).__init__() 106 | self.drop_p = dropout 107 | 108 | self.in_linear = nn.Linear(input_dim, hidden_size) 109 | latest_size = hidden_size 110 | 111 | # conv bank 112 | self.cnns = nn.ModuleList() 113 | assert len(kernels) > 0 114 | for kernel in kernels: 115 | self.cnns.append(nn.Conv1d(latest_size, cnn_size, kernel, padding=kernel//2)) 116 | latest_size = cnn_size * len(kernels) 117 | 118 | self.out_linear = nn.Linear(latest_size, output_class_num) 119 | 120 | def forward(self, features): 121 | hidden = F.dropout(F.relu(self.in_linear(features)), p=self.drop_p) 122 | 123 | conv_feats = [] 124 | hidden = hidden.transpose(1, 2).contiguous() 125 | for cnn in self.cnns: 126 | conv_feats.append(cnn(hidden)) 127 | hidden = torch.cat(conv_feats, dim=1).transpose(1, 2).contiguous() 128 | hidden = F.dropout(F.relu(hidden), p=self.drop_p) 129 | 130 | predicted = self.out_linear(hidden) 131 | return predicted 132 | 133 | 134 | class Wav2Vec2ForAlignment(Wav2Vec2ForCTC): 135 | 136 | def __init__(self,config): 137 | super().__init__(config) 138 | self.cnn = ConvBank(config.hidden_size,384,[1,3],384,384,0.1) 139 | self.align_loss = ForwardSumLoss() 140 | 141 | def freeze_wav2vec2(self): 142 | for param in self.wav2vec2.parameters(): 143 | param.requires_grad = False 144 | 145 | def initialize_phone_model(self,path): 146 | 147 | self.bert = BertForMaskedPhoneLM.from_pretrained(path) 148 | 149 | 150 | def forward( 151 | self, 152 | input_values, 153 | attention_mask=None, 154 | output_attentions=None, 155 | output_hidden_states=None, 156 | mask_time_indices=None, 157 | return_dict=None, 158 | labels=None, 159 | labels_attention_mask=None, 160 | text_len=None, 161 | frame_len=None 162 | ): 163 | 164 | 165 | 166 | outputs = self.wav2vec2( 167 | input_values, 168 | attention_mask=attention_mask, 169 | output_attentions=output_attentions, 170 | output_hidden_states=output_hidden_states, 171 | mask_time_indices=mask_time_indices, 172 | return_dict=return_dict, 173 | ) 174 | 175 | # acoustic embeddings 176 | hidden_states = outputs[0] 177 | hidden_states = self.dropout(hidden_states) 178 | hidden_states = self.cnn(hidden_states) 179 | 180 | # phone embeddings 181 | phone_hidden = self.bert(input_ids=labels,attention_mask=labels_attention_mask).hidden_states[-1] 182 | 183 | # compute cross attention 184 | att = torch.bmm(hidden_states,phone_hidden.transpose(2,1)) 185 | attention_mask = (1-labels_attention_mask)*-10000.0 186 | att = torch.log_softmax(att+attention_mask.unsqueeze(1).repeat(1,att.size(1),1),dim=-1) 187 | 188 | 189 | loss = None 190 | if self.training: 191 | loss = self.align_loss(att.unsqueeze(1),text_len,frame_len) 192 | 193 | 194 | return CausalLMOutput( 195 | loss=loss, logits=att, hidden_states=outputs.hidden_states, attentions=outputs.attentions 196 | ) 197 | 198 | 199 | class BertForMaskedPhoneLM(BertForMaskedLM): 200 | 201 | def __init__(self,config): 202 | super().__init__(config) 203 | self.cnn = ConvBank(config.hidden_size,384,[1],384,384,0.1) 204 | 205 | def freeze_feature_extractor(self): 206 | for param in self.bert.parameters(): 207 | param.requires_grad = False 208 | 209 | def forward( 210 | self, 211 | input_ids=None, 212 | attention_mask=None, 213 | token_type_ids=None, 214 | position_ids=None, 215 | head_mask=None, 216 | inputs_embeds=None, 217 | encoder_hidden_states=None, 218 | encoder_attention_mask=None, 219 | labels=None, 220 | output_attentions=None, 221 | output_hidden_states=True, 222 | ): 223 | 224 | 225 | outputs = self.bert( 226 | input_ids, 227 | attention_mask=attention_mask, 228 | token_type_ids=token_type_ids, 229 | position_ids=position_ids, 230 | head_mask=head_mask, 231 | inputs_embeds=inputs_embeds, 232 | encoder_hidden_states=encoder_hidden_states, 233 | encoder_attention_mask=encoder_attention_mask, 234 | output_attentions=output_attentions, 235 | output_hidden_states=output_hidden_states 236 | ) 237 | 238 | 239 | prediction_scores = self.cnn(outputs.hidden_states[-1]) 240 | 241 | masked_lm_loss = None 242 | if labels is not None: 243 | loss_fct = CrossEntropyLoss() # -100 index = padding token 244 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 245 | 246 | return MaskedLMOutput( 247 | loss=masked_lm_loss, 248 | logits=prediction_scores, 249 | hidden_states=outputs.hidden_states, 250 | attentions=outputs.attentions, 251 | ) 252 | 253 | 254 | 255 | def get_phones(sen): 256 | ''' 257 | convert texts to phone sequence 258 | ''' 259 | sen = g2p(sen) 260 | sen = [re.sub(r'\d','',p) for p in sen] 261 | return sen 262 | 263 | def get_phone_ids(phones): 264 | ''' 265 | convert phone sequence to ids 266 | ''' 267 | ids = [] 268 | punctuation = set('.,!?') 269 | for p in phones: 270 | if re.match(r'^\w+?$',p): 271 | ids.append(mapping_phone2id.get(p,mapping_phone2id['[UNK]'])) 272 | elif p in punctuation: 273 | ids.append(mapping_phone2id.get('[SIL]')) 274 | ids = [0]+ids 275 | if ids[-1]!=0: 276 | ids.append(0) 277 | return ids #append silence token at the beginning 278 | 279 | def audio_preprocess(path): 280 | 281 | # features,_ = sf.read(path) 282 | features, _ = librosa.core.load(path,sr=16000) 283 | return processor(features, sampling_rate=16000).input_values.squeeze() 284 | 285 | 286 | 287 | def prepare_common_voice_dataset(batch): 288 | # check that all files have the correct sampling rate 289 | 290 | batch["input_values"] = re.search(r'(.*?)\.mp3', batch['path']).group(1)+'.wav' 291 | batch['labels'] = batch['sentence'] 292 | return batch 293 | 294 | 295 | @dataclass 296 | class SpeechCollatorWithPadding: 297 | """ 298 | Data collator that will dynamically pad the inputs received. 299 | Args: 300 | processor (:class:`~transformers.Wav2Vec2Processor`) 301 | The processor used for proccessing the data. 302 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 303 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 304 | among: 305 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 306 | sequence if provided). 307 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 308 | maximum acceptable input length for the model if that argument is not provided. 309 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 310 | different lengths). 311 | max_length (:obj:`int`, `optional`): 312 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). 313 | max_length_labels (:obj:`int`, `optional`): 314 | Maximum length of the ``labels`` returned list and optionally padding length (see above). 315 | pad_to_multiple_of (:obj:`int`, `optional`): 316 | If set will pad the sequence to a multiple of the provided value. 317 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 318 | 7.5 (Volta). 319 | """ 320 | 321 | processor: Wav2Vec2Processor 322 | padding: Union[bool, str] = True 323 | return_attention_mask: Optional[bool] = True 324 | max_length: Optional[int] = None 325 | max_length_labels: Optional[int] = 256 326 | pad_to_multiple_of: Optional[int] = None 327 | pad_to_multiple_of_labels: Optional[int] = None 328 | 329 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 330 | # split inputs and labels since they have to be of different lenghts and need 331 | # different padding methods 332 | 333 | # get phone features 334 | label_features = [{"input_ids": get_phone_ids(get_phones(feature["labels"]))[:self.max_length_labels]} for feature in features] 335 | text_len = [len(i['input_ids']) for i in label_features] 336 | 337 | with self.processor.as_target_processor(): 338 | label_batch = self.processor.pad( 339 | label_features, 340 | padding=self.padding, 341 | max_length=self.max_length_labels, 342 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 343 | return_attention_mask=self.return_attention_mask, 344 | return_tensors="pt", 345 | ) 346 | 347 | # get speech features 348 | input_features = [{"input_values": audio_preprocess(feature["input_values"])} for feature in features] 349 | mel_len = [model._get_feat_extract_output_lengths(len(i['input_values'])) for i in input_features] 350 | batch = self.processor.pad( 351 | input_features, 352 | padding=self.padding, 353 | max_length=self.max_length, 354 | pad_to_multiple_of=self.pad_to_multiple_of, 355 | return_attention_mask=self.return_attention_mask, 356 | return_tensors="pt", 357 | ) 358 | batch_size, raw_sequence_length = batch['input_values'].shape 359 | sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) 360 | batch['mask_time_indices'] = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.075, mask_length=2,device='cpu') 361 | batch['frame_len'] = torch.tensor(mel_len) 362 | 363 | batch['labels'] = label_batch['input_ids'] 364 | batch['text_len'] = torch.tensor(text_len) 365 | batch['labels_attention_mask'] = label_batch['attention_mask'] 366 | return batch 367 | 368 | 369 | if __name__ == "__main__": 370 | 371 | output_dir = '/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr/model/neural_aligner_common_voice_10ms' 372 | 373 | device = 'cuda' 374 | 375 | 376 | ''' 377 | Load tokenizers and processors 378 | ''' 379 | g2p = G2p() 380 | mapping_phone2id = json.load(open("./vocab.json",'r')) 381 | mapping_id2phone = {v:k for k,v in mapping_phone2id.items()} 382 | 383 | tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="") 384 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 385 | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 386 | 387 | ''' 388 | Load dataset 389 | ''' 390 | common_voice = load_from_disk('/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr/common_voice_filtered') 391 | common_voice = common_voice.map(prepare_common_voice_dataset, remove_columns=common_voice.column_names) 392 | print(len(common_voice)) 393 | print(common_voice[0]) 394 | 395 | speech_collator = SpeechCollatorWithPadding(processor=processor) 396 | 397 | 398 | ''' 399 | Load model 400 | ''' 401 | resolution = 10 402 | 403 | model = Wav2Vec2ForAlignment.from_pretrained('./models/alignment_initialized_weights') 404 | model.initialize_phone_model('./models/bert-phones/checkpoint-36000') 405 | weights = torch.load('./models/alignment_initial_weights') 406 | model.load_state_dict(weights) 407 | 408 | if resolution == 10: 409 | model.wav2vec2.feature_extractor.conv_layers[6].conv.stride = (1,) 410 | model.config.conv_stride[-1] = 1 411 | print('The resolution is %s'%resolution) 412 | model.freeze_feature_extractor() 413 | 414 | ''' 415 | Training loop 416 | ''' 417 | training_args = TrainingArguments( 418 | output_dir=output_dir, 419 | group_by_length=True, 420 | per_device_train_batch_size=4, 421 | gradient_accumulation_steps=16, 422 | # evaluation_strategy="steps", 423 | num_train_epochs=2, 424 | fp16=True, 425 | save_steps=500, 426 | # eval_steps=1000, 427 | logging_steps=100, 428 | learning_rate=5e-4, 429 | weight_decay=1e-6, 430 | warmup_steps=1000, 431 | save_total_limit=2, 432 | ignore_data_skip=True, 433 | ) 434 | 435 | 436 | trainer = Trainer( 437 | model=model, 438 | data_collator=speech_collator, 439 | args=training_args, 440 | # compute_metrics=compute_metrics, 441 | train_dataset=common_voice, 442 | # eval_dataset=libris_train_prepared, 443 | tokenizer=processor.feature_extractor, 444 | ) 445 | 446 | 447 | trainer.train() 448 | -------------------------------------------------------------------------------- /experiments/train_asr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import re 6 | import json 7 | import transformers 8 | import soundfile as sf 9 | import jiwer 10 | import torch 11 | import argparse 12 | 13 | from dataclasses import dataclass, field 14 | from typing import Any, Dict, List, Optional, Union 15 | from g2p_en import G2p 16 | import numpy as np 17 | from datasets import concatenate_datasets, load_dataset, load_metric, load_from_disk 18 | from transformers import Wav2Vec2CTCTokenizer 19 | from transformers import Wav2Vec2FeatureExtractor 20 | from transformers import Trainer,TrainingArguments 21 | from transformers import Wav2Vec2Processor 22 | from transformers import Wav2Vec2ForCTC 23 | 24 | 25 | def get_phones(sen): 26 | ''' 27 | convert texts to phone sequence 28 | ''' 29 | sen = g2p(sen) 30 | sen = [re.sub(r'\d','',p) for p in sen] 31 | return sen 32 | 33 | def get_phone_ids(phones): 34 | ''' 35 | convert phone sequence to ids 36 | ''' 37 | ids = [] 38 | punctuation = set('.,!?') 39 | for p in phones: 40 | if re.match(r'^\w+?$',p): 41 | ids.append(mapping_phone2id.get(p,mapping_phone2id['[UNK]'])) 42 | elif p in punctuation: 43 | ids.append(mapping_phone2id.get('[SIL]')) 44 | ids = [0]+ids 45 | if ids[-1]!=0: 46 | ids.append(0) 47 | return ids #append silence token at the beginning 48 | 49 | def prepare_dataset(batch): 50 | # check that all files have the correct sampling rate 51 | 52 | batch["input_values"] = batch['file'] 53 | batch["labels"] = batch['text'] 54 | return batch 55 | 56 | def prepare_common_voice_dataset(batch): 57 | # check that all files have the correct sampling rate 58 | 59 | batch["input_values"] = re.search(r'(.*?)\.mp3', batch['path']).group(1)+'.wav' 60 | batch['labels'] = batch['sentence'] 61 | return batch 62 | 63 | def audio_preprocess(path): 64 | 65 | features, sr = sf.read(path) 66 | return processor(features, sampling_rate=16000).input_values.squeeze() 67 | 68 | 69 | 70 | @dataclass 71 | class DataCollatorCTCWithPadding: 72 | """ 73 | Data collator that will dynamically pad the inputs received. 74 | Args: 75 | processor (:class:`~transformers.Wav2Vec2Processor`) 76 | The processor used for proccessing the data. 77 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 78 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 79 | among: 80 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 81 | sequence if provided). 82 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 83 | maximum acceptable input length for the model if that argument is not provided. 84 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 85 | different lengths). 86 | max_length (:obj:`int`, `optional`): 87 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). 88 | max_length_labels (:obj:`int`, `optional`): 89 | Maximum length of the ``labels`` returned list and optionally padding length (see above). 90 | pad_to_multiple_of (:obj:`int`, `optional`): 91 | If set will pad the sequence to a multiple of the provided value. 92 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 93 | 7.5 (Volta). 94 | """ 95 | 96 | processor: Wav2Vec2Processor 97 | padding: Union[bool, str] = True 98 | max_length: Optional[int] = None 99 | max_length_labels: Optional[int] = None 100 | pad_to_multiple_of: Optional[int] = None 101 | pad_to_multiple_of_labels: Optional[int] = None 102 | 103 | 104 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 105 | # split inputs and labels since they have to be of different lenghts and need 106 | # different padding methods 107 | input_features = [{"input_values": audio_preprocess(feature["input_values"])} for feature in features] 108 | label_features = [{"input_ids": get_phone_ids(get_phones(feature["labels"]))} for feature in features] 109 | 110 | batch = self.processor.pad( 111 | input_features, 112 | padding=self.padding, 113 | max_length=self.max_length, 114 | pad_to_multiple_of=self.pad_to_multiple_of, 115 | return_attention_mask=True, 116 | return_tensors="pt", 117 | ) 118 | 119 | with self.processor.as_target_processor(): 120 | labels_batch = self.processor.pad( 121 | label_features, 122 | padding=self.padding, 123 | max_length=self.max_length_labels, 124 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 125 | return_attention_mask=True, 126 | return_tensors="pt", 127 | ) 128 | 129 | # replace padding with -100 to ignore loss correctly 130 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) 131 | 132 | batch["labels"] = labels 133 | 134 | return batch 135 | 136 | wer_metric = load_metric("wer") 137 | 138 | def compute_metrics(pred): 139 | pred_logits = pred.predictions 140 | pred_ids = np.argmax(pred_logits, axis=-1) 141 | 142 | pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id 143 | 144 | pred_str = processor.batch_decode(pred_ids) 145 | # we do not want to group tokens when computing the metrics 146 | label_str = processor.batch_decode(pred.label_ids, group_tokens=False) 147 | 148 | wer = wer_metric.compute(predictions=pred_str, references=label_str) 149 | 150 | return {"wer": wer} 151 | 152 | 153 | 154 | 155 | def map_to_result(batch): 156 | model.to("cuda") 157 | input_values = processor( 158 | batch["speech"], 159 | sampling_rate=16000, 160 | return_tensors="pt" 161 | ).input_values.to("cuda") 162 | 163 | with torch.no_grad(): 164 | logits = model(input_values).logits 165 | 166 | pred_ids = torch.argmax(logits, dim=-1) 167 | batch["pred_str"] = processor.batch_decode(pred_ids)[0] 168 | 169 | return batch 170 | 171 | g2p = G2p() 172 | 173 | mapping_phone2id = json.load(open("vocab-ctc.json",'r')) 174 | mapping_id2phone = {v:k for k,v in mapping_phone2id.items()} 175 | 176 | 177 | tokenizer = Wav2Vec2CTCTokenizer("vocab-ctc.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="") 178 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 179 | 180 | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 181 | 182 | 183 | 184 | if __name__ == "__main__": 185 | 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument('--train_data', type=str, default='/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr/librispeech_train360') 188 | parser.add_argument('--val_data', default=None,type=str) 189 | parser.add_argument('--test_data',default=None,type=str) 190 | parser.add_argument('--out_dir',type=str,default="./models/wav2vec2-base-360") 191 | 192 | args = parser.parse_args() 193 | 194 | # loading data 195 | libris_train = load_from_disk(args.train_data) 196 | libris_train_prepared = libris_train.map(prepare_dataset, remove_columns=libris_train.column_names) 197 | 198 | common_voice = load_from_disk('/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr/common_voice_filtered') 199 | common_voice = common_voice.map(prepare_common_voice_dataset, remove_columns=common_voice.column_names) 200 | 201 | libris_train_prepared = concatenate_datasets([libris_train_prepared, common_voice]) 202 | 203 | 204 | if args.val_data: 205 | libris_val = load_from_disk('/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr/librispeech_val') 206 | libris_val_prepared = libris_train.map(prepare_dataset, remove_columns=libris_val.column_names, batch_size=8, batched=True) 207 | 208 | if args.test_data: 209 | libris_test = load_from_disk('/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr/librispeech_test') 210 | libris_test_prepared = libris_train.map(prepare_dataset, remove_columns=libris_test.column_names, batch_size=8, batched=True) 211 | 212 | 213 | 214 | 215 | # data loader 216 | data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True) 217 | 218 | 219 | # load model 220 | model = Wav2Vec2ForCTC.from_pretrained( 221 | "facebook/wav2vec2-base", 222 | gradient_checkpointing=True, 223 | ctc_loss_reduction="mean", 224 | pad_token_id=processor.tokenizer.pad_token_id, 225 | vocab_size = len(processor.tokenizer) 226 | ) 227 | # model.wav2vec2.feature_extractor.conv_layers[6].conv.stride = (1,) 228 | # model.config.conv_stride[-1] = 1 229 | model.freeze_feature_extractor() 230 | 231 | 232 | # training settings 233 | training_args = TrainingArguments( 234 | output_dir=args.out_dir, 235 | group_by_length=True, 236 | per_device_train_batch_size=4, 237 | gradient_accumulation_steps=32, 238 | num_train_epochs=2, 239 | fp16=True, 240 | save_steps=500, 241 | logging_steps=500, 242 | learning_rate=3e-4, 243 | weight_decay=0.00005, 244 | warmup_steps=1000, 245 | save_total_limit=2, 246 | ) 247 | 248 | 249 | trainer = Trainer( 250 | model=model, 251 | data_collator=data_collator, 252 | args=training_args, 253 | compute_metrics=compute_metrics, 254 | train_dataset=libris_train_prepared, 255 | # eval_dataset=libris_val_prepared, 256 | tokenizer=processor.feature_extractor, 257 | ) 258 | 259 | 260 | trainer.train() 261 | 262 | ''' 263 | results = timit["test"].map(map_to_result) 264 | print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["target_text"]))) 265 | show_random_elements(results.remove_columns(["speech", "sampling_rate"])) 266 | model.to("cuda") 267 | input_values = processor(timit["test"][2]["speech"], sampling_rate=timit["test"][0]["sampling_rate"], return_tensors="pt").input_values.to("cuda") 268 | 269 | with torch.no_grad(): 270 | logits = model(input_values).logits 271 | 272 | pred_ids = torch.argmax(logits, dim=-1) 273 | 274 | print(timit["test"][2]["target_text"]) 275 | # convert ids to tokens 276 | " ".join(processor.tokenizer.convert_ids_to_tokens(pred_ids[0].tolist())) 277 | ''' 278 | 279 | 280 | 281 | 282 | 283 | 284 | -------------------------------------------------------------------------------- /experiments/train_attention_aligner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import re 5 | import os 6 | import argparse 7 | import transformers 8 | import soundfile as sf 9 | import librosa 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch.utils.data import DataLoader 16 | from tqdm import tqdm 17 | from itertools import chain,groupby 18 | import json 19 | from collections import defaultdict 20 | from praatio import textgrid 21 | 22 | from dataclasses import dataclass, field 23 | from typing import Any, Dict, List, Optional, Union, Tuple 24 | from g2p_en import G2p 25 | from datasets import load_dataset, load_metric, load_from_disk 26 | from transformers import Wav2Vec2CTCTokenizer 27 | from transformers import Wav2Vec2FeatureExtractor 28 | from transformers import Trainer,TrainingArguments 29 | from transformers import Wav2Vec2Processor 30 | from transformers import Wav2Vec2Config, Wav2Vec2ForPreTraining 31 | from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices,Wav2Vec2ForPreTrainingOutput 32 | from transformers.file_utils import ModelOutput 33 | from transformers.modeling_outputs import CausalLMOutput, MaskedLMOutput 34 | from transformers import BertForMaskedLM, BertConfig 35 | from transformers import AdamW 36 | 37 | 38 | class ForwardSumLoss(torch.nn.Module): 39 | def __init__(self, blank_logprob=-1): 40 | super(ForwardSumLoss, self).__init__() 41 | self.log_softmax = torch.nn.LogSoftmax(dim=3) 42 | self.blank_logprob = blank_logprob 43 | # self.off_diag_penalty = off_diag_penalty 44 | self.CTCLoss = nn.CTCLoss(zero_infinity=True) 45 | 46 | def forward(self, attn_logprob, text_lens, mel_lens): 47 | """ 48 | Args: 49 | attn_logprob: batch x 1 x max(mel_lens) x max(text_lens) 50 | batched tensor of attention log 51 | probabilities, padded to length 52 | of longest sequence in each dimension 53 | text_lens: batch-D vector of length of 54 | each text sequence 55 | mel_lens: batch-D vector of length of 56 | each mel sequence 57 | """ 58 | # The CTC loss module assumes the existence of a blank token 59 | # that can be optionally inserted anywhere in the sequence for 60 | # a fixed probability. 61 | # A row must be added to the attention matrix to account for this 62 | attn_logprob_pd = F.pad(input=attn_logprob, 63 | pad=(1, 0, 0, 0, 0, 0, 0, 0), 64 | value=self.blank_logprob) 65 | cost_total = 0.0 66 | # for-loop over batch because of variable-length 67 | # sequences 68 | for bid in range(attn_logprob.shape[0]): 69 | # construct the target sequence. Every 70 | # text token is mapped to a unique sequence number, 71 | # thereby ensuring the monotonicity constraint 72 | target_seq = torch.arange(1, text_lens[bid]+1) 73 | target_seq=target_seq.unsqueeze(0) 74 | curr_logprob = attn_logprob_pd[bid].permute(1, 0, 2) 75 | curr_logprob = curr_logprob[:mel_lens[bid],:,:text_lens[bid]+1] 76 | 77 | # curr_logprob = curr_logprob + self.off_diagonal_loss(curr_logprob,text_lens[bid]+1,mel_lens[bid]) 78 | curr_logprob = self.log_softmax(curr_logprob[None])[0] 79 | cost = self.CTCLoss(curr_logprob, 80 | target_seq, 81 | input_lengths=mel_lens[bid:bid+1], 82 | target_lengths=text_lens[bid:bid+1]) 83 | cost_total += cost 84 | # average cost over batch 85 | cost_total = cost_total/attn_logprob.shape[0] 86 | return cost_total 87 | 88 | def off_diagonal_prior(self,log_prob,N, T, g=0.2): 89 | 90 | n = torch.arange(N).to(log_prob.device) 91 | t = torch.arange(T).to(log_prob.device) 92 | t = t.unsqueeze(1).repeat(1,N) 93 | n = n.unsqueeze(0).repeat(T,1) 94 | 95 | 96 | # W = 1 - torch.exp(-(n/N - t/T)**2/(2*g**2)) 97 | 98 | # penalty = log_prob*W.unsqueeze(1) 99 | # return torch.mean(penalty) 100 | W = torch.exp(-(n/N - t/T)**2/(2*g**2)) 101 | 102 | return torch.log_softmax(W.unsqueeze(1),dim=-1) 103 | 104 | 105 | class ConvBank(nn.Module): 106 | def __init__(self, input_dim, output_class_num, kernels, cnn_size, hidden_size, dropout, **kwargs): 107 | super(ConvBank, self).__init__() 108 | self.drop_p = dropout 109 | 110 | self.in_linear = nn.Linear(input_dim, hidden_size) 111 | latest_size = hidden_size 112 | 113 | # conv bank 114 | self.cnns = nn.ModuleList() 115 | assert len(kernels) > 0 116 | for kernel in kernels: 117 | self.cnns.append(nn.Conv1d(latest_size, cnn_size, kernel, padding=kernel//2)) 118 | latest_size = cnn_size * len(kernels) 119 | 120 | self.out_linear = nn.Linear(latest_size, output_class_num) 121 | 122 | def forward(self, features): 123 | hidden = F.dropout(F.relu(self.in_linear(features)), p=self.drop_p) 124 | 125 | conv_feats = [] 126 | hidden = hidden.transpose(1, 2).contiguous() 127 | for cnn in self.cnns: 128 | conv_feats.append(cnn(hidden)) 129 | hidden = torch.cat(conv_feats, dim=1).transpose(1, 2).contiguous() 130 | hidden = F.dropout(F.relu(hidden), p=self.drop_p) 131 | 132 | predicted = self.out_linear(hidden) 133 | return predicted 134 | 135 | class RNN(nn.Module): 136 | 137 | def __init__(self,hidden_dim,out_dim): 138 | super().__init__() 139 | 140 | self.lstm = nn.LSTM(hidden_dim,hidden_dim,bidirectional=True,num_layers=1,batch_first=True) 141 | self.linear = nn.Sequential(nn.Linear(2*hidden_dim,hidden_dim), 142 | nn.ReLU(), 143 | nn.Linear(hidden_dim,out_dim)) 144 | 145 | 146 | def forward(self, embeddings, lens): 147 | 148 | packed_input = pack_padded_sequence(embeddings, lens.cpu(), batch_first=True,enforce_sorted=False) 149 | packed_output, (ht, ct)= self.lstm(packed_input) 150 | out, _ = pad_packed_sequence(packed_output, batch_first=True) 151 | out = self.linear(out) 152 | return out 153 | 154 | 155 | class Wav2Vec2ForAttentionAlignment(Wav2Vec2ForPreTraining): 156 | 157 | def __init__(self,config): 158 | super().__init__(config) 159 | self.bert = BertForMaskedPhoneLM(config.bert_config) 160 | self.cnn = ConvBank(config.hidden_size,384,[1],384,384,0.1) 161 | #self.lm_head = nn.Linear(config.hidden_size,config.vocab_size) 162 | # self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) 163 | self.phone_rnn = RNN(384,config.vocab_size) 164 | 165 | self.attention = Attention(384) 166 | self.align_loss = ForwardSumLoss() 167 | 168 | def freeze_wav2vec2(self): 169 | for param in self.wav2vec2.parameters(): 170 | param.requires_grad = False 171 | 172 | def initialize_phone_model(self,path): 173 | 174 | self.bert = BertForMaskedPhoneLM.from_pretrained(path) 175 | 176 | 177 | def forward( 178 | self, 179 | input_values, 180 | attention_mask=None, 181 | output_attentions=None, 182 | output_hidden_states=None, 183 | mask_time_indices=None, 184 | return_dict=None, 185 | labels=None, 186 | labels_attention_mask=None, 187 | text_len=None, 188 | frame_len=None 189 | ): 190 | 191 | 192 | 193 | outputs = self.wav2vec2( 194 | input_values, 195 | attention_mask=attention_mask, 196 | output_attentions=output_attentions, 197 | output_hidden_states=output_hidden_states, 198 | mask_time_indices=mask_time_indices, 199 | return_dict=return_dict, 200 | ) 201 | 202 | # acoustic embeddings 203 | frame_hidden = outputs[0] 204 | # frame_hidden = self.dropout(frame_hidden) 205 | frame_hidden = self.cnn(frame_hidden) 206 | 207 | # phone embeddings 208 | phone_hidden = self.bert(input_ids=labels,attention_mask=labels_attention_mask).hidden_states[-1] 209 | 210 | # compute cross attention 211 | att_out,energy = self.attention(frame_hidden,phone_hidden,labels_attention_mask) 212 | 213 | 214 | # start masked modeling 215 | # 0. remove the blank symbol 216 | # 1. project all transformed features (including masked) to final vq dim 217 | transformer_features = self.project_hid(torch.tanh(att_out)) 218 | 219 | 220 | # 2. quantize all (unmasked) extracted features and project to final vq dim 221 | extract_features = self.dropout_features(outputs[1]) 222 | quantized_features, codevector_perplexity = self.quantizer(extract_features, mask_time_indices) 223 | quantized_features = self.project_q(quantized_features) 224 | 225 | 226 | # if attention_mask is passed, make sure that padded feature vectors cannot be sampled 227 | if attention_mask is not None: 228 | # compute reduced attention_mask correponding to feature vectors 229 | attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) 230 | 231 | # loss_fct = nn.CrossEntropyLoss() 232 | 233 | # phone_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 234 | 235 | 236 | loss = None 237 | if self.training: 238 | # for training, we sample negatives 239 | # 3. sample K negatives (distractors) quantized states for contrastive loss 240 | 241 | negative_quantized_features = self._sample_negatives( 242 | quantized_features, self.config.num_negatives, attention_mask=attention_mask 243 | ) 244 | 245 | # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` 246 | # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf 247 | logits = self.compute_contrastive_logits( 248 | quantized_features[None, :], 249 | negative_quantized_features, 250 | transformer_features, 251 | self.config.contrastive_logits_temperature, 252 | ) 253 | 254 | # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low), 255 | # its cosine similarity will be masked 256 | neg_is_pos = (quantized_features == negative_quantized_features).all(-1) 257 | if neg_is_pos.any(): 258 | logits[1:][neg_is_pos] = float("-inf") 259 | 260 | # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = 261 | # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) 262 | preds = logits.transpose(0, 2).reshape(-1, logits.size(0)) 263 | target = ((1 - attention_mask.long()) * -100).transpose(0, 1).flatten() 264 | contrastive_loss = nn.functional.cross_entropy(preds.float(), target, reduction="mean") 265 | 266 | # 7. compute diversity loss: \mathbf{L}_d 267 | # num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups 268 | # diversity_loss = (num_codevectors - codevector_perplexity) / num_codevectors 269 | 270 | # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d 271 | expanded_labels_attention_mask = (1-labels_attention_mask)*-10000.0 272 | expanded_labels_attention_mask = expanded_labels_attention_mask.unsqueeze(1).repeat(1,energy.size(1),1) 273 | att = torch.log_softmax(energy+expanded_labels_attention_mask,dim=-1) 274 | align_loss = self.align_loss(att.unsqueeze(1),text_len,frame_len) 275 | 276 | # expanded_attention_mask = attention_mask.unsqueeze(2).repeat(1,1,energy.size(2)) * labels_attention_mask.unsqueeze(1).repeat(1,energy.size(1),1) 277 | # expanded_attention_mask = (1-expanded_attention_mask)*-10000.0 278 | # phone_attention = torch.softmax((energy+expanded_attention_mask).transpose(2,1),dim=-1) 279 | # phone_emb = torch.bmm(phone_attention,frame_hidden) 280 | # prediction_scores = self.phone_rnn(phone_emb,text_len) 281 | # labels = labels.masked_fill(labels_attention_mask.ne(1), -100) 282 | # inter_phone = F.cosine_similarity(phone_emb[:,:-1,:],phone_emb[:,1:,:],dim=-1)*labels_attention_mask[:,1:] 283 | # interphone_loss = torch.sum(inter_phone)/torch.sum(labels_attention_mask[:,1:]) 284 | 285 | 286 | loss = contrastive_loss + WEIGHT*align_loss #+ interphone_loss 287 | 288 | 289 | return CausalLMOutput( 290 | loss=loss, logits=transformer_features, hidden_states=outputs.hidden_states, attentions=energy 291 | ) 292 | 293 | 294 | class BertForMaskedPhoneLM(BertForMaskedLM): 295 | 296 | def __init__(self,config): 297 | super().__init__(config) 298 | self.cnn = ConvBank(config.hidden_size,384,[1],384,384,0.1) 299 | 300 | def freeze_feature_extractor(self): 301 | for param in self.bert.parameters(): 302 | param.requires_grad = False 303 | 304 | def forward( 305 | self, 306 | input_ids=None, 307 | attention_mask=None, 308 | token_type_ids=None, 309 | position_ids=None, 310 | head_mask=None, 311 | inputs_embeds=None, 312 | encoder_hidden_states=None, 313 | encoder_attention_mask=None, 314 | labels=None, 315 | output_attentions=None, 316 | output_hidden_states=True, 317 | ): 318 | 319 | 320 | outputs = self.bert( 321 | input_ids, 322 | attention_mask=attention_mask, 323 | token_type_ids=token_type_ids, 324 | position_ids=position_ids, 325 | head_mask=head_mask, 326 | inputs_embeds=inputs_embeds, 327 | encoder_hidden_states=encoder_hidden_states, 328 | encoder_attention_mask=encoder_attention_mask, 329 | output_attentions=output_attentions, 330 | output_hidden_states=output_hidden_states 331 | ) 332 | 333 | 334 | prediction_scores = self.cnn(outputs.hidden_states[-1]) 335 | 336 | masked_lm_loss = None 337 | if labels is not None: 338 | loss_fct = CrossEntropyLoss() # -100 index = padding token 339 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) 340 | 341 | return MaskedLMOutput( 342 | loss=masked_lm_loss, 343 | logits=prediction_scores, 344 | hidden_states=outputs.hidden_states, 345 | attentions=outputs.attentions, 346 | ) 347 | 348 | 349 | 350 | class Attention(nn.Module): 351 | 352 | def __init__(self,hidden_dim): 353 | super().__init__() 354 | self.q = nn.Linear(hidden_dim, hidden_dim) 355 | self.k = nn.Linear(hidden_dim, hidden_dim) 356 | # self.v = nn.Linear(hidden_dim*2, hidden_dim*2) 357 | self.layer_norm = nn.LayerNorm(hidden_dim) 358 | 359 | def forward(self,frame_hidden, phone_hidden,labels_attention_mask): 360 | 361 | frame_hidden = self.q(frame_hidden) 362 | phone_hidden = self.k(phone_hidden) 363 | 364 | energy = torch.bmm(frame_hidden,phone_hidden.transpose(2,1)) 365 | attention_mask = (1-labels_attention_mask)*-10000.0 366 | energy = energy+attention_mask.unsqueeze(1).repeat(1,energy.size(1),1) 367 | 368 | att_matrix = torch.softmax(energy,dim=-1) 369 | att_out = torch.bmm(att_matrix,phone_hidden) 370 | att_out = torch.cat([att_out,frame_hidden],dim=-1) 371 | # att_out = self.layer_norm(att_out + frame_hidden) 372 | 373 | return att_out, energy 374 | 375 | 376 | 377 | 378 | 379 | def get_phones(sen): 380 | ''' 381 | convert texts to phone sequence 382 | ''' 383 | sen = g2p(sen) 384 | sen = [re.sub(r'\d','',p) for p in sen] 385 | return sen 386 | 387 | def get_phone_ids(phones): 388 | ''' 389 | convert phone sequence to ids 390 | ''' 391 | ids = [] 392 | punctuation = set('.,!?') 393 | for p in phones: 394 | if re.match(r'^\w+?$',p): 395 | ids.append(mapping_phone2id.get(p,mapping_phone2id['[UNK]'])) 396 | elif p in punctuation: 397 | ids.append(mapping_phone2id.get('[SIL]')) 398 | ids = [0]+ids 399 | if ids[-1]!=0: 400 | ids.append(0) 401 | return ids #append silence token at the beginning 402 | 403 | def audio_preprocess(path): 404 | if SAMPLING_RATE==32000: 405 | features, _ = librosa.core.load(path,sr=32000) 406 | else: 407 | features, _ = sf.read(path) 408 | return processor(features, sampling_rate=16000,return_tensors='pt').input_values.squeeze() 409 | 410 | def seq2duration(phones,resolution=0.02): 411 | counter = 0 412 | out = [] 413 | for p,group in groupby(phones): 414 | length = len(list(group)) 415 | out.append((round(counter*resolution,2),round((counter+length)*resolution,2),p)) 416 | counter += length 417 | return out 418 | 419 | def prepare_common_voice_dataset(batch): 420 | # check that all files have the correct sampling rate 421 | 422 | batch["input_values"] = re.search(r'(.*?)\.mp3', batch['path']).group(1)+'.wav' 423 | batch['labels'] = batch['sentence'] 424 | return batch 425 | 426 | 427 | @dataclass 428 | class SpeechCollatorWithPadding: 429 | 430 | processor: Wav2Vec2Processor 431 | padding: Union[bool, str] = True 432 | return_attention_mask: Optional[bool] = True 433 | max_length: Optional[int] = None 434 | max_length_labels: Optional[int] = None 435 | pad_to_multiple_of: Optional[int] = None 436 | pad_to_multiple_of_labels: Optional[int] = None 437 | 438 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 439 | # split inputs and labels since they have to be of different lenghts and need 440 | # different padding methods 441 | 442 | # get phone features 443 | label_features = [{"input_ids": get_phone_ids(get_phones(feature["labels"]))} for feature in features] 444 | text_len = [len(i['input_ids']) for i in label_features] 445 | 446 | with self.processor.as_target_processor(): 447 | labels_batch = self.processor.pad( 448 | label_features, 449 | padding=self.padding, 450 | max_length=self.max_length_labels, 451 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 452 | return_attention_mask=self.return_attention_mask, 453 | return_tensors="pt", 454 | ) 455 | 456 | 457 | # get speech features 458 | input_features = [{"input_values": audio_preprocess(feature["input_values"])} for feature in features] 459 | mel_len = [model._get_feat_extract_output_lengths(len(i['input_values'])) for i in input_features] 460 | batch = self.processor.pad( 461 | input_features, 462 | padding=self.padding, 463 | max_length=self.max_length, 464 | pad_to_multiple_of=self.pad_to_multiple_of, 465 | return_attention_mask=self.return_attention_mask, 466 | return_tensors="pt", 467 | ) 468 | batch_size, raw_sequence_length = batch['input_values'].shape 469 | sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) 470 | 471 | mask_prob = torch.randint(size=(1,),low=1,high=40)/100 472 | batch['mask_time_indices'] = _compute_mask_indices((batch_size, sequence_length), mask_prob=mask_prob, mask_length=2,device='cpu') 473 | batch['frame_len'] = torch.tensor(mel_len) 474 | batch["text_len"] = torch.tensor(text_len) 475 | batch['labels'] = labels_batch["input_ids"]#.masked_fill(labels_batch.attention_mask.ne(1), -100) 476 | batch['labels_attention_mask'] = labels_batch['attention_mask'] 477 | 478 | return batch 479 | 480 | 481 | 482 | 483 | if __name__ == "__main__": 484 | 485 | parser = argparse.ArgumentParser() 486 | parser.add_argument('--train_data', type=str, default='/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr/common_voice_filtered') 487 | parser.add_argument('--val_data', default=None,type=str) 488 | parser.add_argument('--test_data',default=None,type=str) 489 | parser.add_argument('--out_dir',type=str,default="./models/wav2vec2-base-cv-attention-align-10ms-1.0") 490 | parser.add_argument('--weight',type=float,default=1.0) 491 | parser.add_argument('--sampling_rate',type=float,default=16000) 492 | 493 | args = parser.parse_args() 494 | 495 | WEIGHT = args.weight 496 | SAMPLING_RATE = args.sampling_rate 497 | 498 | g2p = G2p() 499 | mapping_phone2id = json.load(open("vocab-ctc.json",'r')) 500 | mapping_id2phone = {v:k for k,v in mapping_phone2id.items()} 501 | 502 | tokenizer = Wav2Vec2CTCTokenizer("vocab-ctc.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="") 503 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 504 | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 505 | 506 | 507 | config = Wav2Vec2Config.from_pretrained('facebook/wav2vec2-base') 508 | bert_config = BertConfig.from_pretrained('./models/bert-phones/checkpoint-36000') 509 | config.bert_config = bert_config 510 | config.pad_token_id = tokenizer.pad_token_id 511 | config.vocab_size = len(tokenizer) 512 | config.ctc_loss_reduction = 'mean' 513 | model = Wav2Vec2ForAttentionAlignment(config) 514 | #model.freeze_feature_extractor() 515 | # model.initialize_phone_model('../bert-phone') 516 | 517 | # weights = Wav2Vec2ForCTC.from_pretrained('facebook/wav2vec2-base').state_dict() 518 | # del weights['lm_head.bias'] 519 | # del weights['lm_head.weight'] 520 | weights = torch.load('./models/neural_attention_aligner_forwardsum_10ms_true_quantizer.pt').state_dict() 521 | state_dict = model.state_dict() 522 | # weights = {k:v for k,v in weights.items() if k in state_dict.keys()} 523 | state_dict.update(weights) 524 | 525 | model.load_state_dict(state_dict) 526 | if SAMPLING_RATE != 32000: 527 | model.wav2vec2.feature_extractor.conv_layers[6].conv.stride = (1,) 528 | model.config.conv_stride[-1] = 1 529 | model.freeze_feature_extractor() 530 | #model.bert.freeze_feature_extractor() 531 | model.config.bert_config = None 532 | 533 | model.config.num_negatives = 50 534 | for param in model.quantizer.parameters(): 535 | param.requires_grad = False 536 | for param in model.project_q.parameters(): 537 | param.requires_grad = False 538 | 539 | 540 | ''' 541 | Load dataset 542 | ''' 543 | common_voice = load_from_disk('/gpfs/accounts/lingjzhu_root/lingjzhu1/lingjzhu/asr/common_voice_filtered') 544 | common_voice = common_voice.map(prepare_common_voice_dataset, remove_columns=common_voice.column_names) 545 | print(len(common_voice)) 546 | print(common_voice[0]) 547 | 548 | data_collator = SpeechCollatorWithPadding(processor=processor) 549 | 550 | 551 | 552 | # training settings 553 | training_args = TrainingArguments( 554 | output_dir=args.out_dir, 555 | group_by_length=True, 556 | per_device_train_batch_size=4, 557 | gradient_accumulation_steps=16, 558 | # evaluation_strategy="steps", 559 | num_train_epochs=1, 560 | fp16=False, 561 | save_steps=500, 562 | # eval_steps=1000, 563 | logging_steps=500, 564 | learning_rate=3e-4, 565 | weight_decay=0.0001, 566 | warmup_steps=500, 567 | save_total_limit=2, 568 | ignore_data_skip=True, 569 | ) 570 | 571 | 572 | trainer = Trainer( 573 | model=model, 574 | data_collator=data_collator, 575 | args=training_args, 576 | # compute_metrics=compute_metrics, 577 | train_dataset=common_voice, 578 | # eval_dataset=libris_train_prepared, 579 | tokenizer=processor.feature_extractor, 580 | ) 581 | 582 | 583 | trainer.train() 584 | -------------------------------------------------------------------------------- /experiments/train_frame_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import re 6 | import transformers 7 | import soundfile as sf 8 | import torch 9 | import json 10 | import numpy as np 11 | 12 | from dataclasses import dataclass, field 13 | from typing import Any, Dict, List, Optional, Union 14 | from g2p_en import G2p 15 | from datasets import load_dataset, load_metric, load_from_disk 16 | from transformers import Wav2Vec2CTCTokenizer 17 | from transformers import Wav2Vec2FeatureExtractor 18 | from transformers import Trainer,TrainingArguments 19 | from transformers import Wav2Vec2Processor 20 | from transformers import Wav2Vec2ForCTC, Wav2Vec2Config 21 | from transformers.modeling_outputs import CausalLMOutput, MaskedLMOutput 22 | 23 | 24 | def prepare_dataset_20ms(batch): 25 | 26 | batch["input_values"] = batch['file'] 27 | batch["labels"] = [mapping_phone2id[p] for p in batch['frame_labels']] 28 | assert len(batch['frame_labels']) == len(batch['labels']) 29 | return batch 30 | 31 | 32 | def prepare_dataset_10ms(batch): 33 | 34 | batch["input_values"] = batch['file'] 35 | # batch["labels"] = [mapping_phone2id[p] for p in batch['frame_labels_10ms']] 36 | batch["labels"] = [mapping_phone2id[p] for p in batch['labels']] 37 | assert len(batch['frame_labels_10ms']) == len(batch['labels']) 38 | return batch 39 | 40 | def prepare_test_dataset_10ms(batch): 41 | 42 | batch["input_values"] = batch['file'] 43 | batch["labels"] = [mapping_phone2id[p] for p in batch['frame_labels_10ms']] 44 | # batch["labels"] = [mapping_phone2id[p] for p in batch['labels']] 45 | assert len(batch['frame_labels_10ms']) == len(batch['labels']) 46 | return batch 47 | 48 | def prepare_dataset_cv(batch): 49 | 50 | batch["input_values"] = batch['path'].replace('.mp3','.wav') 51 | batch["labels"] = [mapping_phone2id[p] for p in batch['labels']] 52 | assert len(batch['labels']) == len(batch['labels']) 53 | return batch 54 | 55 | 56 | def audio_preprocess(path): 57 | 58 | features,sr = sf.read(path) 59 | assert sr == 16000 60 | return processor(features, sampling_rate=16000).input_values.squeeze() 61 | 62 | 63 | 64 | @dataclass 65 | class DataCollatorClassificationWithPadding: 66 | """ 67 | Data collator that will dynamically pad the inputs received. 68 | Args: 69 | processor (:class:`~transformers.Wav2Vec2Processor`) 70 | The processor used for proccessing the data. 71 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 72 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 73 | among: 74 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 75 | sequence if provided). 76 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 77 | maximum acceptable input length for the model if that argument is not provided. 78 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 79 | different lengths). 80 | max_length (:obj:`int`, `optional`): 81 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). 82 | max_length_labels (:obj:`int`, `optional`): 83 | Maximum length of the ``labels`` returned list and optionally padding length (see above). 84 | pad_to_multiple_of (:obj:`int`, `optional`): 85 | If set will pad the sequence to a multiple of the provided value. 86 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 87 | 7.5 (Volta). 88 | """ 89 | 90 | processor: Wav2Vec2Processor 91 | padding: Union[bool, str] = True 92 | return_attention_mask: Optional[bool] = True 93 | max_length: Optional[int] = None 94 | max_length_labels: Optional[int] = None 95 | pad_to_multiple_of: Optional[int] = None 96 | pad_to_multiple_of_labels: Optional[int] = None 97 | 98 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 99 | # split inputs and labels since they have to be of different lenghts and need 100 | # different padding methods 101 | input_features = [{"input_values": audio_preprocess(feature["input_values"])} for feature in features] 102 | label_features = [{"input_ids": feature["labels"]} for feature in features] 103 | 104 | batch = self.processor.pad( 105 | input_features, 106 | padding=self.padding, 107 | max_length=self.max_length, 108 | pad_to_multiple_of=self.pad_to_multiple_of, 109 | return_attention_mask=self.return_attention_mask, 110 | return_tensors="pt", 111 | ) 112 | with self.processor.as_target_processor(): 113 | labels_batch = self.processor.pad( 114 | label_features, 115 | padding=self.padding, 116 | max_length=self.max_length_labels, 117 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 118 | return_tensors="pt", 119 | ) 120 | 121 | # replace padding with -100 to ignore loss correctly 122 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) 123 | 124 | batch["labels"] = labels 125 | 126 | return batch 127 | 128 | class Wav2Vec2ForFrameClassification(Wav2Vec2ForCTC): 129 | 130 | def forward( 131 | self, 132 | input_values, 133 | attention_mask=None, 134 | output_attentions=None, 135 | output_hidden_states=None, 136 | return_dict=None, 137 | labels=None, 138 | ): 139 | 140 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 141 | 142 | outputs = self.wav2vec2( 143 | input_values, 144 | attention_mask=attention_mask, 145 | output_attentions=output_attentions, 146 | output_hidden_states=output_hidden_states, 147 | return_dict=return_dict, 148 | ) 149 | 150 | hidden_states = outputs[0] 151 | hidden_states = self.dropout(hidden_states) 152 | 153 | logits = self.lm_head(hidden_states) 154 | 155 | loss = None 156 | if labels is not None: 157 | 158 | if labels.max() >= self.config.vocab_size: 159 | raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") 160 | 161 | # retrieve loss input_lengths from attention_mask 162 | attention_mask = ( 163 | attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) 164 | ) 165 | input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) 166 | 167 | loss = torch.nn.functional.cross_entropy(logits.view(-1,logits.size(2)), labels.flatten(), reduction="mean") 168 | 169 | 170 | 171 | if not return_dict: 172 | output = (logits,) + outputs[2:] 173 | return ((loss,) + output) if loss is not None else output 174 | 175 | return CausalLMOutput( 176 | loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions 177 | ) 178 | 179 | 180 | def compute_metrics(pred): 181 | pred_logits = pred.predictions 182 | pred_ids = np.argmax(pred_logits, axis=-1) 183 | 184 | # pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id 185 | 186 | comparison = (pred_ids == pred.label_ids) 187 | comparison = comparison[pred.label_ids != -100].flatten() 188 | acc = np.sum(comparison)/len(comparison) 189 | 190 | return {"phone_accuracy": acc} 191 | 192 | 193 | 194 | 195 | 196 | 197 | tokenizer = Wav2Vec2CTCTokenizer("./dict/vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="") 198 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 199 | processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 200 | 201 | mapping_phone2id = json.load(open("./dict/vocab.json",'r')) 202 | mapping_id2phone = {v:k for k,v in mapping_phone2id.items()} 203 | 204 | 205 | 206 | if __name__ == "__main__": 207 | 208 | frameshift = 10 209 | 210 | if frameshift == 10: 211 | prepare_dataset = prepare_dataset_10ms 212 | # prepare_dataset = prepare_dataset_cv 213 | else: 214 | prepare_dataset = prepare_dataset_20ms 215 | 216 | # loading data 217 | # libris = load_from_disk('/shared/2/datasets/speech/librispeech_asr/librispeech_full') 218 | # libris = load_from_disk('/shared/2/datasets/speech/common_voice/common_voice_align') 219 | libris = load_from_disk('/shared/2/datasets/speech/librispeech_asr/librispeech_360_align') 220 | libris_train_prepared = libris.map(prepare_dataset,batched=False) 221 | 222 | 223 | # libris_train_prepared = libris_prepared.filter(lambda x: bool(re.search('train-clean-360',x['file']))) 224 | 225 | libris_val = load_from_disk('/shared/2/datasets/speech/librispeech_asr/librispeech_full_test') 226 | libris_val = libris_val.select([i for i in range(200)]) 227 | libris_val_prepared = libris_val.map(prepare_test_dataset_10ms,batched=False) 228 | 229 | 230 | # libris_test_prepared = libris_prepared.filter(lambda x: bool(re.search('test-clean',x['file']))) 231 | 232 | 233 | 234 | 235 | # data loader 236 | data_collator = DataCollatorClassificationWithPadding(processor=processor, padding=True) 237 | 238 | mode = 'base' 239 | 240 | # load model 241 | if mode == 'tiny': 242 | config = Wav2Vec2Config() 243 | config.num_attention_heads = 6 244 | config.hidden_size = 384 245 | config.num_hidden_layers = 6 246 | config.vocab_size = len(processor.tokenizer) 247 | model = Wav2Vec2ForFrameClassification(config) 248 | 249 | # load pretrained weights 250 | 251 | pretrained_model = Wav2Vec2ForFrameClassification.from_pretrained( 252 | "facebook/wav2vec2-base", 253 | gradient_checkpointing=True, 254 | pad_token_id=processor.tokenizer.pad_token_id, 255 | vocab_size = len(processor.tokenizer) 256 | ) 257 | 258 | layers = {'wav2vec2.feature_extractor.conv_layers.0.conv.weight', 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.weight', 'wav2vec2.feature_extractor.conv_layers.0.layer_norm.bias', 'wav2vec2.feature_extractor.conv_layers.1.conv.weight', 'wav2vec2.feature_extractor.conv_layers.2.conv.weight', 'wav2vec2.feature_extractor.conv_layers.3.conv.weight', 'wav2vec2.feature_extractor.conv_layers.4.conv.weight', 'wav2vec2.feature_extractor.conv_layers.5.conv.weight', 'wav2vec2.feature_extractor.conv_layers.6.conv.weight','quantizer.codevectors', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_q.weight', 'project_q.bias'} 259 | pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in layers} 260 | del pretrained_model 261 | 262 | # update pretrained weights 263 | state_dict = model.state_dict() 264 | state_dict.update(pretrained_dict) 265 | model.load_state_dict(state_dict) 266 | 267 | elif mode == 'base': 268 | model = Wav2Vec2ForFrameClassification.from_pretrained( 269 | "facebook/wav2vec2-base", 270 | gradient_checkpointing=True, 271 | pad_token_id=processor.tokenizer.pad_token_id, 272 | vocab_size = len(processor.tokenizer) 273 | ) 274 | 275 | 276 | # freeze convolutional layers and set the stride of the last conv layer to 1 277 | # this increase the sampling frequency to 98 Hz 278 | model.wav2vec2.feature_extractor.conv_layers[6].conv.stride = (1,) 279 | model.config.conv_stride[-1] = 1 280 | model.freeze_feature_extractor() 281 | 282 | 283 | # training settings 284 | training_args = TrainingArguments( 285 | output_dir="/shared/2/projects/phone_segmentation/models/wav2vec2-base-FC-10ms-libris-iter3", 286 | group_by_length=True, 287 | per_device_train_batch_size=2, 288 | per_device_eval_batch_size=4, 289 | gradient_accumulation_steps=8, 290 | evaluation_strategy="steps", 291 | num_train_epochs=2, 292 | fp16=True, 293 | save_steps=500, 294 | eval_steps=500, 295 | logging_steps=500, 296 | learning_rate=3e-4, 297 | weight_decay=0.0001, 298 | warmup_steps=1000, 299 | save_total_limit=2, 300 | ) 301 | 302 | 303 | trainer = Trainer( 304 | model=model, 305 | data_collator=data_collator, 306 | args=training_args, 307 | compute_metrics=compute_metrics, 308 | train_dataset=libris_train_prepared, 309 | eval_dataset=libris_val_prepared, 310 | tokenizer=processor.feature_extractor, 311 | ) 312 | 313 | 314 | trainer.train() 315 | 316 | 317 | -------------------------------------------------------------------------------- /local/SA1.TXT: -------------------------------------------------------------------------------- 1 | 0 55911 She had your dark suit in greasy wash water all year. 2 | -------------------------------------------------------------------------------- /local/SA1.TextGrid: -------------------------------------------------------------------------------- 1 | File type = "ooTextFile" 2 | Object class = "TextGrid" 3 | 4 | 0 5 | 3.47 6 | 7 | 2 8 | "IntervalTier" 9 | "phones" 10 | 0 11 | 3.47 12 | 35 13 | 0 14 | 0.1 15 | "[SIL]" 16 | 0.1 17 | 0.25 18 | "SH" 19 | 0.25 20 | 0.33 21 | "IY" 22 | 0.33 23 | 0.41 24 | "HH" 25 | 0.41 26 | 0.54 27 | "AE" 28 | 0.54 29 | 0.6 30 | "D" 31 | 0.6 32 | 0.67 33 | "Y" 34 | 0.67 35 | 0.75 36 | "R" 37 | 0.75 38 | 0.86 39 | "D" 40 | 0.86 41 | 0.96 42 | "AA" 43 | 0.96 44 | 1.01 45 | "R" 46 | 1.01 47 | 1.1 48 | "K" 49 | 1.1 50 | 1.24 51 | "S" 52 | 1.24 53 | 1.35 54 | "UW" 55 | 1.35 56 | 1.38 57 | "T" 58 | 1.38 59 | 1.49 60 | "IH" 61 | 1.49 62 | 1.58 63 | "N" 64 | 1.58 65 | 1.66 66 | "G" 67 | 1.66 68 | 1.73 69 | "R" 70 | 1.73 71 | 1.83 72 | "IY" 73 | 1.83 74 | 1.93 75 | "S" 76 | 1.93 77 | 2.01 78 | "IY" 79 | 2.01 80 | 2.12 81 | "W" 82 | 2.12 83 | 2.27 84 | "AA" 85 | 2.27 86 | 2.38 87 | "SH" 88 | 2.38 89 | 2.43 90 | "W" 91 | 2.43 92 | 2.56 93 | "AO" 94 | 2.56 95 | 2.64 96 | "T" 97 | 2.64 98 | 2.76 99 | "ER" 100 | 2.76 101 | 2.93 102 | "AO" 103 | 2.93 104 | 2.99 105 | "L" 106 | 2.99 107 | 3.11 108 | "Y" 109 | 3.11 110 | 3.22 111 | "IH" 112 | 3.22 113 | 3.43 114 | "R" 115 | 3.43 116 | 3.47 117 | "[SIL]" 118 | "IntervalTier" 119 | "words" 120 | 0 121 | 3.47 122 | 13 123 | 0 124 | 0.1 125 | "[SIL]" 126 | 0.1 127 | 0.33 128 | "she" 129 | 0.33 130 | 0.6 131 | "had" 132 | 0.6 133 | 0.75 134 | "your" 135 | 0.75 136 | 1.1 137 | "dark" 138 | 1.1 139 | 1.38 140 | "suit" 141 | 1.38 142 | 1.58 143 | "in" 144 | 1.58 145 | 2.01 146 | "greasy" 147 | 2.01 148 | 2.38 149 | "wash" 150 | 2.38 151 | 2.76 152 | "water" 153 | 2.76 154 | 2.99 155 | "all" 156 | 2.99 157 | 3.43 158 | "year" 159 | 3.43 160 | 3.47 161 | "[SIL]" 162 | -------------------------------------------------------------------------------- /local/SA1.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingjzhu/charsiu/13a69f2a22ca0c0962b75cc693399b0ae23a12c9/local/SA1.WAV -------------------------------------------------------------------------------- /local/SSB00050015.TextGrid: -------------------------------------------------------------------------------- 1 | File type = "ooTextFile" 2 | Object class = "TextGrid" 3 | 4 | 0 5 | 5.42 6 | 7 | 2 8 | "IntervalTier" 9 | "phones" 10 | 0 11 | 5.42 12 | 31 13 | 0 14 | 0.5 15 | "[SIL]" 16 | 0.5 17 | 0.59 18 | "j" 19 | 0.59 20 | 0.8 21 | "ing1" 22 | 0.8 23 | 0.89 24 | "g" 25 | 0.89 26 | 1.07 27 | "uang3" 28 | 1.07 29 | 1.19 30 | "zh" 31 | 1.19 32 | 1.36 33 | "ou1" 34 | 1.36 35 | 1.54 36 | "r" 37 | 1.54 38 | 1.69 39 | "iii4" 40 | 1.69 41 | 1.83 42 | "b" 43 | 1.83 44 | 2.01 45 | "ao4" 46 | 2.01 47 | 2.17 48 | "b" 49 | 2.17 50 | 2.32 51 | "ao4" 52 | 2.32 53 | 2.43 54 | "d" 55 | 2.43 56 | 2.64 57 | "ao4" 58 | 2.64 59 | 2.81 60 | "h" 61 | 2.81 62 | 2.96 63 | "ou4" 64 | 2.96 65 | 3.13 66 | "ch" 67 | 3.13 68 | 3.24 69 | "eng2" 70 | 3.24 71 | 3.54 72 | "ui2" 73 | 3.54 74 | 3.68 75 | "l" 76 | 3.68 77 | 3.81 78 | "e5" 79 | 3.81 80 | 3.99 81 | "sh" 82 | 3.99 83 | 4.14 84 | "e4" 85 | 4.14 86 | 4.26 87 | "h" 88 | 4.26 89 | 4.43 90 | "ui4" 91 | 4.43 92 | 4.55 93 | "r" 94 | 4.55 95 | 4.63 96 | "e4" 97 | 4.63 98 | 4.73 99 | "d" 100 | 4.73 101 | 4.93 102 | "ian3" 103 | 4.93 104 | 5.42 105 | "[SIL]" 106 | "IntervalTier" 107 | "words" 108 | 0 109 | 5.42 110 | 17 111 | 0 112 | 0.5 113 | "[SIL]" 114 | 0.5 115 | 0.8 116 | "经" 117 | 0.8 118 | 1.07 119 | "广" 120 | 1.07 121 | 1.36 122 | "州" 123 | 1.36 124 | 1.69 125 | "日" 126 | 1.69 127 | 2.01 128 | "报" 129 | 2.01 130 | 2.32 131 | "报" 132 | 2.32 133 | 2.64 134 | "道" 135 | 2.64 136 | 2.96 137 | "后" 138 | 2.96 139 | 3.24 140 | "成" 141 | 3.24 142 | 3.54 143 | "为" 144 | 3.54 145 | 3.81 146 | "了" 147 | 3.81 148 | 4.14 149 | "社" 150 | 4.14 151 | 4.43 152 | "会" 153 | 4.43 154 | 4.63 155 | "热" 156 | 4.63 157 | 4.93 158 | "点" 159 | 4.93 160 | 5.42 161 | "[SIL]" 162 | -------------------------------------------------------------------------------- /local/SSB00050015_16k.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingjzhu/charsiu/13a69f2a22ca0c0962b75cc693399b0ae23a12c9/local/SSB00050015_16k.wav -------------------------------------------------------------------------------- /local/SSB16240001.TextGrid: -------------------------------------------------------------------------------- 1 | File type = "ooTextFile" 2 | Object class = "TextGrid" 3 | 4 | 0 5 | 1.71 6 | 7 | 1 8 | "IntervalTier" 9 | "phones" 10 | 0 11 | 1.71 12 | 12 13 | 0 14 | 0.28 15 | "[SIL]" 16 | 0.28 17 | 0.36 18 | "zh" 19 | 0.36 20 | 0.47 21 | "e4" 22 | 0.47 23 | 0.56 24 | "j" 25 | 0.56 26 | 0.64 27 | "ve2" 28 | 0.64 29 | 0.72 30 | "d" 31 | 0.72 32 | 0.8 33 | "ui4" 34 | 0.8 35 | 0.9 36 | "b" 37 | 0.9 38 | 0.98 39 | "u4" 40 | 0.98 41 | 1.11 42 | "x" 43 | 1.11 44 | 1.27 45 | "ing2" 46 | 1.27 47 | 1.71 48 | "[SIL]" 49 | -------------------------------------------------------------------------------- /local/SSB16240001_16k.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingjzhu/charsiu/13a69f2a22ca0c0962b75cc693399b0ae23a12c9/local/SSB16240001_16k.wav -------------------------------------------------------------------------------- /local/good_ali.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingjzhu/charsiu/13a69f2a22ca0c0962b75cc693399b0ae23a12c9/local/good_ali.pdf -------------------------------------------------------------------------------- /local/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lingjzhu/charsiu/13a69f2a22ca0c0962b75cc693399b0ae23a12c9/local/image.png -------------------------------------------------------------------------------- /local/sample.TextGrid: -------------------------------------------------------------------------------- 1 | File type = "ooTextFile" 2 | Object class = "TextGrid" 3 | 4 | 0 5 | 6.32 6 | 7 | 2 8 | "IntervalTier" 9 | "phones" 10 | 0 11 | 6.32 12 | 39 13 | 0 14 | 0.52 15 | "[SIL]" 16 | 0.52 17 | 0.59 18 | "e4" 19 | 0.59 20 | 0.68 21 | "rr" 22 | 0.68 23 | 0.78 24 | "x" 25 | 0.78 26 | 0.89 27 | "ian4" 28 | 0.89 29 | 0.98 30 | "ch" 31 | 0.98 32 | 1.04 33 | "eng2" 34 | 1.04 35 | 1.14 36 | "sh" 37 | 1.14 38 | 1.23 39 | "iii4" 40 | 1.23 41 | 1.35 42 | "g" 43 | 1.35 44 | 1.39 45 | "ou4" 46 | 1.39 47 | 1.51 48 | "r" 49 | 1.51 50 | 1.6 51 | "u4" 52 | 1.6 53 | 1.7 54 | "t" 55 | 1.7 56 | 1.76 57 | "u3" 58 | 1.76 59 | 1.87 60 | "d" 61 | 1.87 62 | 2.17 63 | "i4" 64 | 2.17 65 | 2.82 66 | "[SIL]" 67 | 2.82 68 | 2.92 69 | "e4" 70 | 2.92 71 | 3.05 72 | "rr" 73 | 3.05 74 | 3.18 75 | "q" 76 | 3.18 77 | 3.48 78 | "ian1" 79 | 3.48 80 | 3.68 81 | "[SIL]" 82 | 3.68 83 | 3.84 84 | "j" 85 | 3.84 86 | 3.91 87 | "iou3" 88 | 3.91 89 | 4.03 90 | "b" 91 | 4.03 92 | 4.31 93 | "ai3" 94 | 4.31 95 | 4.44 96 | "[SIL]" 97 | 4.44 98 | 4.58 99 | "u3" 100 | 4.58 101 | 4.67 102 | "sh" 103 | 4.67 104 | 4.79 105 | "iii2" 106 | 4.79 107 | 4.9 108 | "j" 109 | 4.9 110 | 5.01 111 | "iou3" 112 | 5.01 113 | 5.28 114 | "uan4" 115 | 5.28 116 | 5.39 117 | "p" 118 | 5.39 119 | 5.6 120 | "ing2" 121 | 5.6 122 | 5.61 123 | "m" 124 | 5.61 125 | 5.62 126 | "i3" 127 | 5.62 128 | 6.32 129 | "[SIL]" 130 | "IntervalTier" 131 | "words" 132 | 0 133 | 6.32 134 | 23 135 | 0 136 | 0.52 137 | "[SIL]" 138 | 0.52 139 | 0.68 140 | "二" 141 | 0.68 142 | 0.89 143 | "线" 144 | 0.89 145 | 1.04 146 | "城" 147 | 1.04 148 | 1.23 149 | "市" 150 | 1.23 151 | 1.39 152 | "购" 153 | 1.39 154 | 1.6 155 | "入" 156 | 1.6 157 | 1.76 158 | "土" 159 | 1.76 160 | 2.17 161 | "地" 162 | 2.17 163 | 2.82 164 | "[SIL]" 165 | 2.82 166 | 3.05 167 | "二" 168 | 3.05 169 | 3.48 170 | "千" 171 | 3.48 172 | 3.68 173 | "[SIL]" 174 | 3.68 175 | 3.91 176 | "九" 177 | 3.91 178 | 4.31 179 | "百" 180 | 4.31 181 | 4.44 182 | "[SIL]" 183 | 4.44 184 | 4.58 185 | "五" 186 | 4.58 187 | 4.79 188 | "十" 189 | 4.79 190 | 5.01 191 | "九" 192 | 5.01 193 | 5.28 194 | "万" 195 | 5.28 196 | 5.6 197 | "平" 198 | 5.6 199 | 5.62 200 | "米" 201 | 5.62 202 | 6.32 203 | "[SIL]" 204 | -------------------------------------------------------------------------------- /local/vocab-ctc.json: -------------------------------------------------------------------------------- 1 | {"[SIL]": 0, "NG": 1, "F": 2, "M": 3, "AE": 4, "R": 5, "UW": 6, "N": 7, "IY": 8, "AW": 9, "V": 10, "UH": 11, "OW": 12, "AA": 13, "ER": 14, "HH": 15, "Z": 16, "K": 17, "CH": 18, "W": 19, "EY": 20, "ZH": 21, "T": 22, "EH": 23, "Y": 24, "AH": 25, "B": 26, "P": 27, "TH": 28, "DH": 29, "AO": 30, "G": 31, "L": 32, "JH": 33, "OY": 34, "SH": 35, "D": 36, "AY": 37, "S": 38, "IH": 39, "[UNK]": 40, "[PAD]": 41} -------------------------------------------------------------------------------- /local/vocab.json: -------------------------------------------------------------------------------- 1 | {"[SIL]": 0, "NG": 1, "F": 2, "M": 3, "AE": 4, "R": 5, "UW": 6, "N": 7, "IY": 8, "AW": 9, "V": 10, "UH": 11, "OW": 12, "AA": 13, "ER": 14, "HH": 15, "Z": 16, "K": 17, "CH": 18, "W": 19, "EY": 20, "ZH": 21, "T": 22, "EH": 23, "Y": 24, "AH": 25, "B": 26, "P": 27, "TH": 28, "DH": 29, "AO": 30, "G": 31, "L": 32, "JH": 33, "OY": 34, "SH": 35, "D": 36, "AY": 37, "S": 38, "IH": 39, "[UNK]": 40, "[PAD]": 41} -------------------------------------------------------------------------------- /misc/data.md: -------------------------------------------------------------------------------- 1 | # Forced-aligned Speech Datasets 2 | 3 | Here we release **TextGrids** for several datasets that have been forced aligned with [Charsiu Forced Aligner](https://github.com/lingjzhu/charsiu). Hoepfully they might be helpful for your research. 4 | Forced alignment does not generate perfect alignments. **Use at you own discrection**. 5 | 6 | [English](data.md#alignments-for-english-datasets) 7 | [Mandarin](data.md#alignments-for-mandarin-speech-datasets) 8 | 9 | 10 | Please cite this if you use these alignments in your research projects. 11 | ``` 12 | @article{zhu2022charsiu, 13 | title={Phone-to-audio alignment without text: A Semi-supervised Approach}, 14 | author={Zhu, Jian and Zhang, Cong and Jurgens, David}, 15 | journal={IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 16 | year={2022}, 17 | url=https://arxiv.org/abs/2110.03876, 18 | } 19 | ``` 20 | 21 | ## Alignments for English datasets 22 | 23 | ### Textgrids 24 | You can find [all textgrids for the trianing set (~860k utterances) in this google drive folder](https://drive.google.com/drive/folders/1IF0WB5-8VXfaENtE4r5rehHHK8YFe61S?usp=sharing). It contains phone- and word-level alignments for the English subset of Common Voice (~2000 hours). It was aligned with `charsiu_forced_aligner` with the model `charsiu/en_w2v2_fc_10ms`. 25 | All filenames are matched. Only a few mismatched samples were discarded. 26 | 27 | ### Audio 28 | You can find the dataset at the [Common Voice Project](https://commonvoice.mozilla.org/en/datasets). 29 | 30 | The audio data can also be easily accessed through [the Common Voce page](https://huggingface.co/datasets/mozilla-foundation/common_voice_8_0) at HuggingFace hub. An account is needed for authentication. 31 | ``` 32 | from datasets import load_dataset 33 | 34 | dataset = load_dataset("mozilla-foundation/common_voice_8_0", "en", split='train',use_auth_token=True) 35 | ``` 36 | Note that ~80GB of memory is needed to load the dataset into memory. 37 | 38 | Please cite Common Voice if you use this dataset. 39 | ``` 40 | @inproceedings{commonvoice:2020, 41 | author = {Ardila, R. and Branson, M. and Davis, K. and Henretty, M. and Kohler, M. and Meyer, J. and Morais, R. and Saunders, L. and Tyers, F. M. and Weber, G.}, 42 | title = {Common Voice: A Massively-Multilingual Speech Corpus}, 43 | booktitle = {Proceedings of the 12th Conference on Language Resources and Evaluation (LREC 2020)}, 44 | pages = {4211--4215}, 45 | year = 2020 46 | } 47 | ``` 48 | 49 | The grapheme-to-phoneme conversion was done automatically with [`g2p_en`](https://github.com/Kyubyong/g2p). 50 | ``` 51 | @misc{g2pE2019, 52 | author = {Park, Kyubyong & Kim, Jongseok}, 53 | title = {g2pE}, 54 | year = {2019}, 55 | publisher = {GitHub}, 56 | journal = {GitHub repository}, 57 | howpublished = {\url{https://github.com/Kyubyong/g2p}} 58 | } 59 | ``` 60 | 61 | ## Alignments for Mandarin Speech datasets 62 | 63 | This repository contains phone- and word-level alignments for multiple Mandarin Chinese speech datasets, including MagicData (~755 hours), Aishell-1 (~150 hours), STCMDS (~100 hours), Datatang (~200 hours), THCHS-30 (~30 hours) and PrimeWords (~100 hours). 64 | 65 | ### Textgrids 66 | You can download all textgrids [here](https://drive.google.com/drive/folders/1IF0WB5-8VXfaENtE4r5rehHHK8YFe61S?usp=sharing). The forced alignment was done with `charsiu_forced_aligner` using model `charsiu/zh_xlsr_fc_10ms`. Only Praat textgrid files are distributed. Sentences with Englist letters and numbers were all removed. Misaligned files were also discarded. 67 | 68 | The grapheme-to-phoneme conversion was done automatically with [`g2pM`](https://github.com/kakaobrain/g2pM). 69 | ``` 70 | @article{park2020g2pm, 71 | author={Park, Kyubyong and Lee, Seanie}, 72 | title = {A Neural Grapheme-to-Phoneme Conversion Package for Mandarin Chinese Based on a New Open Benchmark Dataset 73 | }, 74 | journal={Proc. Interspeech 2020}, 75 | url = {https://arxiv.org/abs/2004.03136}, 76 | year = {2020} 77 | } 78 | ``` 79 | 80 | 81 | 82 | 83 | ### Audio data 84 | The original audio data can be downloaded via OpenSLR. All filenames are matched. Please also cite the original datasets. 85 | 86 | Aishell-1 (~150 hours): https://openslr.org/33/ 87 | ``` 88 | @inproceedings{aishell_2017, 89 | title={AIShell-1: An Open-Source Mandarin Speech Corpus and A Speech Recognition Baseline}, 90 | author={Hui Bu, Jiayu Du, Xingyu Na, Bengu Wu, Hao Zheng}, 91 | booktitle={Oriental COCOSDA 2017}, 92 | pages={Submitted}, 93 | year={2017} 94 | } 95 | ``` 96 | 97 | MagicData (~755 hours): https://openslr.org/68/ 98 | ``` 99 | Please cite the corpus as "Magic Data Technology Co., Ltd., "http://www.imagicdatatech.com/index.php/home/dataopensource/data_info/id/101", 05/2019". 100 | ``` 101 | 102 | Datatang (~200 hours): https://openslr.org/62/ 103 | ``` 104 | Please cite the corpus as “aidatatang_200zh, a free Chinese Mandarin speech corpus by Beijing DataTang Technology Co., Ltd ( www.datatang.com )”. 105 | ``` 106 | STCMDS (~100 hours): https://openslr.org/38/ 107 | ``` 108 | Please cite the data as “ST-CMDS-20170001_1, Free ST Chinese Mandarin Corpus”. 109 | ``` 110 | PrimeWords (~100 hours): https://openslr.org/47/ 111 | ``` 112 | @misc{primewords_201801, 113 | title={Primewords Chinese Corpus Set 1}, 114 | author={Primewords Information Technology Co., Ltd.}, 115 | year={2018}, 116 | note={\url{https://www.primewords.cn}} 117 | } 118 | 119 | ``` 120 | THCHS-30 (~30 hours): https://openslr.org/18/ 121 | ``` 122 | @misc{THCHS30_2015, 123 | title={THCHS-30 : A Free Chinese Speech Corpus}, 124 | author={Dong Wang, Xuewei Zhang, Zhiyong Zhang}, 125 | year={2015}, 126 | url={http://arxiv.org/abs/1512.01882} 127 | } 128 | ``` 129 | -------------------------------------------------------------------------------- /src/Charsiu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import sys 6 | import torch 7 | from itertools import groupby 8 | sys.path.append('src/') 9 | import numpy as np 10 | #sys.path.insert(0,'src') 11 | from models import Wav2Vec2ForAttentionAlignment, Wav2Vec2ForFrameClassification, Wav2Vec2ForCTC 12 | from utils import seq2duration,forced_align,duration2textgrid,word2textgrid 13 | from processors import CharsiuPreprocessor_zh, CharsiuPreprocessor_en 14 | 15 | processors = {'zh':CharsiuPreprocessor_zh, 16 | 'en':CharsiuPreprocessor_en} 17 | 18 | class charsiu_aligner: 19 | 20 | def __init__(self, 21 | lang='en', 22 | sampling_rate=16000, 23 | device=None, 24 | recognizer=None, 25 | processor=None, 26 | resolution=0.01): 27 | 28 | self.lang = lang 29 | 30 | if processor is not None: 31 | self.processor = processor 32 | else: 33 | self.charsiu_processor = processors[self.lang]() 34 | 35 | 36 | 37 | self.resolution = resolution 38 | 39 | self.sr = sampling_rate 40 | 41 | self.recognizer = recognizer 42 | 43 | if device is None: 44 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 45 | else: 46 | self.device = device 47 | 48 | 49 | def _freeze_model(self): 50 | self.aligner.eval().to(self.device) 51 | if self.recognizer is not None: 52 | self.recognizer.eval().to(self.device) 53 | 54 | 55 | 56 | def align(self,audio,text): 57 | raise NotImplementedError() 58 | 59 | 60 | 61 | def serve(self,audio,save_to,output_format='variable',text=None): 62 | raise NotImplementedError() 63 | 64 | 65 | def _to_textgrid(self,phones,save_to): 66 | ''' 67 | Convert output tuples to a textgrid file 68 | 69 | Parameters 70 | ---------- 71 | phones : TYPE 72 | DESCRIPTION. 73 | 74 | Returns 75 | ------- 76 | None. 77 | 78 | ''' 79 | duration2textgrid(phones,save_path=save_to) 80 | print('Alignment output has been saved to %s'%(save_to)) 81 | 82 | 83 | 84 | def _to_tsv(self,phones,save_to): 85 | ''' 86 | Convert output tuples to a tab-separated file 87 | 88 | Parameters 89 | ---------- 90 | phones : TYPE 91 | DESCRIPTION. 92 | 93 | Returns 94 | ------- 95 | None. 96 | 97 | ''' 98 | with open(save_to,'w') as f: 99 | for start,end,phone in phones: 100 | f.write('%s\t%s\t%s\n'%(start,end,phone)) 101 | print('Alignment output has been saved to %s'%(save_to)) 102 | 103 | 104 | 105 | 106 | 107 | class charsiu_forced_aligner(charsiu_aligner): 108 | 109 | def __init__(self, aligner, sil_threshold=4, **kwargs): 110 | super(charsiu_forced_aligner, self).__init__(**kwargs) 111 | self.aligner = Wav2Vec2ForFrameClassification.from_pretrained(aligner) 112 | self.sil_threshold = sil_threshold 113 | 114 | self._freeze_model() 115 | 116 | 117 | def align(self, audio, text): 118 | ''' 119 | Perform forced alignment 120 | 121 | Parameters 122 | ---------- 123 | audio : np.ndarray [shape=(n,)] 124 | time series of speech signal 125 | text : str 126 | The transcription 127 | 128 | Returns 129 | ------- 130 | A tuple of aligned phones in the form (start_time, end_time, phone) 131 | 132 | ''' 133 | audio = self.charsiu_processor.audio_preprocess(audio,sr=self.sr) 134 | audio = torch.Tensor(audio).unsqueeze(0).to(self.device) 135 | phones, words = self.charsiu_processor.get_phones_and_words(text) 136 | phone_ids = self.charsiu_processor.get_phone_ids(phones) 137 | 138 | with torch.no_grad(): 139 | out = self.aligner(audio) 140 | cost = torch.softmax(out.logits,dim=-1).detach().cpu().numpy().squeeze() 141 | 142 | 143 | sil_mask = self._get_sil_mask(cost) 144 | 145 | nonsil_idx = np.argwhere(sil_mask!=self.charsiu_processor.sil_idx).squeeze() 146 | if nonsil_idx is None: 147 | raise Exception("No speech detected! Please check the audio file!") 148 | 149 | aligned_phone_ids = forced_align(cost[nonsil_idx,:],phone_ids[1:-1]) 150 | 151 | aligned_phones = [self.charsiu_processor.mapping_id2phone(phone_ids[1:-1][i]) for i in aligned_phone_ids] 152 | 153 | pred_phones = self._merge_silence(aligned_phones,sil_mask) 154 | 155 | pred_phones = seq2duration(pred_phones,resolution=self.resolution) 156 | 157 | pred_words = self.charsiu_processor.align_words(pred_phones,phones,words) 158 | return pred_phones, pred_words 159 | 160 | 161 | def serve(self,audio,text,save_to,output_format='textgrid'): 162 | ''' 163 | A wrapper function for quick inference 164 | 165 | Parameters 166 | ---------- 167 | audio : TYPE 168 | DESCRIPTION. 169 | text : TYPE, optional 170 | DESCRIPTION. The default is None. 171 | output_format : str, optional 172 | Output phone-taudio alignment as a "tsv" or "textgrid" file. 173 | The default is 'textgrid'. 174 | 175 | Returns 176 | ------- 177 | None. 178 | 179 | ''' 180 | phones, words = self.align(audio,text) 181 | 182 | if output_format == 'tsv': 183 | if save_to.endswith('.tsv'): 184 | save_to_phone = save_to.replace('.tsv','_phone.tsv') 185 | save_to_word = save_to.replace('.tsv','_word.tsv') 186 | else: 187 | save_to_phone = save_to + '_phone.tsv' 188 | save_to_word = save_to + '_word.tsv' 189 | 190 | self._to_tsv(phones, save_to_phone) 191 | self._to_tsv(words, save_to_word) 192 | 193 | elif output_format == 'textgrid': 194 | self._to_textgrid(phones, words, save_to) 195 | else: 196 | raise Exception('Please specify the correct output format (tsv or textgird)!') 197 | 198 | def _to_textgrid(self,phones,words,save_to): 199 | ''' 200 | Convert output tuples to a textgrid file 201 | 202 | Parameters 203 | ---------- 204 | phones : TYPE 205 | DESCRIPTION. 206 | 207 | Returns 208 | ------- 209 | None. 210 | 211 | ''' 212 | word2textgrid(phones,words,save_path=save_to) 213 | print('Alignment output has been saved to %s'%(save_to)) 214 | 215 | 216 | def _merge_silence(self,aligned_phones,sil_mask): 217 | # merge silent and non-silent intervals 218 | pred_phones = [] 219 | count = 0 220 | for i in sil_mask: 221 | if i==self.charsiu_processor.sil_idx: 222 | pred_phones.append('[SIL]') 223 | else: 224 | pred_phones.append(aligned_phones[count]) 225 | count += 1 226 | assert len(pred_phones) == len(sil_mask) 227 | return pred_phones 228 | 229 | 230 | 231 | def _get_sil_mask(self,cost): 232 | # single out silent intervals 233 | 234 | preds = np.argmax(cost,axis=-1) 235 | sil_mask = [] 236 | for key, group in groupby(preds): 237 | group = list(group) 238 | if (key==self.charsiu_processor.sil_idx and len(group)= 38 | 7.5 (Volta). 39 | """ 40 | 41 | processor: Wav2Vec2Processor 42 | padding: Union[bool, str] = True 43 | return_attention_mask: Optional[bool] = True 44 | max_length: Optional[int] = None 45 | max_length_labels: Optional[int] = None 46 | pad_to_multiple_of: Optional[int] = None 47 | pad_to_multiple_of_labels: Optional[int] = None 48 | 49 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 50 | # split inputs and labels since they have to be of different lenghts and need 51 | # different padding methods 52 | input_features = [{"input_values": audio_preprocess(feature["input_values"])} for feature in features] 53 | 54 | batch = self.processor.pad( 55 | input_features, 56 | padding=self.padding, 57 | max_length=self.max_length, 58 | pad_to_multiple_of=self.pad_to_multiple_of, 59 | return_attention_mask=self.return_attention_mask, 60 | return_tensors="pt", 61 | ) 62 | batch_size, raw_sequence_length = batch['input_values'].shape 63 | sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) 64 | batch['mask_time_indices'] = _compute_mask_indices((batch_size, sequence_length), mask_prob=0.1, mask_length=2,device='cpu') 65 | 66 | 67 | return batch 68 | 69 | 70 | 71 | @dataclass 72 | class DataCollatorCTCWithPadding: 73 | """ 74 | Data collator that will dynamically pad the inputs received. 75 | Args: 76 | processor (:class:`~transformers.Wav2Vec2Processor`) 77 | The processor used for proccessing the data. 78 | padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`): 79 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 80 | among: 81 | * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single 82 | sequence if provided). 83 | * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the 84 | maximum acceptable input length for the model if that argument is not provided. 85 | * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of 86 | different lengths). 87 | max_length (:obj:`int`, `optional`): 88 | Maximum length of the ``input_values`` of the returned list and optionally padding length (see above). 89 | max_length_labels (:obj:`int`, `optional`): 90 | Maximum length of the ``labels`` returned list and optionally padding length (see above). 91 | pad_to_multiple_of (:obj:`int`, `optional`): 92 | If set will pad the sequence to a multiple of the provided value. 93 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 94 | 7.5 (Volta). 95 | """ 96 | 97 | processor: Wav2Vec2Processor 98 | padding: Union[bool, str] = True 99 | return_attention_mask: Optional[bool] = True 100 | max_length: Optional[int] = None 101 | max_length_labels: Optional[int] = None 102 | pad_to_multiple_of: Optional[int] = None 103 | pad_to_multiple_of_labels: Optional[int] = None 104 | 105 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 106 | # split inputs and labels since they have to be of different lenghts and need 107 | # different padding methods 108 | input_features = [{"input_values": audio_preprocess(feature["input_values"])} for feature in features] 109 | label_features = [{"input_ids": feature["labels"]} for feature in features] 110 | 111 | batch = self.processor.pad( 112 | input_features, 113 | padding=self.padding, 114 | max_length=self.max_length, 115 | pad_to_multiple_of=self.pad_to_multiple_of, 116 | return_attention_mask=self.return_attention_mask, 117 | return_tensors="pt", 118 | ) 119 | with self.processor.as_target_processor(): 120 | labels_batch = self.processor.pad( 121 | label_features, 122 | padding=self.padding, 123 | max_length=self.max_length_labels, 124 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 125 | return_tensors="pt", 126 | ) 127 | 128 | # replace padding with -100 to ignore loss correctly 129 | labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) 130 | 131 | batch["labels"] = labels 132 | 133 | return batch 134 | 135 | 136 | @dataclass 137 | class SpeechCollatorWithPadding: 138 | 139 | processor: Wav2Vec2Processor 140 | padding: Union[bool, str] = True 141 | return_attention_mask: Optional[bool] = True 142 | max_length: Optional[int] = None 143 | max_length_labels: Optional[int] = None 144 | pad_to_multiple_of: Optional[int] = None 145 | pad_to_multiple_of_labels: Optional[int] = None 146 | 147 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 148 | # split inputs and labels since they have to be of different lenghts and need 149 | # different padding methods 150 | 151 | # get phone features 152 | label_features = [{"input_ids": get_phone_ids(get_phones(feature["text"]))} for feature in features] 153 | text_len = [len(i['input_ids']) for i in label_features] 154 | 155 | with self.processor.as_target_processor(): 156 | labels_batch = self.processor.pad( 157 | label_features, 158 | padding=self.padding, 159 | max_length=self.max_length_labels, 160 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 161 | return_attention_mask=self.return_attention_mask, 162 | return_tensors="pt", 163 | ) 164 | 165 | 166 | # get speech features 167 | input_features = [{"input_values": audio_preprocess(feature["file"])} for feature in features] 168 | mel_len = [model._get_feat_extract_output_lengths(len(i['input_values'])) for i in input_features] 169 | batch = self.processor.pad( 170 | input_features, 171 | padding=self.padding, 172 | max_length=self.max_length, 173 | pad_to_multiple_of=self.pad_to_multiple_of, 174 | return_attention_mask=self.return_attention_mask, 175 | return_tensors="pt", 176 | ) 177 | batch_size, raw_sequence_length = batch['input_values'].shape 178 | sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) 179 | 180 | mask_prob = torch.randint(size=(1,),low=5,high=40)/100 181 | batch['mask_time_indices'] = _compute_mask_indices((batch_size, sequence_length), mask_prob=mask_prob, mask_length=2,device='cpu') 182 | batch['frame_len'] = torch.tensor(mel_len) 183 | batch["text_len"] = torch.tensor(text_len) 184 | batch['labels'] = labels_batch["input_ids"]#.masked_fill(labels_batch.attention_mask.ne(1), -100) 185 | batch['labels_attention_mask'] = labels_batch['attention_mask'] 186 | 187 | return batch 188 | 189 | @dataclass 190 | class CVCollatorWithPadding: 191 | 192 | processor: Wav2Vec2Processor 193 | padding: Union[bool, str] = True 194 | return_attention_mask: Optional[bool] = True 195 | max_length: Optional[int] = None 196 | max_length_labels: Optional[int] = None 197 | pad_to_multiple_of: Optional[int] = None 198 | pad_to_multiple_of_labels: Optional[int] = None 199 | 200 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 201 | # split inputs and labels since they have to be of different lenghts and need 202 | # different padding methods 203 | 204 | # get phone features 205 | label_features = [{"input_ids": get_phone_ids(get_phones(feature["sentence"]))} for feature in features] 206 | text_len = [len(i['input_ids']) for i in label_features] 207 | 208 | with self.processor.as_target_processor(): 209 | labels_batch = self.processor.pad( 210 | label_features, 211 | padding=self.padding, 212 | max_length=self.max_length_labels, 213 | pad_to_multiple_of=self.pad_to_multiple_of_labels, 214 | return_attention_mask=self.return_attention_mask, 215 | return_tensors="pt", 216 | ) 217 | 218 | 219 | # get speech features 220 | input_features = [{"input_values": audio_preprocess(feature["path"].replace('.mp3','.wav'))} for feature in features] 221 | mel_len = [model._get_feat_extract_output_lengths(len(i['input_values'])) for i in input_features] 222 | batch = self.processor.pad( 223 | input_features, 224 | padding=self.padding, 225 | max_length=self.max_length, 226 | pad_to_multiple_of=self.pad_to_multiple_of, 227 | return_attention_mask=self.return_attention_mask, 228 | return_tensors="pt", 229 | ) 230 | batch_size, raw_sequence_length = batch['input_values'].shape 231 | sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length) 232 | 233 | mask_prob = torch.randint(size=(1,),low=5,high=20)/100 234 | batch['mask_time_indices'] = _compute_mask_indices((batch_size, sequence_length), mask_prob=mask_prob, mask_length=2,device='cpu') 235 | batch['frame_len'] = torch.tensor(mel_len) 236 | batch["text_len"] = torch.tensor(text_len) 237 | batch['labels'] = labels_batch["input_ids"]#.masked_fill(labels_batch.attention_mask.ne(1), -100) 238 | batch['labels_attention_mask'] = labels_batch['attention_mask'] 239 | 240 | return batch 241 | -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import random 5 | import torch 6 | import numpy as np 7 | from scipy.signal import find_peaks 8 | from collections import defaultdict, Counter 9 | 10 | 11 | def evaluate_overlap(evaluation_pairs): 12 | 13 | hits = 0 14 | counts = 0 15 | for targets,preds in tqdm(evaluation_pairs): 16 | assert len(targets)==len(preds) 17 | hits += sum(np.array(targets)==np.array(preds)) 18 | counts += len(targets) 19 | 20 | return hits/counts 21 | 22 | 23 | ''' 24 | Code for precision, recall, F1, R-value was adapted from unsupseg: https://github.com/felixkreuk/UnsupSeg 25 | ''' 26 | 27 | def get_metrics(precision_counter, recall_counter, pred_counter, gt_counter): 28 | eps = 1e-7 29 | 30 | precision = precision_counter / (pred_counter + eps) 31 | recall = recall_counter / (gt_counter + eps) 32 | f1 = 2 * (precision * recall) / (precision + recall + eps) 33 | 34 | os = recall / (precision + eps) - 1 35 | r1 = np.sqrt((1 - recall) ** 2 + os ** 2) 36 | r2 = (-os + recall - 1) / (np.sqrt(2)) 37 | rval = 1 - (np.abs(r1) + np.abs(r2)) / 2 38 | 39 | return precision, recall, f1, rval 40 | 41 | 42 | def get_stats(y, y_ids, yhat, yhat_ids, tolerance=0.02): 43 | 44 | precision_counter = 0 45 | recall_counter = 0 46 | pred_counter = 0 47 | gt_counter = 0 48 | 49 | for yhat_i, yhat_id in zip(yhat,yhat_ids): 50 | diff = np.abs(y - yhat_i) 51 | min_dist = diff.min() 52 | min_pos = np.argmin(diff) 53 | intersect = y_ids[min_pos].intersection(yhat_id) 54 | if len(intersect)>0: 55 | precision_counter += (min_dist <= tolerance) 56 | 57 | for y_i,y_id in zip(y,y_ids): 58 | diff = np.abs(yhat - y_i) 59 | min_dist = diff.min() 60 | min_pos = np.argmin(diff) 61 | intersect = yhat_ids[min_pos].intersection(y_id) 62 | if len(intersect)>0: 63 | recall_counter += (min_dist <= tolerance) 64 | 65 | pred_counter += len(yhat) 66 | gt_counter += len(y) 67 | 68 | p, r, f1, rval = get_metrics(precision_counter, 69 | recall_counter, 70 | pred_counter, 71 | gt_counter) 72 | return p, r, f1, rval 73 | 74 | def get_all_stats(evaluation_pairs,tolerance=0.02): 75 | 76 | precision_counter = 0 77 | recall_counter = 0 78 | pred_counter = 0 79 | gt_counter = 0 80 | 81 | for (y,y_ids), (yhat, yhat_ids) in tqdm(evaluation_pairs): 82 | 83 | for yhat_i, yhat_id in zip(yhat,yhat_ids): 84 | diff = np.abs(y - yhat_i) 85 | min_dist = diff.min() 86 | min_pos = np.argmin(diff) 87 | intersect = y_ids[min_pos].intersection(yhat_id) 88 | if len(intersect)>0: 89 | precision_counter += (min_dist <= tolerance) 90 | 91 | for y_i,y_id in zip(y,y_ids): 92 | diff = np.abs(yhat - y_i) 93 | min_dist = diff.min() 94 | min_pos = np.argmin(diff) 95 | intersect = yhat_ids[min_pos].intersection(y_id) 96 | if len(intersect)>0: 97 | recall_counter += (min_dist <= tolerance) 98 | 99 | pred_counter += len(yhat) 100 | gt_counter += len(y) 101 | 102 | p, r, f1, rval = get_metrics(precision_counter, 103 | recall_counter, 104 | pred_counter, 105 | gt_counter) 106 | return p, r, f1, rval 107 | 108 | 109 | 110 | def get_all_stats_boundary_only(evaluation_pairs,tolerance=0.02): 111 | 112 | precision_counter = 0 113 | recall_counter = 0 114 | pred_counter = 0 115 | gt_counter = 0 116 | 117 | for (y,y_ids), (yhat, yhat_ids) in tqdm(evaluation_pairs): 118 | 119 | for yhat_i, yhat_id in zip(yhat,yhat_ids): 120 | diff = np.abs(y - yhat_i) 121 | min_dist = diff.min() 122 | min_pos = np.argmin(diff) 123 | precision_counter += (min_dist <= tolerance) 124 | 125 | for y_i,y_id in zip(y,y_ids): 126 | diff = np.abs(yhat - y_i) 127 | min_dist = diff.min() 128 | min_pos = np.argmin(diff) 129 | recall_counter += (min_dist <= tolerance) 130 | 131 | pred_counter += len(yhat) 132 | gt_counter += len(y) 133 | 134 | p, r, f1, rval = get_metrics(precision_counter, 135 | recall_counter, 136 | pred_counter, 137 | gt_counter) 138 | return p, r, f1, rval 139 | 140 | 141 | def detect_peaks(xi,prominence=0.1, width=None, distance=None): 142 | """detect peaks of next_frame_classifier 143 | 144 | Arguments: 145 | x {Array} -- a sequence of cosine distances 146 | """ 147 | 148 | # shorten to actual length 149 | xmin, xmax = xi.min(), xi.max() 150 | xi = (xi - xmin) / (xmax - xmin) 151 | peaks, _ = find_peaks(xi, prominence=prominence, width=width, distance=distance) 152 | 153 | if len(peaks) == 0: 154 | peaks = np.array([len(xi)-1]) 155 | 156 | 157 | return peaks 158 | 159 | 160 | 161 | 162 | def detect_peaks_in_batch(x, prominence=0.1, width=None, distance=None): 163 | """detect peaks of next_frame_classifier 164 | 165 | Arguments: 166 | x {Array} -- batch of coscine distances per time 167 | """ 168 | out = [] 169 | 170 | for xi in x: 171 | if type(xi) == torch.Tensor: 172 | xi = xi.cpu().detach().numpy() # shorten to actual length 173 | xmin, xmax = xi.min(), xi.max() 174 | xi = (xi - xmin) / (xmax - xmin) 175 | peaks, _ = find_peaks(xi, prominence=prominence, width=width, distance=distance) 176 | 177 | if len(peaks) == 0: 178 | peaks = np.array([len(xi)-1]) 179 | 180 | out.append(peaks) 181 | 182 | return out 183 | 184 | 185 | class PrecisionRecallMetric: 186 | def __init__(self): 187 | self.precision_counter = 0 188 | self.recall_counter = 0 189 | self.pred_counter = 0 190 | self.gt_counter = 0 191 | self.eps = 1e-5 192 | self.data = [] 193 | self.tolerance = 0.02 194 | self.prominence_range = np.arange(0, 0.15, 0.1) 195 | self.width_range = [None, 1] 196 | self.distance_range = [None, 1] 197 | self.resolution = 1/49 198 | 199 | def get_metrics(self, precision_counter, recall_counter, pred_counter, gt_counter): 200 | EPS = 1e-7 201 | 202 | precision = precision_counter / (pred_counter + self.eps) 203 | recall = recall_counter / (gt_counter + self.eps) 204 | f1 = 2 * (precision * recall) / (precision + recall + self.eps) 205 | 206 | os = recall / (precision + EPS) - 1 207 | r1 = np.sqrt((1 - recall) ** 2 + os ** 2) 208 | r2 = (-os + recall - 1) / (np.sqrt(2)) 209 | rval = 1 - (np.abs(r1) + np.abs(r2)) / 2 210 | 211 | return precision, recall, f1, rval 212 | 213 | def zero(self): 214 | self.data = [] 215 | 216 | def update(self, seg, pos_pred,): 217 | for seg_i, pos_pred_i in zip(seg, pos_pred): 218 | self.data.append((seg_i, pos_pred_i)) 219 | 220 | def get_stats_and_search_params(self, width=None, prominence=None, distance=None): 221 | print(f"calculating metrics using {len(self.data)} entries") 222 | max_rval = -float("inf") 223 | best_params = None 224 | segs = list(map(lambda x: x[0], self.data)) 225 | yhats = list(map(lambda x: x[1], self.data)) 226 | 227 | width_range = self.width_range 228 | distance_range = self.distance_range 229 | prominence_range = self.prominence_range 230 | 231 | # when testing, we would override the search with specific values from validation 232 | if prominence is not None: 233 | width_range = [width] 234 | distance_range = [distance] 235 | prominence_range = [prominence] 236 | 237 | for width in width_range: 238 | for prominence in prominence_range: 239 | for distance in distance_range: 240 | precision_counter = 0 241 | recall_counter = 0 242 | pred_counter = 0 243 | gt_counter = 0 244 | peaks = detect_peaks_in_batch(yhats, 245 | prominence=prominence, 246 | width=width, 247 | distance=distance) 248 | 249 | for (y, yhat) in zip(segs, peaks): 250 | for yhat_i in yhat: 251 | min_dist = np.abs(y - yhat_i*self.resolution).min() 252 | precision_counter += (min_dist <= self.tolerance) 253 | for y_i in y: 254 | min_dist = np.abs(yhat*self.resolution - y_i).min() 255 | recall_counter += (min_dist <= self.tolerance) 256 | pred_counter += len(yhat) 257 | gt_counter += len(y) 258 | 259 | p, r, f1, rval = self.get_metrics(precision_counter, 260 | recall_counter, 261 | pred_counter, 262 | gt_counter) 263 | if rval > max_rval: 264 | max_rval = rval 265 | best_params = width, prominence, distance 266 | out = (p, r, f1, rval) 267 | self.zero() 268 | print(f"best peak detection params: {best_params} (width, prominence, distance)") 269 | return out, best_params 270 | 271 | def get_stats(self, y, yhat): 272 | 273 | precision_counter = 0 274 | recall_counter = 0 275 | pred_counter = 0 276 | gt_counter = 0 277 | 278 | 279 | for yhat_i in yhat: 280 | min_dist = np.abs(y - yhat_i).min() 281 | precision_counter += (min_dist <= self.tolerance) 282 | for y_i in y: 283 | min_dist = np.abs(yhat - y_i).min() 284 | recall_counter += (min_dist <= self.tolerance) 285 | pred_counter += len(yhat) 286 | gt_counter += len(y) 287 | 288 | p, r, f1, rval = self.get_metrics(precision_counter, 289 | recall_counter, 290 | pred_counter, 291 | gt_counter) 292 | return p, r, f1, rval 293 | -------------------------------------------------------------------------------- /src/processors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import re 6 | import numpy as np 7 | from itertools import groupby, chain 8 | import soundfile as sf 9 | import librosa.core 10 | import unicodedata 11 | from builtins import str as unicode 12 | from nltk.tokenize import TweetTokenizer 13 | word_tokenize = TweetTokenizer().tokenize 14 | 15 | from g2p_en import G2p 16 | from g2p_en.expand import normalize_numbers 17 | from g2pM import G2pM 18 | from transformers import Wav2Vec2CTCTokenizer,Wav2Vec2FeatureExtractor, Wav2Vec2Processor 19 | 20 | 21 | 22 | class CharsiuPreprocessor: 23 | 24 | def __init__(self): 25 | pass 26 | 27 | 28 | def get_phones_and_words(self): 29 | raise NotImplementedError 30 | 31 | 32 | def get_phone_ids(self): 33 | raise NotImplementedError 34 | 35 | 36 | def mapping_phone2id(self,phone): 37 | ''' 38 | Convert a phone to a numerical id 39 | 40 | Parameters 41 | ---------- 42 | phone : str 43 | A phonetic symbol 44 | 45 | Returns 46 | ------- 47 | int 48 | A one-hot id for the input phone 49 | 50 | ''' 51 | return self.processor.tokenizer.convert_tokens_to_ids(phone) 52 | 53 | def mapping_id2phone(self,idx): 54 | ''' 55 | Convert a numerical id to a phone 56 | 57 | Parameters 58 | ---------- 59 | idx : int 60 | A one-hot id for a phone 61 | 62 | Returns 63 | ------- 64 | str 65 | A phonetic symbol 66 | 67 | ''' 68 | 69 | return self.processor.tokenizer.convert_ids_to_tokens(idx) 70 | 71 | 72 | def audio_preprocess(self,audio,sr=16000): 73 | ''' 74 | Load and normalize audio 75 | If the sampling rate is incompatible with models, the input audio will be resampled. 76 | 77 | Parameters 78 | ---------- 79 | path : str 80 | The path to the audio 81 | sr : int, optional 82 | Audio sampling rate, either 16000 or 32000. The default is 16000. 83 | 84 | Returns 85 | ------- 86 | torch.Tensor [(n,)] 87 | A list of audio sample as an one dimensional torch tensor 88 | 89 | ''' 90 | if type(audio)==str: 91 | if sr == 16000: 92 | features,fs = sf.read(audio) 93 | assert fs == 16000 94 | else: 95 | features, _ = librosa.core.load(audio,sr=sr) 96 | elif isinstance(audio, np.ndarray): 97 | features = audio 98 | else: 99 | raise Exception('The input must be a path or a numpy array!') 100 | return self.processor(features, sampling_rate=16000,return_tensors='pt').input_values.squeeze() 101 | 102 | ''' 103 | English g2p processor 104 | ''' 105 | class CharsiuPreprocessor_en(CharsiuPreprocessor): 106 | 107 | def __init__(self): 108 | 109 | tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('charsiu/tokenizer_en_cmu') 110 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 111 | self.processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 112 | self.g2p = G2p() 113 | self.sil = '[SIL]' 114 | self.sil_idx = self.mapping_phone2id(self.sil) 115 | # self.punctuation = set('.,!?') 116 | self.punctuation = set() 117 | 118 | def get_phones_and_words(self,sen): 119 | ''' 120 | Convert texts to phone sequence 121 | 122 | Parameters 123 | ---------- 124 | sen : str 125 | A str of input sentence 126 | 127 | Returns 128 | ------- 129 | sen_clean : list 130 | A list of phone sequence without stress marks 131 | sen : list 132 | A list of phone sequence with stress marks 133 | 134 | 135 | xxxxx should sen_clean be deleted? 136 | 137 | ''' 138 | 139 | phones = self.g2p(sen) 140 | words = self._get_words(sen) 141 | 142 | phones = list(tuple(g) for k,g in groupby(phones, key=lambda x: x != ' ') if k) 143 | 144 | aligned_phones = [] 145 | aligned_words = [] 146 | for p,w in zip(phones,words): 147 | if re.search(r'\w+\d?',p[0]): 148 | aligned_phones.append(p) 149 | aligned_words.append(w) 150 | elif p in self.punctuation: 151 | aligned_phones.append((self.sil,)) 152 | aligned_words.append(self.sil) 153 | 154 | assert len(aligned_words) == len(aligned_phones) 155 | 156 | return aligned_phones, aligned_words 157 | 158 | assert len(words) == len(phones) 159 | 160 | return phones, words 161 | 162 | 163 | 164 | def get_phone_ids(self,phones,append_silence=True): 165 | ''' 166 | Convert phone sequence to ids 167 | 168 | Parameters 169 | ---------- 170 | phones : list 171 | A list of phone sequence 172 | append_silence : bool, optional 173 | Whether silence is appended at the beginning and the end of the sequence. 174 | The default is True. 175 | 176 | Returns 177 | ------- 178 | ids: list 179 | A list of one-hot representations of phones 180 | 181 | ''' 182 | phones = list(chain.from_iterable(phones)) 183 | ids = [self.mapping_phone2id(re.sub(r'\d','',p)) for p in phones] 184 | 185 | # append silence at the beginning and the end 186 | if append_silence: 187 | if ids[0]!=self.sil_idx: 188 | ids = [self.sil_idx]+ids 189 | if ids[-1]!=self.sil_idx: 190 | ids.append(self.sil_idx) 191 | return ids 192 | 193 | 194 | 195 | def _get_words(self,text): 196 | ''' 197 | from G2P_en 198 | https://github.com/Kyubyong/g2p/blob/master/g2p_en/g2p.py 199 | 200 | Parameters 201 | ---------- 202 | sen : TYPE 203 | DESCRIPTION. 204 | 205 | Returns 206 | ------- 207 | words : TYPE 208 | DESCRIPTION. 209 | 210 | ''' 211 | 212 | text = unicode(text) 213 | text = normalize_numbers(text) 214 | text = ''.join(char for char in unicodedata.normalize('NFD', text) 215 | if unicodedata.category(char) != 'Mn') # Strip accents 216 | text = text.lower() 217 | text = re.sub("[^ a-z'.,?!\-]", "", text) 218 | text = text.replace("i.e.", "that is") 219 | text = text.replace("e.g.", "for example") 220 | 221 | # tokenization 222 | words = word_tokenize(text) 223 | 224 | return words 225 | 226 | def align_words(self, preds, phones, words): 227 | 228 | words_rep = [w for ph,w in zip(phones,words) for p in ph] 229 | phones_rep = [re.sub(r'\d','',p) for ph,w in zip(phones,words) for p in ph] 230 | assert len(words_rep)==len(phones_rep) 231 | 232 | # match each phone to its word 233 | word_dur = [] 234 | count = 0 235 | for dur in preds: 236 | if dur[-1] == '[SIL]': 237 | word_dur.append((dur,'[SIL]')) 238 | else: 239 | while dur[-1] != phones_rep[count]: 240 | count += 1 241 | word_dur.append((dur,words_rep[count])) #((start,end,phone),word) 242 | 243 | # merge phone-to-word alignment to derive word duration 244 | words = [] 245 | for key, group in groupby(word_dur, lambda x: x[-1]): 246 | group = list(group) 247 | entry = (group[0][0][0],group[-1][0][1],key) 248 | words.append(entry) 249 | 250 | return words 251 | 252 | 253 | ''' 254 | Mandarin g2p processor 255 | ''' 256 | 257 | 258 | class CharsiuPreprocessor_zh(CharsiuPreprocessor_en): 259 | 260 | def __init__(self): 261 | tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('charsiu/tokenizer_zh_pinyin') 262 | feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False) 263 | self.processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) 264 | self.g2p = G2pM() 265 | self.sil = "[SIL]" 266 | self.sil_idx = self.mapping_phone2id(self.sil) 267 | #self.punctuation = set('.,!?。,!?、') 268 | self.punctuation = set() 269 | # Pinyin tables 270 | self.consonant_list = set(['b', 'p', 'm', 'f', 'd', 't', 'n', 'l', 'g', 'k', 271 | 'h', 'j', 'q', 'x', 'zh', 'ch', 'sh', 'r', 'z', 272 | 'c', 's']) 273 | 274 | self.transform_dict = {'ju':'jv', 'qu':'qv', 'xu':'xv','jue':'jve', 275 | 'que':'qve', 'xue':'xve','quan':'qvan', 276 | 'xuan':'xvan','juan':'jvan', 277 | 'qun':'qvn','xun':'xvn', 'jun':'jvn', 278 | 'yuan':'van', 'yue':'ve', 'yun':'vn', 279 | 'you':'iou', 'yan':'ian', 'yin':'in', 280 | 'wa':'ua', 'wo':'uo', 'wai':'uai', 281 | 'weng':'ueng', 'wang':'uang','wu':'u', 282 | 'yu':'v','yi':'i','yo':'io','ya':'ia', 'ye':'ie', 283 | 'yao':'iao','yang':'iang', 'ying':'ing', 'yong':'iong', 284 | 'yvan':'van', 'yve':'ve', 'yvn':'vn', 285 | 'wa':'ua', 'wo':'uo', 'wai':'uai', 286 | 'wei':'ui', 'wan':'uan', 'wen':'un', 287 | 'weng':'ueng', 'wang':'uang','yv':'v', 288 | 'wuen':'un','wuo':'uo','wuang':'uang', 289 | 'wuan':'uan','wua':'ua','wuai':'uai', 290 | 'zhi':'zhiii','chi':'chiii','shi':'shiii', 291 | 'zi':'zii','ci':'cii','si':'sii'} 292 | self.er_mapping ={'er1':('e1','rr'),'er2':('e2','rr'),'er3':('e3','rr'),'er4':('e4','rr'), 293 | 'er5':('e5','rr'),'r5':('e5','rr')} 294 | self.rhyme_mapping = {'iu1':'iou1','iu2':'iou2','iu3':'iou3','iu4':'iou4','iu5':'iou5', 295 | 'u:e1':'ve1','u:e2':'ve2','u:e3':'ve3','u:e4':'ve4','u:e5':'ve5', 296 | 'u:1':'v1','u:2':'v2','u:3':'v3','u:4':'v4','u:5':'v5', 297 | 'ueng1':('u1','eng1'),'ueng2':('u2','eng2'),'ueng3':('u3','eng3'), 298 | 'ueng4':('u4','eng4'),'ueng5':('u5','eng5'),'io5':('i5','o5'), 299 | 'io4':('i4','o4'),'io1':('i1','o1')} 300 | 301 | def get_phones_and_words(self,sen): 302 | ''' 303 | Convert texts to phone sequence 304 | 305 | Parameters 306 | ---------- 307 | sen : str 308 | A str of input sentence 309 | 310 | Returns 311 | ------- 312 | sen_clean : list 313 | A list of phone sequence without stress marks 314 | sen : list 315 | A list of phone sequence with stress marks 316 | 317 | xxxxx should sen_clean be removed? 318 | ''' 319 | 320 | phones = self.g2p(sen) 321 | 322 | aligned_phones = [] 323 | aligned_words = [] 324 | for p,w in zip(phones,sen): 325 | if re.search(r'\w+:?\d',p): 326 | aligned_phones.append(self._separate_syllable(self.transform_dict.get(p[:-1],p[:-1])+p[-1])) 327 | aligned_words.append(w) 328 | elif p in self.punctuation: 329 | aligned_phones.append((self.sil,)) 330 | aligned_words.append(self.sil) 331 | 332 | assert len(aligned_phones)==len(aligned_words) 333 | return aligned_phones, aligned_words 334 | 335 | 336 | def get_phone_ids(self,phones,append_silence=True): 337 | ''' 338 | Convert phone sequence to ids 339 | 340 | Parameters 341 | ---------- 342 | phones : list 343 | A list of phone sequence 344 | append_silence : bool, optional 345 | Whether silence is appended at the beginning and the end of the sequence. 346 | The default is True. 347 | 348 | Returns 349 | ------- 350 | ids: list 351 | A list of one-hot representations of phones 352 | 353 | ''' 354 | phones = list(chain.from_iterable(phones)) 355 | ids = [self.mapping_phone2id(p) for p in phones] 356 | 357 | # append silence at the beginning and the end 358 | if append_silence: 359 | if ids[0]!=self.sil_idx: 360 | ids = [self.sil_idx]+ids 361 | if ids[-1]!=self.sil_idx: 362 | ids.append(self.sil_idx) 363 | return ids 364 | 365 | 366 | def _separate_syllable(self,syllable): 367 | """ 368 | seprate syllable to consonant + ' ' + vowel 369 | 370 | Parameters 371 | ---------- 372 | syllable : xxxxx TYPE 373 | xxxxx DESCRIPTION. 374 | 375 | Returns 376 | ------- 377 | syllable: xxxxx TYPE 378 | xxxxxx DESCRIPTION. 379 | 380 | """ 381 | 382 | assert syllable[-1].isdigit() 383 | if syllable == 'ri4': 384 | return ('r','iii4') 385 | if syllable[:-1] == 'ueng' or syllable[:-1] == 'io': 386 | return self.rhyme_mapping.get(syllable,syllable) 387 | if syllable in self.er_mapping.keys(): 388 | return self.er_mapping[syllable] 389 | if syllable[0:2] in self.consonant_list: 390 | #return syllable[0:2].encode('utf-8'),syllable[2:].encode('utf-8') 391 | return syllable[0:2], self.rhyme_mapping.get(syllable[2:],syllable[2:]) 392 | elif syllable[0] in self.consonant_list: 393 | #return syllable[0].encode('utf-8'),syllable[1:].encode('utf-8') 394 | return syllable[0], self.rhyme_mapping.get(syllable[1:],syllable[1:]) 395 | else: 396 | #return (syllable.encode('utf-8'),) 397 | return (syllable,) 398 | 399 | 400 | def align_words(self, preds, phones, words): 401 | 402 | words_rep = [w+str(i) for i,(ph,w) in enumerate(zip(phones,words)) for p in ph] 403 | phones_rep = [p for ph,w in zip(phones,words) for p in ph] 404 | assert len(words_rep)==len(phones_rep) 405 | 406 | # match each phone to its word 407 | word_dur = [] 408 | count = 0 409 | for dur in preds: 410 | if dur[-1] == '[SIL]': 411 | word_dur.append((dur,'[SIL]')) 412 | else: 413 | while dur[-1] != phones_rep[count]: 414 | count += 1 415 | if count >= len(phones_rep): 416 | break 417 | word_dur.append((dur,words_rep[count])) #((start,end,phone),word) 418 | 419 | # merge phone-to-word alignment to derive word duration 420 | words = [] 421 | for key, group in groupby(word_dur, lambda x: x[-1]): 422 | group = list(group) 423 | entry = (group[0][0][0],group[-1][0][1],re.sub(r'\d','',key)) 424 | words.append(entry) 425 | 426 | return words 427 | 428 | 429 | 430 | if __name__ == '__main__': 431 | ''' 432 | Testing functions 433 | ''' 434 | 435 | processor = CharsiuPreprocessor_zh() 436 | phones, words = processor.get_phones_and_words("鱼香肉丝、王道椒香鸡腿和川蜀鸡翅。") 437 | print(phones) 438 | print(words) 439 | ids = processor.get_phone_ids(phones) 440 | print(ids) 441 | 442 | processor = CharsiuPreprocessor_en() 443 | phones, words = processor.get_phones_and_words("I’m playing octopath right now!") 444 | print(phones) 445 | print(words) 446 | ids = processor.get_phone_ids(phones) 447 | print(ids) 448 | 449 | 450 | 451 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | import numpy as np 6 | 7 | import re 8 | from praatio import textgrid 9 | from itertools import groupby 10 | from librosa.sequence import dtw 11 | 12 | 13 | 14 | def ctc2duration(phones,resolution=0.01): 15 | """ 16 | xxxxx convert ctc to duration 17 | 18 | Parameters 19 | ---------- 20 | phones : list 21 | A list of phone sequence 22 | resolution : float, optional 23 | The resolution of xxxxx. The default is 0.01. 24 | 25 | Returns 26 | ------- 27 | merged : list 28 | xxxxx A list of duration values. 29 | 30 | """ 31 | 32 | counter = 0 33 | out = [] 34 | for p,group in groupby(phones): 35 | length = len(list(group)) 36 | out.append((round(counter*resolution,2),round((counter+length)*resolution,2),p)) 37 | counter += length 38 | 39 | merged = [] 40 | for i, (s,e,p) in enumerate(out): 41 | if i==0 and p=='[PAD]': 42 | merged.append((s,e,'[SIL]')) 43 | elif p=='[PAD]': 44 | merged.append((out[i-1][0],e,out[i-1][2])) 45 | elif i==len(out)-1: 46 | merged.append((s,e,p)) 47 | return merged 48 | 49 | 50 | def seq2duration(phones,resolution=0.01): 51 | """ 52 | xxxxx convert phone sequence to duration 53 | 54 | Parameters 55 | ---------- 56 | phones : list 57 | A list of phone sequence 58 | resolution : float, optional 59 | The resolution of xxxxx. The default is 0.01. 60 | 61 | Returns 62 | ------- 63 | out : list 64 | xxxxx A list of duration values. 65 | 66 | """ 67 | 68 | counter = 0 69 | out = [] 70 | for p,group in groupby(phones): 71 | length = len(list(group)) 72 | out.append((round(counter*resolution,2),round((counter+length)*resolution,2),p)) 73 | counter += length 74 | return out 75 | 76 | 77 | def duration2textgrid(duration_seq,save_path=None): 78 | """ 79 | Save duration values to textgrids 80 | 81 | Parameters 82 | ---------- 83 | duration_seq : list 84 | xxxxx A list of duration values. 85 | save_path : str, optional 86 | The path to save the TextGrid files. The default is None. 87 | 88 | Returns 89 | ------- 90 | tg : TextGrid file?? str?? xxxxx? 91 | A textgrid object containing duration information. 92 | 93 | """ 94 | 95 | tg = textgrid.Textgrid() 96 | phoneTier = textgrid.IntervalTier('phones', duration_seq, 0, duration_seq[-1][1]) 97 | tg.addTier(phoneTier) 98 | if save_path: 99 | tg.save(save_path,format="short_textgrid", includeBlankSpaces=False) 100 | return tg 101 | 102 | 103 | def word2textgrid(duration_seq,word_seq,save_path=None): 104 | """ 105 | Save duration values to textgrids 106 | 107 | Parameters 108 | ---------- 109 | duration_seq : list 110 | xxxxx A list of duration values. 111 | save_path : str, optional 112 | The path to save the TextGrid files. The default is None. 113 | 114 | Returns 115 | ------- 116 | tg : TextGrid file?? str?? xxxxx? 117 | A textgrid object containing duration information. 118 | 119 | """ 120 | 121 | tg = textgrid.Textgrid() 122 | phoneTier = textgrid.IntervalTier('phones', duration_seq, 0, duration_seq[-1][1]) 123 | tg.addTier(phoneTier) 124 | wordTier = textgrid.IntervalTier('words', word_seq, 0, word_seq[-1][1]) 125 | tg.addTier(wordTier) 126 | if save_path: 127 | tg.save(save_path,format="short_textgrid", includeBlankSpaces=False) 128 | return tg 129 | 130 | 131 | 132 | def get_boundaries(phone_seq): 133 | """ 134 | Get time of phone boundaries 135 | 136 | Parameters 137 | ---------- 138 | phone_seq : list xxxx? 139 | A list of phone sequence. 140 | 141 | Returns 142 | ------- 143 | timings: A list of time stamps 144 | symbols: A list of phone symbols 145 | 146 | """ 147 | 148 | boundaries = defaultdict(set) 149 | for s,e,p in phone_seq: 150 | boundaries[s].update([p.upper()]) 151 | # boundaries[e].update([p.upper()+'_e']) 152 | timings = np.array(list(boundaries.keys())) 153 | symbols = list(boundaries.values()) 154 | return (timings,symbols) 155 | 156 | 157 | def check_textgrid_duration(textgrid,duration): 158 | """ 159 | Check whether the duration of a textgrid file equals to 'duration'. 160 | If not, replace duration of the textgrid file. 161 | 162 | Parameters 163 | ---------- 164 | textgrid : .TextGrid object 165 | A .TextGrid object. 166 | duration : float 167 | A given length of time. 168 | 169 | Returns 170 | ------- 171 | textgrid : .TextGrid object 172 | A modified/unmodified textgrid. 173 | 174 | """ 175 | 176 | 177 | endtime = textgrid.tierDict['phones'].entryList[-1].end 178 | if not endtime==duration: 179 | last = textgrid.tierDict['phones'].entryList.pop() 180 | textgrid.tierDict['phones'].entryList.append(last._replace(end=duration)) 181 | 182 | return textgrid 183 | 184 | 185 | def textgrid_to_labels(phones,duration,resolution): 186 | """ 187 | 188 | 189 | Parameters 190 | ---------- 191 | phones : list 192 | A list of phone sequence 193 | resolution : float, optional 194 | The resolution of xxxxx. The default is 0.01. 195 | duration : float 196 | A given length of time. 197 | 198 | 199 | Returns 200 | ------- 201 | labels : list 202 | A list of phone labels. 203 | 204 | """ 205 | 206 | labels = [] 207 | clock = 0.0 208 | 209 | for i, (s,e,p) in enumerate(phones): 210 | 211 | assert clock >= s 212 | while clock <= e: 213 | labels.append(p) 214 | clock += resolution 215 | 216 | # if more than half of the current frame is outside the current phone 217 | # we'll label it as the next phone 218 | if np.abs(clock-e) > resolution/2: 219 | labels[-1] = phones[min(len(phones)-1,i+1)][2] 220 | 221 | # if the final time interval is longer than the total duration 222 | # we will chop off this frame 223 | if clock-duration > resolution/2: 224 | labels.pop() 225 | 226 | return labels 227 | 228 | def remove_null_and_numbers(labels): 229 | """ 230 | Remove labels which are null, noise, or numbers. 231 | 232 | Parameters 233 | ---------- 234 | labels : list 235 | A list of text labels. 236 | 237 | Returns 238 | ------- 239 | out : list 240 | A list of new labels. 241 | 242 | """ 243 | 244 | out = [] 245 | noises = set(['SPN','NSN','LAU']) 246 | for l in labels: 247 | l = re.sub(r'\d+','',l) 248 | l = l.upper() 249 | if l == '' or l == 'SIL': 250 | l = '[SIL]' 251 | if l == 'SP': 252 | l = '[SIL]' 253 | if l in noises: 254 | l = '[UNK]' 255 | out.append(l) 256 | return out 257 | 258 | 259 | def insert_sil(phones): 260 | """ 261 | Insert silences. 262 | 263 | Parameters 264 | ---------- 265 | phones : list 266 | A list of phone sequence 267 | 268 | Returns 269 | ------- 270 | out : list 271 | A list of new labels. 272 | 273 | """ 274 | 275 | out = [] 276 | for i,(s,e,p) in enumerate(phones): 277 | 278 | if out: 279 | if out[-1][1]!=s: 280 | out.append((out[-1][1],s,'[SIL]')) 281 | out.append((s,e,p)) 282 | return out 283 | 284 | 285 | def forced_align(cost, phone_ids): 286 | 287 | """ 288 | Force align text to audio. 289 | 290 | Parameters 291 | ---------- 292 | cost : float xxxxx 293 | xxxxx. 294 | phone_ids : list 295 | A list of phone IDs. 296 | 297 | Returns 298 | ------- 299 | align_id : list 300 | A list of IDs for aligned phones. 301 | 302 | """ 303 | 304 | D,align = dtw(C=-cost[:,phone_ids], 305 | step_sizes_sigma=np.array([[1, 1], [1, 0]])) 306 | 307 | align_seq = [-1 for i in range(max(align[:,0])+1)] 308 | for i in list(align): 309 | # print(align) 310 | if align_seq[i[0]]