├── .github └── workflows │ └── python.yml ├── .gitignore ├── LICENSE ├── README.md ├── doc ├── language.md └── performance │ ├── README.md │ ├── epitran.md │ ├── supervised.md │ └── zsl.md ├── requirements.txt ├── sample.txt ├── setup.py ├── test └── test_tokenizer.py └── transphone ├── __init__.py ├── bin ├── __init__.py ├── download_model.py ├── eval_epitran.py ├── eval_g2p.py ├── eval_zsl_g2p.py ├── g2p.py ├── tokenize.py ├── train_g2p.py └── update_model.py ├── config.py ├── data ├── __init__.py └── exp │ ├── 042801_base.yml │ └── __init__.py ├── g2p.py ├── lang ├── __init__.py ├── base_tokenizer.py ├── cmn │ ├── __init__.py │ ├── normalizer.py │ └── tokenizer.py ├── eng │ ├── __init__.py │ ├── normalizer.py │ └── tokenizer.py ├── epitran_tokenizer.py ├── g2p_tokenizer.py └── jpn │ ├── __init__.py │ ├── conv_table.py │ ├── jaconv.py │ ├── kana2phoneme.py │ ├── normalizer.py │ └── tokenizer.py ├── model ├── __init__.py ├── checkpoint_utils.py ├── dataset.py ├── ensemble.py ├── grapheme.py ├── loader.py ├── lstm.py ├── transformer.py ├── utils.py └── vocab.py ├── run.py ├── tokenizer.py └── utils.py /.github/workflows/python.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI Test 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.9 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: 3.9 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install flake8 pytest 27 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 28 | python setup.py develop 29 | - name: Test with pytest 30 | run: | 31 | pytest 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | /data 132 | transphone/output 133 | transphone/runs 134 | transphone/pretrained 135 | .idea 136 | scripts/ 137 | notes/ 138 | transphone/research 139 | transphone/data/decode 140 | transphone/data/model 141 | transphone/data/sandbox 142 | transphone/data/exp/private 143 | transphone/data/jobs 144 | .DS_Store 145 | doc/research -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2012-2022 Scott Chacon and others 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining 4 | a copy of this software and associated documentation files (the 5 | "Software"), to deal in the Software without restriction, including 6 | without limitation the rights to use, copy, modify, merge, publish, 7 | distribute, sublicense, and/or sell copies of the Software, and to 8 | permit persons to whom the Software is furnished to do so, subject to 9 | the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be 12 | included in all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 15 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 16 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 17 | NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE 18 | LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 19 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION 20 | WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transphone 2 | 3 | ![CI Test](https://github.com/xinjli/transphone/actions/workflows/python.yml/badge.svg) 4 | 5 | `transphone` is a multilingual grapheme-to-phoneme conversion toolkit derived from our paper: [Zero-shot Learning for Grapheme to Phoneme Conversion with Language Ensemble](https://aclanthology.org/2022.findings-acl.166/). 6 | 7 | It provides approximiated phoneme tokenizers and G2P model for 7546 languages registered in the [Glottolog database](https://glottolog.org/glottolog/language). You can see the full list of supported languages in [the language doc](./doc/language.md) 8 | 9 | ## Install 10 | 11 | transphone is available from pip 12 | 13 | ```bash 14 | pip install transphone 15 | ``` 16 | 17 | You can clone this repository and install 18 | 19 | ```bash 20 | python setup.py install 21 | ``` 22 | 23 | ## Tokenizer Usage 24 | 25 | The tokenizer converts a string into each languages' phonemes. By default, it combines a few approach to decide the pronunciation of a word for the target language: 26 | 27 | - **lexicon-based**: it will first lookup lexicon dictionary for pronunciation (from Wikitionary, cmudict, and other sources), currently around 1k languages have at least some entries. 28 | - **transducer-based**: it will use rule-based transducer from [epitran](https://github.com/dmort27/epitran) for several languages considering accuracy and speed. 29 | - **g2p-based**: use the G2P model as described in the next section. 30 | 31 | ### python interface 32 | 33 | You can use it from python as follows: 34 | 35 | ```python 36 | In [1]: from transphone import read_tokenizer 37 | 38 | # use 2-char or 3-char ISO id to specify your target language 39 | In [2]: eng = read_tokenizer('eng') 40 | 41 | # tokenize a string of text into a list of phonemes 42 | In [3]: lst = eng.tokenize('hello world') 43 | 44 | In [4]: lst 45 | Out[4]: ['h', 'ʌ', 'l', 'o', 'w', 'w', 'ɹ̩', 'l', 'd'] 46 | 47 | In [5]: ids = eng.convert_tokens_to_ids(lst) 48 | 49 | In [6]: ids 50 | Out[6]: [7, 36, 11, 14, 21, 21, 33, 11, 3] 51 | 52 | In [7]: eng.convert_ids_to_tokens(ids) 53 | Out[7]: ['h', 'ʌ', 'l', 'o', 'w', 'w', 'ɹ̩', 'l', 'd'] 54 | 55 | In [8]: jpn = read_tokenizer('jpn') 56 | 57 | In [9]: jpn.tokenize('こんにちは世界') 58 | Out[9]: ['k', 'o', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', 's', 'e', 'k', 'a', 'i'] 59 | 60 | In [10]: cmn = read_tokenizer('cmn') 61 | 62 | In [11]: cmn.tokenize('你好世界') 63 | Out[11]: ['n', 'i', 'x', 'a', 'o', 'ʂ', 'ɻ̩', 't͡ɕ', 'i', 'e'] 64 | 65 | In [12]: deu = read_tokenizer('deu') 66 | 67 | In [13]: deu.tokenize('Hallo Welt') 68 | Out[13]: ['h', 'a', 'l', 'o', 'v', 'e', 'l', 't'] 69 | 70 | ``` 71 | 72 | ### command line interface 73 | 74 | A command line tool is also available 75 | 76 | ```bash 77 | # compute pronunciation for every word in input file 78 | $ python -m transphone.run --lang eng --input sample.txt 79 | h ɛ l o ʊ 80 | w ə l d 81 | t ɹ æ n s f ə ʊ n 82 | 83 | # by specifying combine flag, you can get word + pronunciation per line 84 | $ python -m transphone.run --lang eng --input sample.txt --combine=True 85 | hello h ɛ l o ʊ 86 | world w ə l d 87 | transphone t ɹ æ n s f ə ʊ n 88 | ``` 89 | 90 | ## G2P Backend Usage 91 | 92 | The tokenizer in the previous section uses the G2P as one of the backend option. You can also use the G2P model directly. 93 | 94 | ### python interface 95 | 96 | A simple python usage is as follows: 97 | 98 | ```python 99 | In [1]: from transphone.g2p import read_g2p 100 | 101 | # read a pretrained model. It will download the pretrained model automatically into repo_root/data/model 102 | In [2]: model = read_g2p() 103 | 104 | # to infer pronunciation for a word with ISO 639-3 id 105 | # For any pretrained languages (~900 languages), it will use the pretrained model without approximation 106 | In [3]: model.inference('transphone', 'eng') 107 | Out[3]: ['t', 'ɹ', 'æ', 'n', 's', 'f', 'ə', 'ʊ', 'n'] 108 | 109 | # If the specified language is not available, then it will approximate it using nearest languages 110 | # in this case, aaa (Ghotuo language) is not one of the training languages, we fetch 10 nearest languages to approximate it 111 | In [4]: model.inference('transphone', 'aaa') 112 | lang aaa is not available directly, use ['bin', 'bja', 'bkh', 'bvx', 'dua', 'eto', 'gwe', 'ibo', 'kam', 'kik'] instead 113 | Out[4]: ['t', 'l', 'a', 'n', 's', 'f', 'o', 'n', 'e'] 114 | 115 | # To gain deeper insights, you can also specify debug flag to see output of each language 116 | In [5]: model.inference('transphone', 'aaa', debug=True) 117 | bin ['s', 'l', 'a', 'n', 's', 'f', 'o', 'n', 'e'] 118 | bja ['s', 'l', 'a', 'n', 's', 'f', 'o', 'n'] 119 | bkh ['t', 'l', 'a', 'n', 's', 'f', 'o', 'n', 'e'] 120 | bvx ['t', 'r', 'a', 'n', 's', 'f', 'o', 'n', 'e'] 121 | dua ['t', 'r', 'n', 's', 'f', 'n'] 122 | eto ['t', 'l', 'a', 'n', 's', 'f', 'o', 'n', 'e'] 123 | gwe ['t', 'l', 'a', 'n', 's', 'f', 'o', 'n', 'e'] 124 | ibo ['t', 'l', 'a', 'n', 's', 'p', 'o', 'n', 'e'] 125 | kam ['t', 'l', 'a', 'n', 's', 'f', 'o', 'n', 'e'] 126 | kik ['t', 'l', 'a', 'n', 's', 'f', 'ɔ', 'n', 'ɛ'] 127 | Out[5]: ['t', 'l', 'a', 'n', 's', 'f', 'o', 'n', 'e'] 128 | ``` 129 | 130 | ### Pretrained Models 131 | 132 | This pretrained models roughly following our paper accepted at `Findings of ACL 2022`: [Zero-shot Learning for Grapheme to Phoneme Conversion with Language Ensemble](https://aclanthology.org/2022.findings-acl.166/). 133 | 134 | You can see the G2P evaluation over 1k languages on the [performance doc](./doc/performance/README.md) 135 | 136 | Note this is the pure G2P evaluation on unseen words. The tokenizer combines other existing resources (i.e. lexicon) as well, so the tokenizer's performance is expected to be much better than this. 137 | 138 | | model | # supported languages | supervised language PER | zero-shot language PER | description | 139 | |:--------------------:|:---------------------:|:-----------------------:|:----------------------:|:------------------------:| 140 | | 042801_base (latest) | ~8k | 13% | 31% | based on our work at [1] | 141 | 142 | ### Training 143 | 144 | We also provide the training code for G2P. You can reproduce the pretrained model using `transphone.bin.train_g2p` 145 | 146 | ## Epitran backend 147 | 148 | This repo also provides a wrapper of a customized version of [epitran](https://github.com/dmort27/epitran). For a few languages, it will use epitran as the backend considering accuracy and speed. 149 | 150 | You can also use epitran directly as follows: 151 | 152 | ```python 153 | In [1]: tokenizer = read_epitran_tokenizer('spa', use_lexicon=False) 154 | 155 | In [2]: tokenizer.tokenize('hola') 156 | Out[2]: ['o', 'l', 'a'] 157 | ``` 158 | 159 | ## Reference 160 | 161 | - [1] Li, Xinjian, et al. "Zero-shot Learning for Grapheme to Phoneme Conversion with Language Ensemble." Findings of the Association for Computational Linguistics: ACL 2022. 2022. 162 | - [2] Li, Xinjian, et al. "Phone Inventories and Recognition for Every Language" LREC 2022. 2022 -------------------------------------------------------------------------------- /doc/performance/README.md: -------------------------------------------------------------------------------- 1 | # Performance 2 | 3 | This dir contains the G2P's testing evaluation around 1k languages. 4 | 5 | - [supervised testing performance](./supervised.md) on 276 languages 6 | - [zero-shot testing performance](./zsl.md) on 613 languages (no overlap with 276) 7 | 8 | As baselines, we also compare with: 9 | 10 | - [epitran](./epitran.md) on 61 languages 11 | -------------------------------------------------------------------------------- /doc/performance/epitran.md: -------------------------------------------------------------------------------- 1 | # epitran evaluation 2 | 3 | This contains the epitran evaluation based on supported epitran performance. These are tested on the same testing set as the supervised model. 4 | 5 | Note that it might not be fair to compare epitran with our model as our model is trained using the Wikitionary training set and epitran is trained using other sources. so our model's training set is consistent with the testing set, but epitran's training set is not. This is only included for reference. 6 | 7 | This can be reproduced by `transphone.bin.eval_epitran` 8 | 9 | | language | phoneme error rate | phonological distance | 10 | |----------|--------------------|-----------------------| 11 | | tur | 0.045 | 0.032 | 12 | | vie | 0.308 | 0.151 | 13 | | aar | 0.222 | 0.026 | 14 | | got | 0.851 | 0.391 | 15 | | swa | 0.036 | 0.036 | 16 | | swe | 0.263 | 0.050 | 17 | | amh | 0.239 | 0.112 | 18 | | hat | 0.025 | 0.004 | 19 | | tam | 0.198 | 0.110 | 20 | | mal | 0.357 | 0.106 | 21 | | hin | 0.317 | 0.080 | 22 | | mar | 0.202 | 0.056 | 23 | | tel | 0.467 | 0.199 | 24 | | ara | 0.359 | 0.218 | 25 | | mlt | 0.169 | 0.053 | 26 | | tgk | 0.074 | 0.044 | 27 | | mon | 0.403 | 0.145 | 28 | | tgl | 0.283 | 0.081 | 29 | | hun | 0.030 | 0.011 | 30 | | msa | 0.178 | 0.025 | 31 | | tha | 0.208 | 0.072 | 32 | | ben | 0.217 | 0.082 | 33 | | ilo | 0.199 | 0.092 | 34 | | mya | 0.239 | 0.136 | 35 | | ind | 0.125 | 0.017 | 36 | | nan | 0.373 | 0.373 | 37 | | ita | 0.152 | 0.042 | 38 | | jam | 0.438 | 0.208 | 39 | | nld | 0.256 | 0.075 | 40 | | cat | 0.292 | 0.057 | 41 | | nya | 0.142 | 0.135 | 42 | | tuk | 0.144 | 0.029 | 43 | | ceb | 0.095 | 0.067 | 44 | | ces | 0.338 | 0.017 | 45 | | ori | 0.064 | 0.036 | 46 | | ckb | 0.297 | 0.112 | 47 | | kat | 0.005 | 0.000 | 48 | | pan | 0.857 | 0.383 | 49 | | kaz | 0.985 | 0.369 | 50 | | uig | 0.276 | 0.033 | 51 | | csb | 0.259 | 0.159 | 52 | | ukr | 0.368 | 0.064 | 53 | | deu | 0.335 | 0.071 | 54 | | pol | 0.050 | 0.046 | 55 | | urd | 0.522 | 0.328 | 56 | | kbd | 0.195 | 0.129 | 57 | | por | 0.427 | 0.081 | 58 | | ron | 0.023 | 0.002 | 59 | | fas | 0.398 | 0.255 | 60 | | kir | 1.029 | 0.402 | 61 | | rus | 0.255 | 0.061 | 62 | | xho | 0.418 | 0.120 | 63 | | yor | 0.092 | 0.059 | 64 | | kmr | 0.050 | 0.009 | 65 | | yue | 0.197 | 0.043 | 66 | | zha | 0.079 | 0.021 | 67 | | lao | 0.419 | 0.260 | 68 | | spa | 0.058 | 0.043 | 69 | | zul | 0.090 | 0.036 | 70 | | sqi | 0.041 | 0.017 | 71 | | lij | 0.767 | 0.173 | -------------------------------------------------------------------------------- /doc/performance/supervised.md: -------------------------------------------------------------------------------- 1 | # supervised testing performance 2 | 3 | The following 276 languages have at least 50 entries in the Wikitionary. Those languages are supervised trained. 4 | 5 | They are tested on the last 25 unseen words (the other 25 were used as dev). You should be able to reproduce this table by using `transphone.bin.eval_g2p` 6 | 7 | - The average phoneme error rate is **0.13** 8 | - the average phonological distance is **0.05**. 9 | 10 | | language | phoneme error rate | phonological distance | 11 | |----------|--------------------|-----------------------| 12 | | aar | 0.006 | 0.006 | 13 | | abk | 0.266 | 0.162 | 14 | | acw | 0.182 | 0.122 | 15 | | ady | 0.103 | 0.052 | 16 | | afb | 0.174 | 0.091 | 17 | | afr | 0.051 | 0.013 | 18 | | ain | 0.033 | 0.021 | 19 | | ajp | 0.111 | 0.073 | 20 | | akk | 0.163 | 0.045 | 21 | | ale | 0.014 | 0.003 | 22 | | alr | 0.020 | 0.020 | 23 | | ang | 0.016 | 0.000 | 24 | | aot | 0.008 | 0.008 | 25 | | apw | 0.184 | 0.074 | 26 | | arb | 0.105 | 0.052 | 27 | | arc | 0.232 | 0.138 | 28 | | ary | 0.083 | 0.057 | 29 | | arz | 0.244 | 0.109 | 30 | | asm | 0.033 | 0.010 | 31 | | ast | 0.089 | 0.037 | 32 | | ayl | 0.138 | 0.055 | 33 | | azg | 0.125 | 0.102 | 34 | | azj | 0.076 | 0.042 | 35 | | bak | 0.267 | 0.267 | 36 | | bam | 0.487 | 0.087 | 37 | | ban | 0.057 | 0.023 | 38 | | bbl | 0.192 | 0.048 | 39 | | bcl | 0.117 | 0.072 | 40 | | bdq | 0.027 | 0.027 | 41 | | bel | 0.000 | 0.000 | 42 | | ben | 0.093 | 0.028 | 43 | | bod | 0.111 | 0.054 | 44 | | bre | 0.325 | 0.079 | 45 | | bsq | 0.857 | 0.350 | 46 | | bul | 0.056 | 0.012 | 47 | | cab | 0.195 | 0.052 | 48 | | cat | 0.009 | 0.001 | 49 | | cbn | 0.198 | 0.039 | 50 | | ceb | 0.089 | 0.055 | 51 | | ces | 0.000 | 0.000 | 52 | | chb | 0.062 | 0.035 | 53 | | che | 0.158 | 0.050 | 54 | | cho | 0.182 | 0.123 | 55 | | chv | 0.393 | 0.042 | 56 | | ckb | 0.068 | 0.052 | 57 | | cnk | 0.000 | 0.000 | 58 | | cop | 0.127 | 0.091 | 59 | | cor | 0.112 | 0.013 | 60 | | cos | 0.150 | 0.026 | 61 | | crk | 0.011 | 0.011 | 62 | | crx | 0.900 | 0.156 | 63 | | csb | 0.301 | 0.105 | 64 | | cym | 0.092 | 0.016 | 65 | | dan | 0.181 | 0.067 | 66 | | deu | 0.030 | 0.018 | 67 | | dhv | 0.321 | 0.108 | 68 | | dlm | 0.034 | 0.004 | 69 | | dng | 0.193 | 0.079 | 70 | | dsb | 0.007 | 0.000 | 71 | | dum | 0.111 | 0.058 | 72 | | dzo | 0.274 | 0.151 | 73 | | egl | 0.186 | 0.013 | 74 | | egy | 0.137 | 0.067 | 75 | | ekk | 0.038 | 0.014 | 76 | | ell | 0.010 | 0.003 | 77 | | eng | 0.114 | 0.052 | 78 | | enm | 0.103 | 0.039 | 79 | | eus | 0.028 | 0.013 | 80 | | ewe | 0.201 | 0.070 | 81 | | fao | 0.062 | 0.022 | 82 | | fin | 0.000 | 0.000 | 83 | | fra | 0.017 | 0.013 | 84 | | fro | 0.191 | 0.049 | 85 | | frr | 0.171 | 0.049 | 86 | | fry | 0.128 | 0.031 | 87 | | gla | 0.141 | 0.034 | 88 | | gle | 0.091 | 0.021 | 89 | | glg | 0.104 | 0.003 | 90 | | glv | 0.156 | 0.055 | 91 | | gml | 0.147 | 0.063 | 92 | | goh | 0.174 | 0.044 | 93 | | got | 0.018 | 0.009 | 94 | | grc | 0.000 | 0.000 | 95 | | gsw | 0.075 | 0.020 | 96 | | guj | 0.189 | 0.055 | 97 | | gur | 0.392 | 0.125 | 98 | | hat | 0.074 | 0.040 | 99 | | haw | 0.161 | 0.148 | 100 | | hbs | 0.010 | 0.000 | 101 | | heb | 0.183 | 0.055 | 102 | | hif | 0.299 | 0.098 | 103 | | hin | 0.006 | 0.006 | 104 | | hrx | 0.069 | 0.041 | 105 | | hts | 0.046 | 0.034 | 106 | | hun | 0.000 | 0.000 | 107 | | huu | 0.037 | 0.037 | 108 | | hye | 0.028 | 0.010 | 109 | | ilo | 0.075 | 0.025 | 110 | | ind | 0.054 | 0.012 | 111 | | inh | 0.276 | 0.072 | 112 | | isl | 0.043 | 0.008 | 113 | | ita | 0.019 | 0.001 | 114 | | izh | 0.129 | 0.026 | 115 | | jam | 0.158 | 0.060 | 116 | | jje | 0.098 | 0.076 | 117 | | kal | 0.021 | 0.017 | 118 | | kan | 0.178 | 0.075 | 119 | | kas | 0.089 | 0.044 | 120 | | kat | 0.000 | 0.000 | 121 | | kaz | 0.098 | 0.016 | 122 | | kbd | 0.126 | 0.072 | 123 | | kgj | 0.475 | 0.118 | 124 | | khb | 0.131 | 0.060 | 125 | | khw | 0.356 | 0.140 | 126 | | kik | 0.023 | 0.016 | 127 | | kir | 0.041 | 0.010 | 128 | | kld | 0.049 | 0.006 | 129 | | kmr | 0.000 | 0.000 | 130 | | knn | 0.255 | 0.032 | 131 | | kor | 0.407 | 0.396 | 132 | | kpv | 0.047 | 0.025 | 133 | | krl | 0.131 | 0.087 | 134 | | kxd | 0.013 | 0.002 | 135 | | lad | 0.126 | 0.041 | 136 | | lao | 0.084 | 0.046 | 137 | | lat | 0.000 | 0.000 | 138 | | lav | 0.105 | 0.033 | 139 | | lcp | 0.080 | 0.051 | 140 | | lic | 0.760 | 0.371 | 141 | | lif | 0.038 | 0.018 | 142 | | lij | 0.063 | 0.034 | 143 | | lim | 0.185 | 0.066 | 144 | | lit | 0.183 | 0.017 | 145 | | liv | 0.078 | 0.013 | 146 | | lmo | 0.168 | 0.094 | 147 | | lmy | 0.000 | 0.000 | 148 | | lsi | 0.036 | 0.014 | 149 | | ltg | 0.059 | 0.002 | 150 | | ltz | 0.215 | 0.041 | 151 | | lwl | 0.065 | 0.047 | 152 | | mah | 0.078 | 0.012 | 153 | | mai | 1.250 | 0.560 | 154 | | mak | 0.000 | 0.000 | 155 | | mal | 0.299 | 0.096 | 156 | | mar | 0.016 | 0.002 | 157 | | mdf | 0.135 | 0.024 | 158 | | mfe | 0.061 | 0.054 | 159 | | mga | 0.169 | 0.035 | 160 | | mic | 0.048 | 0.043 | 161 | | mkd | 0.000 | 0.000 | 162 | | mlt | 0.056 | 0.030 | 163 | | mnc | 0.000 | 0.000 | 164 | | mnw | 0.577 | 0.264 | 165 | | mon | 0.209 | 0.059 | 166 | | mvi | 0.248 | 0.134 | 167 | | mww | 0.056 | 0.023 | 168 | | mya | 0.116 | 0.040 | 169 | | nan | 0.043 | 0.043 | 170 | | nap | 0.129 | 0.021 | 171 | | nav | 0.076 | 0.030 | 172 | | nci | 0.025 | 0.017 | 173 | | nds | 0.178 | 0.082 | 174 | | nep | 0.000 | 0.000 | 175 | | new | 0.056 | 0.021 | 176 | | nhg | 0.032 | 0.019 | 177 | | nhn | 0.189 | 0.167 | 178 | | nhx | 0.018 | 0.018 | 179 | | niv | 0.436 | 0.026 | 180 | | nld | 0.050 | 0.006 | 181 | | nmy | 0.085 | 0.045 | 182 | | nno | 0.099 | 0.048 | 183 | | nob | 0.066 | 0.012 | 184 | | non | 0.058 | 0.018 | 185 | | nrf | 0.290 | 0.114 | 186 | | nya | 0.007 | 0.007 | 187 | | oci | 0.086 | 0.016 | 188 | | ofs | 0.144 | 0.020 | 189 | | olo | 0.155 | 0.019 | 190 | | ory | 0.014 | 0.007 | 191 | | osp | 0.056 | 0.019 | 192 | | osx | 0.079 | 0.018 | 193 | | ota | 0.331 | 0.105 | 194 | | pan | 0.227 | 0.055 | 195 | | pao | 0.640 | 0.217 | 196 | | pau | 0.218 | 0.063 | 197 | | pbv | 0.356 | 0.134 | 198 | | pcc | 0.271 | 0.103 | 199 | | pdc | 0.049 | 0.038 | 200 | | pes | 0.188 | 0.057 | 201 | | phl | 0.007 | 0.007 | 202 | | pjt | 0.060 | 0.044 | 203 | | plt | 0.088 | 0.063 | 204 | | pms | 0.021 | 0.001 | 205 | | pol | 0.000 | 0.000 | 206 | | pon | 0.512 | 0.244 | 207 | | por | 0.178 | 0.045 | 208 | | pox | 0.008 | 0.008 | 209 | | ppl | 0.041 | 0.035 | 210 | | pus | 0.157 | 0.079 | 211 | | quc | 0.076 | 0.005 | 212 | | raw | 0.054 | 0.001 | 213 | | ron | 0.029 | 0.008 | 214 | | rup | 0.052 | 0.026 | 215 | | rus | 0.014 | 0.001 | 216 | | rys | 0.121 | 0.067 | 217 | | ryu | 0.317 | 0.190 | 218 | | sah | 0.046 | 0.022 | 219 | | san | 0.034 | 0.012 | 220 | | sce | 0.146 | 0.058 | 221 | | scn | 0.153 | 0.058 | 222 | | sco | 0.121 | 0.043 | 223 | | sei | 0.640 | 0.129 | 224 | | sga | 0.033 | 0.005 | 225 | | shn | 0.060 | 0.060 | 226 | | sid | 0.027 | 0.027 | 227 | | slk | 0.044 | 0.005 | 228 | | slv | 0.058 | 0.008 | 229 | | sme | 0.050 | 0.003 | 230 | | sms | 0.137 | 0.074 | 231 | | snd | 0.273 | 0.103 | 232 | | spa | 0.004 | 0.001 | 233 | | sqi | 0.068 | 0.044 | 234 | | srn | 0.048 | 0.035 | 235 | | sro | 0.056 | 0.008 | 236 | | stq | 0.061 | 0.005 | 237 | | svm | 0.053 | 0.018 | 238 | | swa | 0.007 | 0.001 | 239 | | swe | 0.073 | 0.017 | 240 | | syc | 0.242 | 0.192 | 241 | | syl | 0.051 | 0.019 | 242 | | tam | 0.006 | 0.006 | 243 | | tel | 0.047 | 0.040 | 244 | | tft | 0.188 | 0.042 | 245 | | tgk | 0.050 | 0.019 | 246 | | tgl | 0.054 | 0.036 | 247 | | tha | 0.018 | 0.018 | 248 | | tkl | 0.000 | 0.000 | 249 | | ton | 0.000 | 0.000 | 250 | | tpw | 0.191 | 0.162 | 251 | | tsn | 0.390 | 0.107 | 252 | | tuk | 0.121 | 0.016 | 253 | | tur | 0.040 | 0.009 | 254 | | twf | 0.055 | 0.055 | 255 | | tyv | 0.066 | 0.016 | 256 | | tzm | 0.021 | 0.021 | 257 | | tzo | 0.025 | 0.017 | 258 | | ugo | 0.171 | 0.042 | 259 | | uig | 0.164 | 0.052 | 260 | | ukr | 0.021 | 0.001 | 261 | | unm | 0.328 | 0.126 | 262 | | urd | 0.199 | 0.082 | 263 | | vec | 0.113 | 0.007 | 264 | | vie | 0.019 | 0.010 | 265 | | wau | 0.040 | 0.008 | 266 | | wiy | 0.047 | 0.015 | 267 | | wlo | 0.222 | 0.101 | 268 | | wol | 0.576 | 0.186 | 269 | | wuh | 0.402 | 0.276 | 270 | | xho | 0.025 | 0.011 | 271 | | xsl | 0.505 | 0.190 | 272 | | xug | 0.484 | 0.133 | 273 | | ybi | 0.000 | 0.000 | 274 | | ycl | 0.143 | 0.049 | 275 | | yid | 0.077 | 0.016 | 276 | | yoi | 0.274 | 0.152 | 277 | | yor | 0.038 | 0.024 | 278 | | yrk | 0.503 | 0.187 | 279 | | yua | 0.154 | 0.013 | 280 | | yue | 0.079 | 0.057 | 281 | | yux | 0.067 | 0.014 | 282 | | zgh | 0.115 | 0.115 | 283 | | zha | 0.070 | 0.044 | 284 | | zlm | 0.057 | 0.015 | 285 | | zom | 0.119 | 0.049 | 286 | | zul | 0.005 | 0.005 | 287 | | zza | 0.015 | 0.015 | 288 | 289 | -------------------------------------------------------------------------------- /doc/performance/zsl.md: -------------------------------------------------------------------------------- 1 | # zero-shot testing performance 2 | 3 | The following 613 languages have at least 1 entry in the Wikitionary but less than 50 entries. Those languages are not seen during the training. It is based on the ensemble approach proposed in our work. 4 | You should be able to reproduce this table by using `transphone.bin.eval_zsl_g2p` 5 | 6 | - the average phoneme error rate is **0.31** 7 | - the average phonological distance is **0.13** 8 | 9 | | language | phoneme error rate | phonological distance | 10 | |----------|--------------------|-----------------------| 11 | | aau | 0.304 | 0.154 | 12 | | aax | 0.125 | 0.125 | 13 | | abd | 0.111 | 0.111 | 14 | | abe | 0.171 | 0.034 | 15 | | abq | 0.297 | 0.074 | 16 | | abs | 0.250 | 0.034 | 17 | | abt | 0.571 | 0.249 | 18 | | abv | 0.500 | 0.034 | 19 | | aby | 0.083 | 0.014 | 20 | | ace | 0.433 | 0.143 | 21 | | acm | 0.213 | 0.117 | 22 | | acn | 1.333 | 0.822 | 23 | | acv | 0.667 | 0.345 | 24 | | aeb | 0.291 | 0.139 | 25 | | aek | 0.000 | 0.000 | 26 | | aem | 0.224 | 0.047 | 27 | | agg | 0.385 | 0.072 | 28 | | agm | 0.385 | 0.109 | 29 | | agn | 0.500 | 0.500 | 30 | | agu | 0.000 | 0.000 | 31 | | aib | 0.200 | 0.014 | 32 | | aii | 0.385 | 0.174 | 33 | | aio | 0.300 | 0.085 | 34 | | aji | 0.081 | 0.057 | 35 | | aka | 0.325 | 0.037 | 36 | | akj | 0.571 | 0.369 | 37 | | akl | 0.283 | 0.217 | 38 | | alc | 0.000 | 0.000 | 39 | | alq | 0.416 | 0.062 | 40 | | alt | 0.227 | 0.021 | 41 | | ami | 0.143 | 0.012 | 42 | | amm | 0.235 | 0.039 | 43 | | amn | 0.091 | 0.091 | 44 | | amp | 0.000 | 0.000 | 45 | | ams | 0.625 | 0.191 | 46 | | amw | 0.241 | 0.131 | 47 | | anc | 0.333 | 0.023 | 48 | | ane | 0.500 | 0.181 | 49 | | ank | 0.389 | 0.126 | 50 | | anm | 0.000 | 0.000 | 51 | | anq | 0.193 | 0.055 | 52 | | aoc | 0.000 | 0.000 | 53 | | aoi | 0.333 | 0.172 | 54 | | apc | 0.336 | 0.177 | 55 | | ape | 0.165 | 0.062 | 56 | | api | 0.714 | 0.209 | 57 | | apm | 0.000 | 0.000 | 58 | | apn | 0.400 | 0.243 | 59 | | apr | 0.000 | 0.000 | 60 | | apy | 0.300 | 0.010 | 61 | | aqc | 0.400 | 0.131 | 62 | | aqz | 0.333 | 0.333 | 63 | | are | 0.600 | 0.166 | 64 | | arg | 0.075 | 0.022 | 65 | | arn | 0.196 | 0.051 | 66 | | arp | 0.600 | 0.245 | 67 | | arq | 0.190 | 0.133 | 68 | | arr | 0.000 | 0.000 | 69 | | ars | 0.500 | 0.272 | 70 | | arw | 0.000 | 0.000 | 71 | | ava | 0.471 | 0.248 | 72 | | avd | 0.426 | 0.148 | 73 | | ave | 1.143 | 0.622 | 74 | | awk | 0.333 | 0.057 | 75 | | aym | 0.064 | 0.031 | 76 | | ayp | 0.286 | 0.150 | 77 | | azz | 0.167 | 0.167 | 78 | | bap | 1.000 | 0.069 | 79 | | bar | 0.185 | 0.038 | 80 | | bbh | 1.000 | 1.000 | 81 | | bbk | 0.167 | 0.076 | 82 | | bbo | 1.000 | 0.073 | 83 | | bca | 1.000 | 0.741 | 84 | | bch | 0.333 | 0.029 | 85 | | bci | 0.250 | 0.054 | 86 | | bdd | 0.167 | 0.049 | 87 | | bdg | 0.200 | 0.200 | 88 | | bdh | 0.000 | 0.000 | 89 | | bdo | 0.200 | 0.200 | 90 | | bdy | 0.000 | 0.000 | 91 | | bea | 0.000 | 0.000 | 92 | | beg | 0.333 | 0.121 | 93 | | beu | 0.065 | 0.018 | 94 | | bfm | 0.000 | 0.000 | 95 | | bft | 0.562 | 0.205 | 96 | | bhi | 0.000 | 0.000 | 97 | | bhl | 0.000 | 0.000 | 98 | | bhm | 0.333 | 0.333 | 99 | | bin | 0.286 | 0.143 | 100 | | bis | 0.000 | 0.000 | 101 | | bja | 0.000 | 0.000 | 102 | | bjb | 0.000 | 0.000 | 103 | | bjn | 0.000 | 0.000 | 104 | | bkd | 0.200 | 0.003 | 105 | | bkh | 0.000 | 0.000 | 106 | | bkj | 1.500 | 1.017 | 107 | | bkq | 0.000 | 0.000 | 108 | | blc | 0.571 | 0.274 | 109 | | blf | 0.250 | 0.026 | 110 | | bmh | 0.333 | 0.333 | 111 | | bnn | 0.000 | 0.000 | 112 | | bns | 0.500 | 0.017 | 113 | | boa | 0.714 | 0.235 | 114 | | boj | 0.000 | 0.000 | 115 | | bor | 0.252 | 0.079 | 116 | | bpy | 0.149 | 0.056 | 117 | | bqi | 0.344 | 0.118 | 118 | | bra | 0.283 | 0.123 | 119 | | brg | 0.071 | 0.013 | 120 | | brh | 1.333 | 0.747 | 121 | | brx | 0.231 | 0.156 | 122 | | bsh | 0.600 | 0.298 | 123 | | bsk | 0.341 | 0.157 | 124 | | bsw | 0.000 | 0.000 | 125 | | bua | 0.243 | 0.088 | 126 | | bug | 0.250 | 0.105 | 127 | | bvx | 0.333 | 0.333 | 128 | | bwa | 0.250 | 0.250 | 129 | | bwp | 0.125 | 0.125 | 130 | | bwr | 0.197 | 0.123 | 131 | | bxd | 0.900 | 0.621 | 132 | | byx | 0.143 | 0.143 | 133 | | cag | 0.218 | 0.089 | 134 | | cak | 0.400 | 0.040 | 135 | | cal | 0.200 | 0.014 | 136 | | caq | 0.857 | 0.390 | 137 | | cbv | 0.333 | 0.057 | 138 | | ccc | 0.190 | 0.060 | 139 | | ccm | 0.400 | 0.064 | 140 | | cdm | 0.500 | 0.336 | 141 | | cgk | 0.500 | 0.441 | 142 | | cha | 0.127 | 0.044 | 143 | | chc | 0.250 | 0.134 | 144 | | chg | 0.200 | 0.200 | 145 | | chk | 0.344 | 0.117 | 146 | | chl | 0.167 | 0.089 | 147 | | chm | 0.143 | 0.009 | 148 | | chp | 0.000 | 0.000 | 149 | | chu | 0.000 | 0.000 | 150 | | chy | 0.444 | 0.201 | 151 | | cia | 0.000 | 0.000 | 152 | | cic | 0.435 | 0.092 | 153 | | cim | 0.000 | 0.000 | 154 | | cjs | 0.085 | 0.008 | 155 | | ckt | 0.367 | 0.028 | 156 | | ckv | 0.250 | 0.097 | 157 | | clk | 0.200 | 0.200 | 158 | | cog | 0.436 | 0.074 | 159 | | cpg | 0.200 | 0.103 | 160 | | cre | 0.231 | 0.080 | 161 | | crg | 0.592 | 0.245 | 162 | | crh | 0.500 | 0.026 | 163 | | cri | 0.113 | 0.034 | 164 | | cro | 0.549 | 0.209 | 165 | | csi | 0.000 | 0.000 | 166 | | csm | 0.203 | 0.119 | 167 | | cta | 0.500 | 0.500 | 168 | | ctg | 0.500 | 0.283 | 169 | | ctp | 0.381 | 0.257 | 170 | | cts | 0.000 | 0.000 | 171 | | cup | 0.158 | 0.046 | 172 | | cuq | 0.333 | 0.333 | 173 | | cwg | 0.000 | 0.000 | 174 | | dak | 0.190 | 0.074 | 175 | | dal | 0.250 | 0.173 | 176 | | ddo | 0.192 | 0.087 | 177 | | des | 0.000 | 0.000 | 178 | | din | 0.000 | 0.000 | 179 | | div | 1.250 | 1.026 | 180 | | dlg | 0.297 | 0.066 | 181 | | dni | 0.000 | 0.000 | 182 | | dta | 0.505 | 0.162 | 183 | | dua | 0.137 | 0.067 | 184 | | duf | 0.282 | 0.148 | 185 | | duk | 0.500 | 0.131 | 186 | | duo | 0.046 | 0.015 | 187 | | dus | 0.077 | 0.077 | 188 | | dyu | 0.000 | 0.000 | 189 | | eme | 0.286 | 0.030 | 190 | | ero | 0.333 | 0.034 | 191 | | eve | 0.276 | 0.021 | 192 | | evn | 0.324 | 0.041 | 193 | | ext | 0.079 | 0.038 | 194 | | fax | 0.500 | 0.026 | 195 | | fay | 0.222 | 0.073 | 196 | | fij | 0.312 | 0.101 | 197 | | fkv | 0.064 | 0.034 | 198 | | fos | 0.200 | 0.003 | 199 | | frm | 0.493 | 0.174 | 200 | | frp | 0.410 | 0.196 | 201 | | fud | 0.000 | 0.000 | 202 | | ful | 0.400 | 0.194 | 203 | | fur | 0.203 | 0.012 | 204 | | gaa | 0.221 | 0.081 | 205 | | gac | 0.083 | 0.003 | 206 | | gag | 0.222 | 0.056 | 207 | | gah | 0.400 | 0.221 | 208 | | gbb | 0.750 | 0.219 | 209 | | gdq | 0.405 | 0.077 | 210 | | geh | 0.500 | 0.254 | 211 | | gil | 0.176 | 0.064 | 212 | | gin | 0.238 | 0.133 | 213 | | giw | 1.000 | 1.000 | 214 | | gld | 0.000 | 0.000 | 215 | | gmh | 0.293 | 0.068 | 216 | | goi | 0.333 | 0.057 | 217 | | gor | 0.000 | 0.000 | 218 | | gqn | 0.111 | 0.058 | 219 | | gri | 0.000 | 0.000 | 220 | | grt | 0.086 | 0.058 | 221 | | gub | 0.111 | 0.011 | 222 | | gug | 0.342 | 0.115 | 223 | | gul | 0.400 | 0.207 | 224 | | gun | 0.571 | 0.240 | 225 | | gup | 0.295 | 0.124 | 226 | | gut | 0.000 | 0.000 | 227 | | gvf | 0.000 | 0.000 | 228 | | gvj | 0.000 | 0.000 | 229 | | gvp | 0.333 | 0.023 | 230 | | gwc | 0.205 | 0.057 | 231 | | gwe | 0.400 | 0.214 | 232 | | gwi | 0.000 | 0.000 | 233 | | haa | 1.333 | 0.836 | 234 | | hac | 0.000 | 0.000 | 235 | | hai | 0.333 | 0.072 | 236 | | hau | 0.390 | 0.089 | 237 | | hdn | 0.000 | 0.000 | 238 | | hil | 0.091 | 0.068 | 239 | | hoi | 0.500 | 0.040 | 240 | | hop | 0.500 | 0.124 | 241 | | hot | 0.000 | 0.000 | 242 | | hsb | 0.194 | 0.030 | 243 | | hto | 0.333 | 0.218 | 244 | | hup | 0.000 | 0.000 | 245 | | hux | 0.333 | 0.218 | 246 | | huz | 0.576 | 0.164 | 247 | | hvk | 0.000 | 0.000 | 248 | | iba | 0.217 | 0.107 | 249 | | ibo | 0.194 | 0.092 | 250 | | idb | 0.000 | 0.000 | 251 | | ing | 0.462 | 0.109 | 252 | | ipk | 0.000 | 0.000 | 253 | | irk | 0.500 | 0.119 | 254 | | ivv | 0.200 | 0.034 | 255 | | iws | 0.000 | 0.000 | 256 | | jaa | 0.216 | 0.151 | 257 | | jav | 0.375 | 0.027 | 258 | | jbt | 0.750 | 0.328 | 259 | | jdt | 0.478 | 0.115 | 260 | | jiv | 0.333 | 0.097 | 261 | | jra | 0.278 | 0.160 | 262 | | jur | 0.333 | 0.175 | 263 | | jut | 0.372 | 0.152 | 264 | | kaa | 0.286 | 0.078 | 265 | | kab | 0.222 | 0.100 | 266 | | kac | 0.267 | 0.037 | 267 | | kam | 0.192 | 0.081 | 268 | | kap | 0.145 | 0.034 | 269 | | kaw | 0.000 | 0.000 | 270 | | kay | 0.176 | 0.069 | 271 | | kca | 0.225 | 0.038 | 272 | | kcx | 0.000 | 0.000 | 273 | | kdd | 0.000 | 0.000 | 274 | | kea | 0.244 | 0.008 | 275 | | kgg | 0.619 | 0.413 | 276 | | kgp | 0.380 | 0.143 | 277 | | kht | 0.000 | 0.000 | 278 | | khv | 0.116 | 0.036 | 279 | | kij | 0.000 | 0.000 | 280 | | kim | 0.200 | 0.024 | 281 | | kin | 0.233 | 0.115 | 282 | | kip | 0.206 | 0.009 | 283 | | kjb | 0.250 | 0.048 | 284 | | kjg | 0.000 | 0.000 | 285 | | kjh | 0.219 | 0.027 | 286 | | kjl | 0.167 | 0.017 | 287 | | kjp | 0.727 | 0.638 | 288 | | kjz | 0.000 | 0.000 | 289 | | kky | 0.136 | 0.053 | 290 | | kla | 0.400 | 0.210 | 291 | | kls | 0.375 | 0.092 | 292 | | kmc | 0.615 | 0.435 | 293 | | kmg | 0.333 | 0.034 | 294 | | kmj | 1.000 | 0.103 | 295 | | kmv | 0.250 | 0.041 | 296 | | kne | 0.200 | 0.021 | 297 | | knx | 0.333 | 0.333 | 298 | | koi | 0.059 | 0.010 | 299 | | kos | 0.423 | 0.284 | 300 | | koy | 0.667 | 0.356 | 301 | | kpy | 0.250 | 0.015 | 302 | | kqy | 0.333 | 0.006 | 303 | | krc | 0.600 | 0.055 | 304 | | kre | 0.429 | 0.038 | 305 | | kru | 1.000 | 0.241 | 306 | | ksi | 0.000 | 0.000 | 307 | | kue | 0.250 | 0.088 | 308 | | kum | 0.375 | 0.045 | 309 | | kuu | 0.333 | 0.023 | 310 | | kwa | 1.333 | 0.718 | 311 | | kwk | 0.556 | 0.109 | 312 | | kxm | 0.500 | 0.069 | 313 | | kxo | 0.241 | 0.040 | 314 | | kyq | 0.200 | 0.200 | 315 | | kzg | 0.364 | 0.114 | 316 | | lac | 0.222 | 0.113 | 317 | | laq | 0.667 | 0.023 | 318 | | lbe | 0.091 | 0.003 | 319 | | led | 0.500 | 0.013 | 320 | | lez | 1.000 | 0.368 | 321 | | lha | 0.333 | 0.011 | 322 | | lhu | 1.429 | 0.608 | 323 | | lin | 0.000 | 0.000 | 324 | | lkt | 0.298 | 0.089 | 325 | | lld | 0.292 | 0.106 | 326 | | lmd | 0.750 | 0.026 | 327 | | lmn | 0.333 | 0.333 | 328 | | lnd | 0.122 | 0.045 | 329 | | loe | 0.286 | 0.286 | 330 | | lou | 0.071 | 0.038 | 331 | | lrc | 0.200 | 0.200 | 332 | | lti | 0.400 | 0.207 | 333 | | lug | 0.200 | 0.200 | 334 | | luo | 0.407 | 0.083 | 335 | | lus | 0.103 | 0.043 | 336 | | lut | 0.333 | 0.207 | 337 | | lzz | 0.121 | 0.050 | 338 | | mam | 0.000 | 0.000 | 339 | | maz | 0.750 | 0.267 | 340 | | mcm | 0.000 | 0.000 | 341 | | mco | 0.778 | 0.358 | 342 | | mei | 0.667 | 0.095 | 343 | | meo | 0.273 | 0.077 | 344 | | mep | 0.368 | 0.177 | 345 | | mer | 0.000 | 0.000 | 346 | | met | 0.074 | 0.038 | 347 | | mey | 0.667 | 0.391 | 348 | | mfh | 0.327 | 0.127 | 349 | | mgp | 0.600 | 0.203 | 350 | | mhj | 0.000 | 0.000 | 351 | | mhn | 0.150 | 0.017 | 352 | | mhu | 0.400 | 0.207 | 353 | | mia | 0.167 | 0.003 | 354 | | mih | 0.000 | 0.000 | 355 | | min | 0.000 | 0.000 | 356 | | miq | 0.000 | 0.000 | 357 | | miw | 0.200 | 0.014 | 358 | | mjc | 0.000 | 0.000 | 359 | | mlm | 0.625 | 0.498 | 360 | | mlp | 0.000 | 0.000 | 361 | | mlv | 1.000 | 0.159 | 362 | | mmc | 0.750 | 0.272 | 363 | | mne | 0.400 | 0.207 | 364 | | mng | 0.143 | 0.002 | 365 | | mnk | 0.420 | 0.097 | 366 | | mns | 0.472 | 0.178 | 367 | | moe | 0.571 | 0.022 | 368 | | moh | 0.257 | 0.177 | 369 | | mop | 0.667 | 0.351 | 370 | | mos | 0.165 | 0.042 | 371 | | mri | 0.438 | 0.071 | 372 | | msn | 0.119 | 0.073 | 373 | | mtd | 0.667 | 0.454 | 374 | | mtq | 0.130 | 0.012 | 375 | | mup | 0.000 | 0.000 | 376 | | mus | 0.500 | 0.128 | 377 | | mwl | 0.333 | 0.047 | 378 | | mxb | 0.212 | 0.080 | 379 | | mxi | 0.386 | 0.174 | 380 | | myp | 0.222 | 0.144 | 381 | | myv | 0.098 | 0.004 | 382 | | mzn | 0.424 | 0.160 | 383 | | mzq | 0.400 | 0.297 | 384 | | naq | 0.000 | 0.000 | 385 | | nau | 0.091 | 0.005 | 386 | | naz | 0.190 | 0.075 | 387 | | ncb | 2.000 | 1.134 | 388 | | nch | 0.191 | 0.068 | 389 | | ncj | 0.333 | 0.107 | 390 | | ngu | 0.168 | 0.052 | 391 | | nhe | 0.196 | 0.124 | 392 | | nhm | 0.158 | 0.073 | 393 | | nht | 0.000 | 0.000 | 394 | | nhv | 0.286 | 0.121 | 395 | | nhw | 0.118 | 0.037 | 396 | | nhy | 0.143 | 0.027 | 397 | | nia | 0.173 | 0.009 | 398 | | niq | 0.000 | 0.000 | 399 | | niu | 0.111 | 0.021 | 400 | | njo | 1.000 | 0.552 | 401 | | njz | 0.222 | 0.127 | 402 | | nlv | 0.178 | 0.066 | 403 | | nmc | 0.339 | 0.219 | 404 | | nnp | 0.667 | 0.368 | 405 | | nod | 0.000 | 0.000 | 406 | | noe | 0.000 | 0.000 | 407 | | nor | 0.435 | 0.121 | 408 | | ntj | 0.400 | 0.229 | 409 | | nus | 0.000 | 0.000 | 410 | | nuz | 0.214 | 0.070 | 411 | | nxn | 0.500 | 0.208 | 412 | | nxq | 0.957 | 0.403 | 413 | | oca | 0.250 | 0.250 | 414 | | odt | 0.333 | 0.006 | 415 | | ofu | 0.500 | 0.013 | 416 | | oge | 0.041 | 0.015 | 417 | | okn | 0.562 | 0.275 | 418 | | ole | 0.333 | 0.103 | 419 | | omc | 0.111 | 0.002 | 420 | | omx | 0.684 | 0.324 | 421 | | onu | 0.000 | 0.000 | 422 | | ood | 0.333 | 0.161 | 423 | | oon | 0.000 | 0.000 | 424 | | orv | 0.333 | 0.170 | 425 | | oss | 0.290 | 0.140 | 426 | | ote | 0.299 | 0.071 | 427 | | otq | 0.750 | 0.272 | 428 | | ott | 0.750 | 0.272 | 429 | | ovd | 0.474 | 0.144 | 430 | | owl | 0.429 | 0.028 | 431 | | pag | 0.103 | 0.040 | 432 | | pal | 0.222 | 0.133 | 433 | | pam | 0.105 | 0.059 | 434 | | pap | 0.143 | 0.012 | 435 | | pav | 0.158 | 0.144 | 436 | | pcd | 0.667 | 0.362 | 437 | | pdo | 0.400 | 0.297 | 438 | | pdt | 0.387 | 0.119 | 439 | | peh | 0.750 | 0.302 | 440 | | pei | 0.000 | 0.000 | 441 | | pis | 0.000 | 0.000 | 442 | | pli | 0.478 | 0.113 | 443 | | plk | 0.222 | 0.027 | 444 | | ply | 0.667 | 0.345 | 445 | | pnr | 0.250 | 0.147 | 446 | | pnw | 0.286 | 0.107 | 447 | | poi | 0.318 | 0.220 | 448 | | pos | 0.000 | 0.000 | 449 | | pre | 0.194 | 0.038 | 450 | | prg | 0.198 | 0.021 | 451 | | prk | 0.833 | 0.441 | 452 | | pro | 0.167 | 0.021 | 453 | | psi | 0.429 | 0.190 | 454 | | pua | 0.204 | 0.144 | 455 | | puw | 0.091 | 0.091 | 456 | | pwa | 0.000 | 0.000 | 457 | | pwn | 0.444 | 0.207 | 458 | | rad | 0.320 | 0.281 | 459 | | rah | 0.146 | 0.044 | 460 | | rak | 0.000 | 0.000 | 461 | | ram | 0.324 | 0.091 | 462 | | rap | 0.167 | 0.167 | 463 | | rel | 0.500 | 0.181 | 464 | | rgn | 0.245 | 0.102 | 465 | | rhg | 0.625 | 0.442 | 466 | | rif | 0.152 | 0.066 | 467 | | rjs | 0.133 | 0.007 | 468 | | rki | 0.206 | 0.090 | 469 | | rkt | 0.111 | 0.111 | 470 | | rme | 0.273 | 0.090 | 471 | | rmf | 0.128 | 0.061 | 472 | | rmo | 0.333 | 0.129 | 473 | | rmq | 0.000 | 0.000 | 474 | | rmt | 0.393 | 0.094 | 475 | | roh | 0.396 | 0.129 | 476 | | rom | 0.200 | 0.022 | 477 | | rpn | 0.400 | 0.028 | 478 | | rue | 0.137 | 0.005 | 479 | | ruo | 0.200 | 0.078 | 480 | | ruq | 0.170 | 0.009 | 481 | | sac | 0.000 | 0.000 | 482 | | sag | 0.000 | 0.000 | 483 | | scb | 0.154 | 0.011 | 484 | | scl | 0.605 | 0.254 | 485 | | sdh | 0.174 | 0.011 | 486 | | sdn | 0.000 | 0.000 | 487 | | sea | 0.077 | 0.008 | 488 | | see | 0.417 | 0.170 | 489 | | sgs | 0.400 | 0.081 | 490 | | shs | 0.833 | 0.303 | 491 | | shy | 0.000 | 0.000 | 492 | | sjd | 0.250 | 0.004 | 493 | | sje | 0.376 | 0.106 | 494 | | sjt | 0.000 | 0.000 | 495 | | skc | 0.167 | 0.011 | 496 | | skr | 0.333 | 0.023 | 497 | | sma | 0.250 | 0.129 | 498 | | smk | 0.000 | 0.000 | 499 | | smn | 0.250 | 0.056 | 500 | | smo | 0.125 | 0.016 | 501 | | sna | 0.250 | 0.022 | 502 | | som | 0.897 | 0.574 | 503 | | spp | 1.000 | 0.338 | 504 | | squ | 0.143 | 0.039 | 505 | | ssf | 0.400 | 0.060 | 506 | | ssw | 0.143 | 0.143 | 507 | | stp | 0.625 | 0.522 | 508 | | str | 0.423 | 0.143 | 509 | | stw | 1.000 | 0.522 | 510 | | sun | 0.800 | 0.441 | 511 | | suy | 0.419 | 0.092 | 512 | | sva | 0.221 | 0.062 | 513 | | swb | 0.138 | 0.002 | 514 | | swi | 0.333 | 0.333 | 515 | | szl | 0.246 | 0.076 | 516 | | szw | 0.143 | 0.143 | 517 | | taa | 0.615 | 0.191 | 518 | | tae | 0.333 | 0.006 | 519 | | tah | 0.000 | 0.000 | 520 | | taj | 0.000 | 0.000 | 521 | | tat | 0.133 | 0.014 | 522 | | tay | 0.000 | 0.000 | 523 | | taz | 1.000 | 0.034 | 524 | | tbk | 0.000 | 0.000 | 525 | | tet | 0.250 | 0.041 | 526 | | tfr | 0.000 | 0.000 | 527 | | tji | 0.571 | 0.170 | 528 | | tjs | 0.455 | 0.303 | 529 | | tkn | 0.316 | 0.213 | 530 | | tkw | 0.000 | 0.000 | 531 | | tli | 0.600 | 0.228 | 532 | | top | 0.474 | 0.137 | 533 | | tpi | 0.040 | 0.005 | 534 | | tqw | 0.429 | 0.177 | 535 | | trc | 0.333 | 0.333 | 536 | | trp | 0.320 | 0.127 | 537 | | trq | 1.000 | 0.505 | 538 | | tru | 0.511 | 0.352 | 539 | | trw | 1.000 | 0.069 | 540 | | tsd | 0.360 | 0.171 | 541 | | tsi | 0.235 | 0.178 | 542 | | tts | 0.000 | 0.000 | 543 | | tvl | 0.000 | 0.000 | 544 | | txb | 0.125 | 0.004 | 545 | | txu | 0.272 | 0.109 | 546 | | tyj | 0.106 | 0.003 | 547 | | uar | 0.250 | 0.250 | 548 | | ubl | 0.200 | 0.003 | 549 | | uby | 0.000 | 0.000 | 550 | | ude | 0.114 | 0.037 | 551 | | udm | 0.000 | 0.000 | 552 | | ulc | 0.197 | 0.040 | 553 | | umu | 0.625 | 0.203 | 554 | | unr | 0.400 | 0.034 | 555 | | urb | 0.167 | 0.109 | 556 | | urt | 0.250 | 0.250 | 557 | | uur | 0.000 | 0.000 | 558 | | uzb | 0.317 | 0.095 | 559 | | vam | 0.250 | 0.250 | 560 | | vav | 0.357 | 0.038 | 561 | | vls | 0.177 | 0.071 | 562 | | vma | 0.258 | 0.111 | 563 | | vmb | 0.000 | 0.000 | 564 | | vot | 1.000 | 0.017 | 565 | | vro | 0.111 | 0.040 | 566 | | wam | 0.357 | 0.121 | 567 | | war | 0.067 | 0.050 | 568 | | wbl | 1.500 | 0.828 | 569 | | wbp | 0.314 | 0.122 | 570 | | wer | 0.067 | 0.005 | 571 | | wgy | 0.400 | 0.229 | 572 | | whg | 0.000 | 0.000 | 573 | | win | 0.500 | 0.500 | 574 | | wln | 0.555 | 0.259 | 575 | | woe | 0.375 | 0.284 | 576 | | wrg | 0.068 | 0.049 | 577 | | wth | 0.091 | 0.091 | 578 | | wwo | 0.000 | 0.000 | 579 | | wym | 0.346 | 0.097 | 580 | | xaa | 0.500 | 0.293 | 581 | | xal | 0.293 | 0.145 | 582 | | xav | 0.684 | 0.247 | 583 | | xbc | 0.535 | 0.288 | 584 | | xcl | 0.000 | 0.000 | 585 | | xdk | 0.294 | 0.115 | 586 | | xer | 0.417 | 0.032 | 587 | | xfa | 0.182 | 0.053 | 588 | | xib | 0.000 | 0.000 | 589 | | xkz | 0.632 | 0.195 | 590 | | xmf | 0.060 | 0.004 | 591 | | xmz | 0.400 | 0.283 | 592 | | xng | 0.750 | 0.293 | 593 | | xnr | 1.500 | 1.052 | 594 | | xok | 0.000 | 0.000 | 595 | | xpo | 0.224 | 0.198 | 596 | | xra | 0.500 | 0.032 | 597 | | xsm | 0.167 | 0.006 | 598 | | xsr | 0.786 | 0.387 | 599 | | xsy | 0.091 | 0.091 | 600 | | xtm | 0.273 | 0.174 | 601 | | xto | 0.667 | 0.667 | 602 | | yag | 0.133 | 0.067 | 603 | | yai | 0.600 | 0.100 | 604 | | yap | 0.667 | 0.125 | 605 | | ybh | 0.333 | 0.333 | 606 | | yej | 0.385 | 0.385 | 607 | | ykg | 0.118 | 0.057 | 608 | | yly | 0.714 | 0.200 | 609 | | yox | 0.667 | 0.380 | 610 | | yrl | 0.429 | 0.429 | 611 | | ysn | 1.000 | 1.000 | 612 | | yuj | 0.250 | 0.250 | 613 | | yur | 0.667 | 0.060 | 614 | | yuy | 0.281 | 0.072 | 615 | | zag | 0.250 | 0.009 | 616 | | zai | 0.175 | 0.063 | 617 | | zbt | 0.286 | 0.286 | 618 | | zku | 0.444 | 0.280 | 619 | | zoh | 0.333 | 0.011 | 620 | | zor | 0.000 | 0.000 | 621 | | zos | 0.300 | 0.040 | 622 | | zpq | 0.444 | 0.188 | 623 | | ztm | 0.333 | 0.029 | 624 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy 2 | numpy 3 | panphon 4 | torch 5 | editdistance 6 | requests 7 | tqdm 8 | pyyaml 9 | phonepiece 10 | epitran 11 | unidecode 12 | mecab-python3 13 | cmudict -------------------------------------------------------------------------------- /sample.txt: -------------------------------------------------------------------------------- 1 | hello 2 | world 3 | transphone 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup,find_packages 2 | 3 | requirements=[r.strip() for r in open("requirements.txt").readlines()] 4 | 5 | setup( 6 | name='transphone', 7 | version='1.5.3', 8 | description='a multilingual g2p/p2g model', 9 | author='Xinjian Li', 10 | author_email='xinjianl@cs.cmu.edu', 11 | url="https://github.com/xinjli/transphone", 12 | packages=find_packages(), 13 | package_data={'': ['*.csv', '*.tsv', '*.yml']}, 14 | install_requires=requirements, 15 | ) 16 | -------------------------------------------------------------------------------- /test/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from transphone.tokenizer import read_tokenizer 3 | 4 | 5 | class TestTokenizer(unittest.TestCase): 6 | 7 | def test_eng_tokenizer(self): 8 | eng = read_tokenizer('eng') 9 | 10 | self.assertEqual(eng.tokenize_words('hello world'), ['hello', 'world']) 11 | self.assertEqual(eng.tokenize_words('hello world!! '), ['hello', 'world']) 12 | self.assertEqual(eng.tokenize('hello world'), ['h', 'ʌ', 'l', 'o', 'w', 'w', 'ɹ̩', 'l', 'd']) 13 | 14 | def test_jpn_tokenizer(self): 15 | jpn = read_tokenizer('jpn') 16 | self.assertEqual(jpn.tokenize_words('こんにちは世界'), ['こんにちは', '世界']) 17 | self.assertEqual(jpn.tokenize('こんにちは世界'), ['k', 'o', 'N', 'n', 'i', 'ch', 'i', 'w', 'a', 's', 'e', 'k', 'a', 'i']) 18 | self.assertEqual(jpn.tokenize('2023年'), ['n', 'i', 's', 'e', 'N', 'n', 'i', 'j', 'u', 'u', 's', 'a', 'N', 'n', 'e', 'N']) 19 | self.assertEqual(jpn.tokenize('UTFとABC。'), ['y', 'uː', 't', 'iː', 'e', 'f', 'u', 't', 'o', 'eː', 'b', 'iː', 'sh', 'iː']) 20 | 21 | def test_spa_tokenizer(self): 22 | 23 | spa = read_tokenizer('spa') 24 | self.assertEqual(spa.tokenize('hola hola'), ['o', 'l', 'a', 'o', 'l', 'a']) 25 | self.assertEqual(spa.tokenize('español'), ['e', 's', 'p', 'a', 'ɲ', 'o', 'l']) 26 | 27 | spa = read_tokenizer('spa', use_lexicon=False) 28 | self.assertEqual(spa.tokenize('hola hola'), ['o', 'l', 'a', 'o', 'l', 'a']) 29 | self.assertEqual(spa.tokenize('español'), ['e', 's', 'p', 'a', 'ɲ', 'o', 'l']) 30 | 31 | def test_fra_tokenizer(self): 32 | 33 | fra = read_tokenizer('fra') 34 | self.assertEqual(fra.tokenize('français'), ['f', 'ʁ', 'ɑ̃', 's', 'ɛ']) 35 | 36 | fra = read_tokenizer('fra', use_lexicon=False) 37 | self.assertEqual(fra.tokenize('français'), ['f', 'ʁ', 'ɑ̃', 's', 'ɛ']) 38 | 39 | def test_deu_tokenizer(self): 40 | 41 | deu = read_tokenizer('deu') 42 | self.assertEqual(deu.tokenize('Deutsche'), ['d', 'o', 'i', 't͡ʃ', 'ə']) 43 | 44 | deu = read_tokenizer('deu', use_lexicon=False) 45 | self.assertEqual(deu.tokenize('Deutsche'),['d', 'o', 'i', 't͡ʃ', 'ə']) 46 | 47 | def test_ita_tokenizer(self): 48 | ita = read_tokenizer('ita') 49 | self.assertEqual(ita.tokenize('Italia'), ['i', 't', 'a', 'l', 'j', 'a']) 50 | 51 | # g2p is slightly different here 52 | ita = read_tokenizer('ita', use_lexicon=False) 53 | self.assertEqual(ita.tokenize('Italia'), ['i', 't', 'a', 'l', 'i', 'a']) 54 | 55 | def test_tur_tokenizer(self): 56 | 57 | tur = read_tokenizer('tur', use_lexicon=False) 58 | self.assertEqual(tur.tokenize('Türkçe'), ['t', 'y', 'ɾ', 'k', 't͡ʃ', 'e']) 59 | 60 | tur = read_tokenizer('tur') 61 | self.assertEqual(tur.tokenize('Türkçe'), ['t', 'y', 'ɾ', 'k', 't͡ʃ', 'e']) 62 | -------------------------------------------------------------------------------- /transphone/__init__.py: -------------------------------------------------------------------------------- 1 | from transphone.tokenizer import read_tokenizer -------------------------------------------------------------------------------- /transphone/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinjli/transphone/7769dec083b90dab072fc7e4c592744592a35736/transphone/bin/__init__.py -------------------------------------------------------------------------------- /transphone/bin/download_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import tarfile 3 | from urllib.request import urlopen 4 | import io 5 | import argparse 6 | import os 7 | from transphone.config import TransphoneConfig 8 | from transphone.model.utils import resolve_model_name 9 | 10 | 11 | def download_model(model_name=None, alt_model_path=None): 12 | 13 | if model_name is None: 14 | model_name = 'latest' 15 | if alt_model_path: 16 | model_dir = alt_model_path 17 | else: 18 | model_dir = TransphoneConfig.data_path / 'model' 19 | model_dir.mkdir(parents=True, exist_ok=True) 20 | 21 | model_name = resolve_model_name(model_name) 22 | 23 | if not (model_dir / model_name).exists(): 24 | 25 | try: 26 | url = 'https://github.com/xinjli/transphone/releases/download/v1.0/' + model_name + '.tar.gz' 27 | print("downloading model ", model_name) 28 | print("from: ", url) 29 | print("to: ", str(model_dir)) 30 | print("please wait...") 31 | resp = urlopen(url) 32 | compressed_files = io.BytesIO(resp.read()) 33 | files = tarfile.open(fileobj=compressed_files) 34 | files.extractall(str(model_dir)) 35 | 36 | except Exception as e: 37 | print("Error: could not download the model", e) 38 | (model_dir / model_name).rmdir() 39 | 40 | 41 | if __name__ == '__main__': 42 | 43 | parser = argparse.ArgumentParser('a utility to download pretrained models') 44 | parser.add_argument('-m', '--model', default='latest', help='specify which model to download. A list of downloadable models are available on Github') 45 | 46 | args = parser.parse_args() 47 | 48 | download_model(args.model) -------------------------------------------------------------------------------- /transphone/bin/eval_epitran.py: -------------------------------------------------------------------------------- 1 | from transphone.tokenizer import raw_epitran_dict 2 | from transphone.lang.epitran_tokenizer import read_raw_epitran_tokenizer 3 | from phonepiece.lexicon import read_lexicon 4 | from phonepiece.distance import phonological_distance 5 | import editdistance 6 | from transphone.model.checkpoint_utils import * 7 | import argparse 8 | 9 | def eval_epitran(langs=None, exclude_langs=None): 10 | 11 | exp_dir = TransphoneConfig.data_path / 'decode' / 'epitran' 12 | exp_dir.mkdir(exist_ok=True, parents=True) 13 | 14 | log_w = open(exp_dir / f'result.md', 'w') 15 | 16 | log_w.write('| language | phoneme error rate | phonological distance |\n') 17 | log_w.write('|----------|--------------------|-----------------------|\n') 18 | 19 | if langs is not None and len(langs) != 0: 20 | target_langs = langs 21 | elif exclude_langs is not None and len(exclude_langs) != 0: 22 | target_langs = [] 23 | for lang in list(raw_epitran_dict.keys()): 24 | if lang not in exclude_langs: 25 | target_langs.append(lang) 26 | else: 27 | target_langs = list(raw_epitran_dict.keys()) 28 | 29 | tot_cer = 0 30 | tot_fer = 0 31 | tot_csum = 0 32 | 33 | for lang in target_langs: 34 | cer = 0 35 | fer = 0 36 | csum = 0 37 | 38 | print("processing ", lang) 39 | 40 | epitran_id = raw_epitran_dict[lang] 41 | 42 | try: 43 | lexicon = read_lexicon(lang) 44 | except: 45 | print("could not read lexicon of ", lang) 46 | continue 47 | 48 | if len(lexicon) < 50: 49 | print("skipping ", lang) 50 | continue 51 | 52 | w = open(exp_dir / f'{lang}.txt', 'w') 53 | 54 | # at most 10000 to prevent overfitting 55 | word2phoneme_lst = list(lexicon.word2phoneme.items()) 56 | 57 | try: 58 | epitran_tokenizer = read_raw_epitran_tokenizer(epitran_id, use_lexicon=False) 59 | except: 60 | print('failed to read epitran', lang) 61 | 62 | # last 10 for validation 63 | for grapheme_str, phonemes in word2phoneme_lst[-25:]: 64 | hyp_phonemes = epitran_tokenizer.tokenize(grapheme_str) 65 | w.write(f'lang: {lang}\n') 66 | w.write(f'inp : {grapheme_str}\n') 67 | w.write(f'ref : {" ".join(phonemes)}\n') 68 | w.write(f'hyp : {" ".join(hyp_phonemes)}\n') 69 | 70 | cur_cer = editdistance.distance(phonemes, hyp_phonemes) 71 | cur_fer = phonological_distance(phonemes, hyp_phonemes) 72 | cur_sum = len(phonemes) 73 | w.write(f'cer: {cur_cer/cur_sum}\n') 74 | w.write(f'fer: {cur_fer/cur_sum}\n\n') 75 | 76 | cer += cur_cer 77 | fer += cur_fer 78 | csum += cur_sum 79 | 80 | w.close() 81 | 82 | tot_cer += cer 83 | tot_fer += fer 84 | tot_csum += csum 85 | 86 | log_w.write(f'| {lang} | {cer / csum:.3f} | {fer / csum:.3f} |\n') 87 | 88 | print('lang :', lang) 89 | print("val cer: ", cer / csum) 90 | print("val fer: ", fer / csum) 91 | 92 | print('all') 93 | print("val cer: ", tot_cer / tot_csum) 94 | print("val fer: ", tot_fer / tot_csum) 95 | log_w.close() 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser(description='eval epitran g2p') 100 | parser.add_argument('--lang', type=str, help='language') 101 | parser.add_argument('--exclude_lang', type=str, help='excluded_language') 102 | 103 | args = parser.parse_args() 104 | 105 | lang = set() 106 | exclude_lang = set() 107 | 108 | if args.lang is not None: 109 | lang.udpate(args.lang.split(',')) 110 | if args.exclude_lang is not None: 111 | exclude_lang.update(args.exclude_lang.split(',')) 112 | 113 | eval_epitran(lang, exclude_lang) 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /transphone/bin/eval_g2p.py: -------------------------------------------------------------------------------- 1 | from transphone.model.dataset import read_test_dataset 2 | from phonepiece.distance import phonological_distance 3 | import editdistance 4 | import tqdm 5 | from transphone.model.checkpoint_utils import * 6 | import argparse 7 | from transphone.g2p import read_g2p 8 | 9 | 10 | def get_decode_dir(exp_name, checkpoint, ensemble): 11 | 12 | decode_dir = TransphoneConfig.data_path / 'decode' / 'dev' 13 | 14 | perf = checkpoint.split('/')[-1].split('.')[1] 15 | 16 | # create test directory 17 | exp_dir = decode_dir / f"{exp_name}-{perf}-{ensemble}" 18 | return exp_dir 19 | 20 | 21 | def eval_test(exp, checkpoint, device, ensemble): 22 | 23 | exp_dir = get_decode_dir(exp, checkpoint, ensemble) 24 | exp_dir.mkdir(exist_ok=True, parents=True) 25 | 26 | test_grapheme_lst, test_phoneme_lst, test_lang_lst = read_test_dataset(exp) 27 | print("test size: ", len(test_lang_lst)) 28 | 29 | model = read_g2p(exp, device, checkpoint=checkpoint) 30 | 31 | tot_err = 0 32 | tot_dst = 0 33 | tot_csum = 0 34 | 35 | exp_dir = get_decode_dir(exp, checkpoint, ensemble) 36 | exp_dir.mkdir(exist_ok=True, parents=True) 37 | log_w = open(exp_dir / 'result.md', 'w') 38 | log_w.write('| language | phoneme error rate | phonological distance |\n') 39 | log_w.write('|----------|--------------------|-----------------------|\n') 40 | 41 | for grapheme_lst, phonemes_lst, lang_id in tqdm.tqdm(zip(test_grapheme_lst, test_phoneme_lst, test_lang_lst)): 42 | 43 | err = 0 44 | dst = 0 45 | csum = 0 46 | 47 | w = open(exp_dir / f'{lang_id}.txt', 'w') 48 | 49 | for grapheme, phonemes in zip(grapheme_lst, phonemes_lst): 50 | 51 | predicted = model.inference_word_batch(grapheme, lang_id=lang_id, num_lang=1, force_approximate=False) 52 | 53 | w.write(f'lang: {lang_id}\n') 54 | w.write(f'inp : {grapheme}\n') 55 | w.write(f'ref : {" ".join(phonemes)}\n') 56 | w.write(f'hyp : {" ".join(predicted)}\n') 57 | cur_err = editdistance.distance(phonemes, predicted) 58 | cur_dst = phonological_distance(phonemes, predicted) 59 | cur_sum = len(phonemes) 60 | w.write(f'err: {cur_err / cur_sum}\n') 61 | w.write(f'dst: {cur_dst / cur_sum}\n\n') 62 | err += cur_err 63 | dst += cur_dst 64 | csum += cur_sum 65 | 66 | w.close() 67 | 68 | tot_err += err 69 | tot_dst += dst 70 | tot_csum += csum 71 | 72 | log_w.write(f'| {lang_id} | {err / csum:.3f} | {dst / csum:.3f} |\n') 73 | 74 | print('lang :', lang_id) 75 | print("val cer: ", err / csum) 76 | print("val dst: ", dst / csum) 77 | 78 | print('all') 79 | print("val err: ", tot_err / tot_csum) 80 | print("val dst: ", tot_dst / tot_csum) 81 | log_w.close() 82 | 83 | 84 | if __name__ == '__main__': 85 | 86 | parser = argparse.ArgumentParser(description='eval g2p') 87 | parser.add_argument('--exp', type=str, help='exp') 88 | parser.add_argument('--checkpoint', type=str, help='checkpoint') 89 | parser.add_argument('--lang', type=str) 90 | parser.add_argument('--device', type=str, default='cuda') 91 | parser.add_argument('--ensemble', type=int, default=1) 92 | 93 | args = parser.parse_args() 94 | 95 | checkpoint = args.checkpoint 96 | exp = args.exp 97 | device = args.device 98 | ensemble = args.ensemble 99 | 100 | if exp is None: 101 | exp = Path(checkpoint).parent.stem 102 | 103 | if checkpoint is None: 104 | model_path = TransphoneConfig.data_path / 'model' / exp 105 | if (model_path / "model.pt").exists(): 106 | checkpoint = model_path / "model.pt" 107 | else: 108 | target_model = find_topk_models(model_path)[0] 109 | print("using model ", target_model) 110 | checkpoint = target_model 111 | 112 | eval_test(exp, checkpoint, device, ensemble) 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /transphone/bin/eval_zsl_g2p.py: -------------------------------------------------------------------------------- 1 | from transphone.model.dataset import read_zsl_dataset 2 | from transphone.model.utils import read_model_config 3 | from transphone.g2p import read_g2p 4 | import editdistance 5 | import tqdm 6 | from transphone.model.checkpoint_utils import * 7 | from phonepiece.distance import phonological_distance 8 | import argparse 9 | 10 | 11 | def get_decode_dir(exp_name, checkpoint, ensemble): 12 | 13 | decode_dir = TransphoneConfig.data_path / 'decode' / 'zsl' 14 | 15 | perf = str(checkpoint).split('/')[-1].split('.')[1] 16 | 17 | # create test directory 18 | exp_dir = decode_dir / f"{exp_name}-{perf}-{ensemble}" 19 | return exp_dir 20 | 21 | 22 | def eval_zsl_test(exp, checkpoint, ensemble, device): 23 | 24 | exp_dir = get_decode_dir(exp, checkpoint, ensemble) 25 | exp_dir.mkdir(exist_ok=True, parents=True) 26 | 27 | test_grapheme_lst, test_phoneme_lst, test_lang_lst = read_zsl_dataset(exp) 28 | print("test size: ", len(test_lang_lst)) 29 | 30 | model = read_g2p(exp, device, checkpoint=checkpoint) 31 | 32 | tot_err = 0 33 | tot_dst = 0 34 | tot_csum = 0 35 | 36 | exp_dir = get_decode_dir(exp, checkpoint, ensemble) 37 | exp_dir.mkdir(exist_ok=True, parents=True) 38 | log_w = open(exp_dir / 'result.md', 'w') 39 | log_w.write('| language | phoneme error rate | phonological distance |\n') 40 | log_w.write('|----------|--------------------|-----------------------|\n') 41 | 42 | for grapheme_lst, phonemes_lst, lang_id in tqdm.tqdm(zip(test_grapheme_lst, test_phoneme_lst, test_lang_lst)): 43 | 44 | err = 0 45 | dst = 0 46 | csum = 0 47 | 48 | w = open(exp_dir / f'{lang_id}.txt', 'w') 49 | 50 | for grapheme, phonemes in zip(grapheme_lst, phonemes_lst): 51 | 52 | predicted = model.inference_word_batch(grapheme, lang_id=lang_id, num_lang=ensemble, force_approximate=True) 53 | 54 | w.write(f'lang: {lang_id}\n') 55 | w.write(f'inp : {grapheme}\n') 56 | w.write(f'ref : {" ".join(phonemes)}\n') 57 | w.write(f'hyp : {" ".join(predicted)}\n') 58 | cur_err = editdistance.distance(phonemes, predicted) 59 | cur_dst = phonological_distance(phonemes, predicted) 60 | cur_sum = len(phonemes) 61 | w.write(f'err: {cur_err / cur_sum}\n') 62 | w.write(f'dst: {cur_dst / cur_sum}\n\n') 63 | err += cur_err 64 | dst += cur_dst 65 | csum += cur_sum 66 | 67 | w.close() 68 | 69 | tot_err += err 70 | tot_dst += dst 71 | tot_csum += csum 72 | 73 | if csum > 0: 74 | log_w.write(f'| {lang_id} | {err / csum:.3f} | {dst / csum:.3f} |\n') 75 | print('lang :', lang_id) 76 | print("val cer: ", err / csum) 77 | print("val dst: ", dst / csum) 78 | 79 | print('all') 80 | print("val err: ", tot_err / tot_csum) 81 | print("val dst: ", tot_dst / tot_csum) 82 | log_w.close() 83 | 84 | 85 | if __name__ == '__main__': 86 | 87 | parser = argparse.ArgumentParser(description='eval zero-shot learning g2p') 88 | parser.add_argument('--exp', type=str, help='exp') 89 | parser.add_argument('--checkpoint', type=str, help='checkpoint') 90 | parser.add_argument('--device', type=str, default='cuda') 91 | parser.add_argument('--ensemble', type=int, default=10) 92 | 93 | args = parser.parse_args() 94 | 95 | checkpoint = args.checkpoint 96 | exp = args.exp 97 | device = args.device 98 | ensemble = args.ensemble 99 | 100 | 101 | if exp is None: 102 | exp = Path(checkpoint).parent.stem 103 | 104 | if checkpoint is None: 105 | model_path = TransphoneConfig.data_path / 'model' / exp 106 | if (model_path / "model.pt").exists(): 107 | checkpoint = model_path / "model.pt" 108 | else: 109 | target_model = find_topk_models(model_path)[0] 110 | print("using model ", target_model) 111 | checkpoint = target_model 112 | 113 | eval_zsl_test(exp, checkpoint, ensemble, device) 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /transphone/bin/g2p.py: -------------------------------------------------------------------------------- 1 | from transphone.bin.download_model import download_model 2 | from transphone.model.utils import resolve_model_name 3 | from transphone.tokenizer import read_tokenizer 4 | from pathlib import Path 5 | from transphone.g2p import read_g2p 6 | import argparse 7 | import tqdm 8 | 9 | def run_g2p(model_name, lang, input, output, checkpoint=None, file_format='text', combine=False, ensemble=10, force_approximate=False): 10 | 11 | # download specified model automatically if no model exists 12 | download_model(model_name) 13 | 14 | # create model 15 | model = read_g2p(model_name, checkpoint=checkpoint) 16 | 17 | # output file descriptor 18 | output_fd = None 19 | if output != 'stdout': 20 | output_fd = open(output, 'w', encoding='utf-8') 21 | 22 | # input file/path 23 | input_path = Path(input) 24 | 25 | for line in tqdm.tqdm(open(input_path, 'r').readlines(), disable=output=='stdout'): 26 | fields = line.strip().split() 27 | 28 | utt_id = None 29 | 30 | if file_format == 'text': 31 | text = ' '.join(fields) 32 | else: 33 | text = ' '.join(fields[1:]) 34 | utt_id = fields[0] 35 | 36 | phonemes = model.inference_word_batch(text, lang_id=lang, num_lang=ensemble, force_approximate=force_approximate) 37 | 38 | line_output = ' '.join(phonemes) 39 | 40 | if combine: 41 | line_output = text + '\t' + line_output 42 | 43 | if utt_id is not None: 44 | line_output = utt_id + ' ' + line_output 45 | 46 | if output_fd: 47 | output_fd.write(line_output + '\n') 48 | else: 49 | print(line_output) 50 | 51 | if output_fd: 52 | output_fd.close() 53 | 54 | 55 | 56 | if __name__ == '__main__': 57 | 58 | parser = argparse.ArgumentParser('running transphone g2p model') 59 | parser.add_argument('-m', '--model', type=str, default='latest', 60 | help='specify which model to use. default is to use the latest local model') 61 | parser.add_argument('-l', '--lang', type=str, default='eng', 62 | help='specify which language inventory to use for recognition. default is to use all phone inventory') 63 | parser.add_argument('-i', '--input', type=str, required=True, help='specify your input wav file/directory') 64 | parser.add_argument('-o', '--output', type=str, default='stdout', 65 | help='specify output file. the default will be stdout') 66 | parser.add_argument('-f', '--format', type=str, default='text', help='kaldi or text') 67 | parser.add_argument('-c', '--combine', type=bool, default=False, 68 | help='write outputs by including both grapheme inputs and phonemes in the same line, delimited by space') 69 | 70 | args = parser.parse_args() 71 | 72 | # resolve model's name 73 | model_name = resolve_model_name(args.model) 74 | 75 | # format 76 | file_format = args.format 77 | 78 | if args.combine: 79 | assert file_format == 'text' 80 | 81 | run_g2p(model_name, args.lang, args.input, args.output, file_format, args.combine) -------------------------------------------------------------------------------- /transphone/bin/tokenize.py: -------------------------------------------------------------------------------- 1 | from transphone.bin.download_model import download_model 2 | from transphone.model.utils import resolve_model_name 3 | from transphone.tokenizer import read_tokenizer 4 | from pathlib import Path 5 | import argparse 6 | import tqdm 7 | 8 | def tokenize(model_name, lang, input, output, file_format='text', combine=False): 9 | 10 | # download specified model automatically if no model exists 11 | download_model(model_name) 12 | 13 | # create model 14 | tokenizer = read_tokenizer(lang, g2p_model=model_name) 15 | 16 | # output file descriptor 17 | output_fd = None 18 | if output != 'stdout': 19 | output_fd = open(output, 'w', encoding='utf-8') 20 | 21 | # input file/path 22 | input_path = Path(input) 23 | 24 | for line in tqdm.tqdm(open(input_path, 'r').readlines(), disable=output=='stdout'): 25 | fields = line.strip().split() 26 | 27 | utt_id = None 28 | 29 | if file_format == 'text': 30 | text = ' '.join(fields) 31 | else: 32 | text = ' '.join(fields[1:]) 33 | utt_id = fields[0] 34 | 35 | phonemes = tokenizer.tokenize(text) 36 | line_output = ' '.join(phonemes) 37 | 38 | if combine: 39 | line_output = text + '\t' + line_output 40 | 41 | if utt_id is not None: 42 | line_output = utt_id + ' ' + line_output 43 | 44 | if output_fd: 45 | output_fd.write(line_output + '\n') 46 | else: 47 | print(line_output) 48 | 49 | if output_fd: 50 | output_fd.close() 51 | 52 | 53 | 54 | if __name__ == '__main__': 55 | 56 | parser = argparse.ArgumentParser('running transphone g2p model') 57 | parser.add_argument('-m', '--model', type=str, default='latest', 58 | help='specify which model to use. default is to use the latest local model') 59 | parser.add_argument('-l', '--lang', type=str, default='eng', 60 | help='specify which language inventory to use for recognition. default is to use all phone inventory') 61 | parser.add_argument('-i', '--input', type=str, required=True, help='specify your input wav file/directory') 62 | parser.add_argument('-o', '--output', type=str, default='stdout', 63 | help='specify output file. the default will be stdout') 64 | parser.add_argument('-f', '--format', type=str, default='text', help='kaldi or text') 65 | parser.add_argument('-c', '--combine', type=bool, default=False, 66 | help='write outputs by including both grapheme inputs and phonemes in the same line, delimited by space') 67 | 68 | args = parser.parse_args() 69 | 70 | # resolve model's name 71 | model_name = resolve_model_name(args.model) 72 | 73 | # format 74 | file_format = args.format 75 | 76 | if args.combine: 77 | assert file_format == 'text' 78 | 79 | tokenize(model_name, args.lang, args.input, args.output, file_format, args.combine) -------------------------------------------------------------------------------- /transphone/bin/train_g2p.py: -------------------------------------------------------------------------------- 1 | from transphone.model.loader import read_loader 2 | from transphone.model.dataset import read_multilingual_dataset 3 | from transphone.model.transformer import TransformerG2P 4 | from transphone.model.utils import read_model_config 5 | import torch.nn as nn 6 | from torch.utils.tensorboard import SummaryWriter 7 | import editdistance 8 | import tqdm 9 | from transphone.model.checkpoint_utils import * 10 | import argparse 11 | 12 | 13 | def train(exp, checkpoint): 14 | 15 | config = read_model_config(exp) 16 | 17 | model_dir = TransphoneConfig.data_path / "model" / exp 18 | model_dir.mkdir(parents=True, exist_ok=True) 19 | 20 | train_dataset, test_dataset = read_multilingual_dataset() 21 | SRC_VOCAB_SIZE = len(train_dataset.grapheme_vocab)+1 22 | TGT_VOCAB_SIZE = len(train_dataset.phoneme_vocab)+1 23 | 24 | print("src vocab size ", SRC_VOCAB_SIZE) 25 | print("tgt vocab size ", TGT_VOCAB_SIZE) 26 | 27 | train_dataset.grapheme_vocab.write(model_dir / 'grapheme.vocab') 28 | train_dataset.phoneme_vocab.write(model_dir / 'phoneme.vocab') 29 | 30 | EMB_SIZE = config.embed_size 31 | NHEAD = config.num_head 32 | FFN_HID_DIM = config.hidden_size 33 | NUM_ENCODER_LAYERS = config.num_encoder 34 | NUM_DECODER_LAYERS = config.num_decoder 35 | torch.manual_seed(0) 36 | 37 | model = TransformerG2P(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, 38 | NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM).cuda() 39 | 40 | 41 | for p in model.parameters(): 42 | if p.dim() > 1: 43 | nn.init.xavier_uniform_(p) 44 | 45 | if checkpoint is not None: 46 | torch_load(model, checkpoint) 47 | 48 | opt = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9) 49 | 50 | print('training size: ', len(train_dataset)) 51 | print("testing size: ", len(test_dataset)) 52 | 53 | train_loader = read_loader(train_dataset, batch_size=64) 54 | test_loader = read_loader(test_dataset, batch_size=1) 55 | 56 | epoch = 400 57 | writer = SummaryWriter() 58 | 59 | iteration = 0 60 | 61 | best_cer = 2.0 62 | 63 | for i in range(epoch): 64 | 65 | it = iter(train_loader) 66 | 67 | loss_sum = 0 68 | it_cnt = len(it) 69 | 70 | for batch in tqdm.tqdm(it): 71 | 72 | x,y = batch 73 | #print(x) 74 | #print(y) 75 | #print(torch.max(x).item(), torch.max(y).item()) 76 | 77 | x = x.cuda() 78 | y = y.cuda() 79 | 80 | opt.zero_grad() 81 | loss = model.train_step(x, y) 82 | loss.backward() 83 | opt.step() 84 | 85 | loss_val = loss.item() 86 | iteration += 1 87 | #print(iteration, ' ', loss_val) 88 | writer.add_scalar('Loss/Train', loss_val, iteration) 89 | loss_sum += loss 90 | 91 | print("epoch ", i, " loss ", loss_sum/it_cnt) 92 | 93 | test_it = iter(test_loader) 94 | 95 | cer = 0 96 | csum = 0 97 | wer = 0 98 | wsum = 0 99 | 100 | ploted =False 101 | 102 | if i % 5 != 0: 103 | continue 104 | 105 | for x,y in test_it: 106 | x = x.cuda() 107 | y = y.cuda() 108 | predicted = model.inference(x) 109 | y = y.squeeze(0).tolist() 110 | 111 | dist = editdistance.eval(predicted, y) 112 | cer += dist 113 | csum += len(y) 114 | 115 | if dist != 0: 116 | wer += 1 117 | wsum += 1 118 | 119 | writer.add_scalar('CER/Test', cer/csum, i) 120 | writer.add_scalar('WER/Test', wer/wsum, i) 121 | 122 | print("val cer: ", cer/csum) 123 | print("val wer: ", wer/wsum) 124 | 125 | cer = cer/csum 126 | 127 | if cer <= best_cer: 128 | best_cer = cer 129 | model_path = model_dir / f"model_{best_cer:0.6f}.pt" 130 | torch_save(model, model_path) 131 | 132 | 133 | if __name__ == '__main__': 134 | parser = argparse.ArgumentParser(description='train g2p') 135 | parser.add_argument('--exp', type=str, help='exp name') 136 | parser.add_argument('--checkpoint', type=str, help='checkpoint') 137 | 138 | args = parser.parse_args() 139 | 140 | exp = args.exp 141 | checkpoint=args.checkpoint 142 | 143 | train(exp, checkpoint) 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /transphone/bin/update_model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import tarfile 3 | from urllib.request import urlopen 4 | import io 5 | import argparse 6 | import os 7 | import shutil 8 | from transphone.config import TransphoneConfig 9 | from transphone.model.utils import resolve_model_name 10 | 11 | def update_model(model_name=None, alt_model_path=None): 12 | 13 | if model_name is None: 14 | model_name = 'latest' 15 | if alt_model_path: 16 | model_dir = alt_model_path 17 | else: 18 | model_dir = TransphoneConfig.data_path / 'model' 19 | model_dir.mkdir(parents=True, exist_ok=True) 20 | 21 | model_name = resolve_model_name(model_name) 22 | 23 | if (model_dir / model_name).exists(): 24 | print("deleting previous version: ", model_dir / model_name) 25 | shutil.rmtree(str(model_dir / model_name)) 26 | 27 | try: 28 | url = 'https://github.com/xinjli/transphone/releases/download/v1.0/' + model_name + '.tar.gz' 29 | print("re-downloading model ", model_name) 30 | print("from: ", url) 31 | print("to: ", str(model_dir)) 32 | print("please wait...") 33 | resp = urlopen(url) 34 | compressed_files = io.BytesIO(resp.read()) 35 | files = tarfile.open(fileobj=compressed_files) 36 | files.extractall(str(model_dir)) 37 | 38 | except Exception as e: 39 | print("Error: could not download the model", e) 40 | (model_dir / model_name).rmdir() 41 | 42 | 43 | if __name__ == '__main__': 44 | 45 | parser = argparse.ArgumentParser('a utility to update transphone model') 46 | parser.add_argument('-m', '--model', default='latest', help='specify which model to download. A list of downloadable models are available on Github') 47 | 48 | args = parser.parse_args() 49 | 50 | update_model(args.model) -------------------------------------------------------------------------------- /transphone/config.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import logging 3 | import torch 4 | 5 | class TransphoneConfig: 6 | 7 | root_path = Path(__file__).parent.parent 8 | data_path = Path(__file__).parent / 'data' 9 | lang_path = data_path / 'lang' 10 | logger = logging.getLogger('transphone') 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | -------------------------------------------------------------------------------- /transphone/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinjli/transphone/7769dec083b90dab072fc7e4c592744592a35736/transphone/data/__init__.py -------------------------------------------------------------------------------- /transphone/data/exp/042801_base.yml: -------------------------------------------------------------------------------- 1 | model: g2p 2 | num_encoder: 4 3 | num_decoder: 4 4 | embed_size: 512 5 | hidden_size: 512 6 | num_head: 8 7 | -------------------------------------------------------------------------------- /transphone/data/exp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinjli/transphone/7769dec083b90dab072fc7e4c592744592a35736/transphone/data/exp/__init__.py -------------------------------------------------------------------------------- /transphone/g2p.py: -------------------------------------------------------------------------------- 1 | from phonepiece.lang import normalize_lang_id 2 | from transphone.model.checkpoint_utils import torch_load 3 | from transphone.model.transformer import TransformerG2P 4 | from transphone.model.ensemble import ensemble 5 | from transphone.bin.download_model import download_model 6 | from transphone.config import TransphoneConfig 7 | from transphone.model.checkpoint_utils import find_topk_models 8 | from transphone.model.vocab import Vocab 9 | from transphone.model.utils import read_model_config 10 | from transphone.utils import Singleton 11 | from phonepiece.lang import read_tree 12 | from phonepiece.inventory import read_inventory 13 | from pathlib import Path 14 | import torch 15 | import unidecode 16 | from itertools import groupby 17 | from transphone.model.utils import resolve_model_name 18 | 19 | 20 | def read_g2p(model_name='latest', device=None, checkpoint=None): 21 | 22 | if device is not None: 23 | if isinstance(device, str): 24 | assert device in ['cpu', 'cuda'] 25 | TransphoneConfig.device = device 26 | elif isinstance(device, int): 27 | if device == -1: 28 | TransphoneConfig.device = 'cpu' 29 | else: 30 | TransphoneConfig.device = f'cuda:{device}' 31 | 32 | else: 33 | assert isinstance(device, torch.device) 34 | TransphoneConfig.device = device.type 35 | 36 | model_name = resolve_model_name(model_name) 37 | cache_path = None 38 | 39 | if checkpoint is None: 40 | model_path = TransphoneConfig.data_path / 'model' / model_name 41 | 42 | # if not exists, we try to download the model 43 | if not model_path.exists(): 44 | download_model(model_name) 45 | 46 | if not model_path.exists(): 47 | raise ValueError(f"could not download or read {model_name} model") 48 | 49 | if (model_path / "model.pt").exists(): 50 | checkpoint = model_path / "model.pt" 51 | else: 52 | checkpoint = find_topk_models(model_path)[0] 53 | 54 | cache_path = model_path / 'cache' 55 | 56 | if not (model_path / 'cache').exists(): 57 | cache_path.mkdir(parents=True, exist_ok=True) 58 | 59 | config = read_model_config(model_name) 60 | 61 | model = G2P(checkpoint, cache_path, config) 62 | 63 | return model 64 | 65 | 66 | class G2P(metaclass=Singleton): 67 | 68 | def __init__(self, checkpoint, cache_path, config): 69 | 70 | self.model_path = Path(checkpoint).parent 71 | self.grapheme_vocab = Vocab.read(self.model_path / 'grapheme.vocab') 72 | self.phoneme_vocab = Vocab.read(self.model_path / 'phoneme.vocab') 73 | self.config = config 74 | self.checkpoint = checkpoint 75 | self.cache_path = cache_path 76 | 77 | # setup available languages 78 | self.supervised_langs = [] 79 | for word in self.grapheme_vocab.words[2:]: 80 | if len(word) == 5 and word[0] == '<' and word[-1] == '>': 81 | self.supervised_langs.append(word[1:-1]) 82 | 83 | # cache to find proper supervised language 84 | self.lang_map = {} 85 | 86 | # inventory 87 | self.lang2inv = {} 88 | 89 | # lang2cache 90 | self.lang2cache = {} 91 | 92 | # tree to estimate language's similarity 93 | self.tree = read_tree() 94 | self.tree.setup_target_langs(self.supervised_langs) 95 | self.supervised_langs = set(self.supervised_langs) 96 | 97 | SRC_VOCAB_SIZE = len(self.grapheme_vocab)+1 98 | TGT_VOCAB_SIZE = len(self.phoneme_vocab)+1 99 | 100 | EMB_SIZE = config.embed_size 101 | NHEAD = config.num_head 102 | FFN_HID_DIM = config.hidden_size 103 | NUM_ENCODER_LAYERS = config.num_encoder 104 | NUM_DECODER_LAYERS = config.num_decoder 105 | 106 | torch.manual_seed(0) 107 | 108 | self.model = TransformerG2P(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, 109 | NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM).to(TransphoneConfig.device) 110 | 111 | torch_load(self.model, self.checkpoint) 112 | 113 | 114 | def get_target_langs(self, lang_id, num_lang=10, verbose=False, force_approximate=False): 115 | 116 | if lang_id in self.lang_map: 117 | target_langs = self.lang_map[lang_id] 118 | else: 119 | 120 | if force_approximate or lang_id not in self.supervised_langs: 121 | target_langs = self.tree.get_nearest_langs(lang_id, num_lang) 122 | if verbose: 123 | print("lang ", lang_id, " is not available directly, use ", target_langs, " instead") 124 | self.lang_map[lang_id] = target_langs 125 | else: 126 | self.lang_map[lang_id] = [lang_id] 127 | target_langs = [lang_id] 128 | 129 | return target_langs 130 | 131 | def inference_word(self, word, lang_id='eng', num_lang=10, verbose=False, force_approximate=False): 132 | 133 | target_langs = self.get_target_langs(lang_id, num_lang, verbose, force_approximate) 134 | 135 | phones_lst = [] 136 | 137 | for target_lang_id in target_langs: 138 | lang_tag = '<' + target_lang_id + '>' 139 | 140 | graphemes = [lang_tag]+[w.lower() for w in list(word)] 141 | 142 | grapheme_ids = [] 143 | for grapheme in graphemes: 144 | if grapheme not in self.grapheme_vocab: 145 | 146 | # romanize chars not available in training languages 147 | romans = list(unidecode.unidecode(grapheme)) 148 | 149 | if verbose: 150 | print("WARNING: not found grapheme ", grapheme, " in vocab. use ", romans, " instead") 151 | 152 | for roman in romans: 153 | 154 | # discard special chars such as $ 155 | if roman in self.grapheme_vocab: 156 | grapheme_ids.append(self.grapheme_vocab.atoi(roman)) 157 | continue 158 | grapheme_ids.append(self.grapheme_vocab.atoi(grapheme)) 159 | 160 | x = torch.LongTensor(grapheme_ids).unsqueeze(0).to(TransphoneConfig.device) 161 | 162 | phone_ids = self.model.inference(x) 163 | 164 | phones = [self.phoneme_vocab.itoa(phone) for phone in phone_ids] 165 | 166 | # ignore empty 167 | if len(phones) == 0: 168 | continue 169 | 170 | # if it is a mapped language, we need to map the inference_phone to the correct language inventory 171 | if lang_id not in self.lang2inv: 172 | inv = read_inventory(lang_id) 173 | self.lang2inv[lang_id] = inv 174 | 175 | inv = self.lang2inv[lang_id] 176 | phones = inv.remap(phones) 177 | 178 | if verbose: 179 | print(target_lang_id, ' ', phones) 180 | 181 | phones_lst.append(phones) 182 | 183 | 184 | if len(phones_lst) == 0: 185 | phones = [] 186 | else: 187 | phones = ensemble(phones_lst) 188 | 189 | return phones 190 | 191 | def inference_word_batch(self, word, lang_id='eng', num_lang=10, verbose=False, force_approximate=False): 192 | 193 | target_langs = self.get_target_langs(lang_id, num_lang, verbose, force_approximate) 194 | 195 | phones_lst = [] 196 | 197 | grapheme_input = [] 198 | 199 | for target_lang_id in target_langs: 200 | lang_tag = '<' + target_lang_id + '>' 201 | 202 | graphemes = [lang_tag]+[w.lower() for w in list(word)] 203 | 204 | grapheme_ids = [] 205 | normalized_graphemes = [] 206 | 207 | for grapheme in graphemes: 208 | if grapheme not in self.grapheme_vocab: 209 | 210 | # romanize chars not available in training languages 211 | romans = list(unidecode.unidecode(grapheme)) 212 | 213 | if verbose: 214 | print("WARNING: not found grapheme ", grapheme, " in vocab. use ", romans, " instead") 215 | 216 | for roman in romans: 217 | 218 | # discard special chars such as $ 219 | if roman in self.grapheme_vocab: 220 | grapheme_ids.append(self.grapheme_vocab.atoi(roman)) 221 | normalized_graphemes.append(roman) 222 | continue 223 | 224 | normalized_graphemes.append(grapheme) 225 | grapheme_ids.append(self.grapheme_vocab.atoi(grapheme)) 226 | 227 | grapheme_input.append(grapheme_ids) 228 | 229 | if verbose: 230 | print(f"normalized: {word} -> {normalized_graphemes}") 231 | 232 | x = torch.LongTensor(grapheme_input).to(TransphoneConfig.device) 233 | 234 | phone_output = self.model.inference_batch(x) 235 | 236 | for target_lang_id, phone_ids in zip(target_langs, phone_output): 237 | phones = [self.phoneme_vocab.itoa(phone) for phone in phone_ids] 238 | 239 | # ignore empty 240 | if len(phones) == 0: 241 | continue 242 | 243 | # if it is a mapped language, we need to map the inference_phone to the correct language inventory 244 | if lang_id not in self.lang2inv: 245 | inv = read_inventory(lang_id) 246 | self.lang2inv[lang_id] = inv 247 | 248 | inv = self.lang2inv[lang_id] 249 | phones = inv.remap(phones) 250 | 251 | if verbose: 252 | print(target_lang_id, ' ', phones) 253 | 254 | phones_lst.append(phones) 255 | 256 | if len(phones_lst) == 0: 257 | phones = [] 258 | else: 259 | phones = ensemble(phones_lst) 260 | 261 | return phones 262 | 263 | def inference(self, text, lang_id='eng', num_lang=10, verbose=False, force_approximate=False): 264 | lang_id = normalize_lang_id(lang_id) 265 | 266 | phones_lst = [] 267 | 268 | words = text.split() 269 | 270 | for word in words: 271 | phones = self.inference_word(word, lang_id, num_lang, verbose, force_approximate) 272 | phones = [x[0] for x in groupby(phones)] 273 | phones_lst.extend(phones) 274 | 275 | return phones_lst 276 | 277 | def inference_batch(self, text, lang_id='eng', num_lang=10, verbose=False, force_approximate=False): 278 | lang_id = normalize_lang_id(lang_id) 279 | 280 | phones_lst = [] 281 | 282 | words = text.split() 283 | 284 | for word in words: 285 | phones = self.inference_word_batch(word, lang_id, num_lang, verbose, force_approximate) 286 | phones = [x[0] for x in groupby(phones)] 287 | phones_lst.extend(phones) 288 | 289 | return phones_lst -------------------------------------------------------------------------------- /transphone/lang/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinjli/transphone/7769dec083b90dab072fc7e4c592744592a35736/transphone/lang/__init__.py -------------------------------------------------------------------------------- /transphone/lang/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | from phonepiece.inventory import read_inventory 2 | from transphone.g2p import read_g2p 3 | from transphone.config import TransphoneConfig 4 | 5 | class BaseTokenizer: 6 | 7 | def __init__(self, lang_id, g2p_model='latest', device=None): 8 | self.lang_id = lang_id 9 | self.inventory = read_inventory(lang_id) 10 | 11 | if g2p_model is None: 12 | self.g2p = None 13 | else: 14 | self.g2p = read_g2p(g2p_model, device) 15 | 16 | # cache for g2p 17 | self.cache = {} 18 | 19 | # this will temporarily store new caches, which will be flashed to disk 20 | self.cache_log = {} 21 | 22 | if self.g2p is not None and self.g2p.cache_path is not None: 23 | lang_cache_path = self.g2p.cache_path / f"{lang_id}.txt" 24 | if lang_cache_path.exists(): 25 | for line in open(lang_cache_path, 'r'): 26 | fields = line.strip().split() 27 | self.cache[fields[0]] = fields[1:] 28 | 29 | self.punctuation = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~0123456789' 30 | self.logger = TransphoneConfig.logger 31 | 32 | def add_cache(self, word, phonemes): 33 | 34 | self.cache[word] = phonemes 35 | 36 | if self.g2p is None or self.g2p.cache_path is None: 37 | return 38 | 39 | # handle new cache 40 | self.cache_log[word] = phonemes 41 | 42 | # flash them to disk if the cache is large enough 43 | if len(self.cache_log) >= 100 and self.g2p.cache_path is not None and self.g2p.cache_path.exists(): 44 | w = open(self.g2p.cache_path / f"{self.lang_id}.txt", 'a') 45 | for word, phonemes in self.cache_log.items(): 46 | w.write(f"{word}\t{' '.join(phonemes)}\n") 47 | w.close() 48 | self.cache_log = {} 49 | 50 | def tokenize(self, text: str): 51 | raise NotImplementedError 52 | 53 | def tokenize_words(self, text:str): 54 | text = text.translate(str.maketrans('', '', self.punctuation)).lower() 55 | 56 | words = text.split() 57 | cleaned_words = [word for word in words if len(word) > 0] 58 | 59 | return text.split() 60 | 61 | def convert_tokens_to_ids(self, lst): 62 | lst = list(filter(lambda s: s!='', lst)) 63 | 64 | return self.inventory.phoneme.atoi(lst) 65 | 66 | def convert_ids_to_tokens(self, lst): 67 | return self.inventory.phoneme.itoa(lst) 68 | -------------------------------------------------------------------------------- /transphone/lang/cmn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinjli/transphone/7769dec083b90dab072fc7e4c592744592a35736/transphone/lang/cmn/__init__.py -------------------------------------------------------------------------------- /transphone/lang/cmn/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transphone.utils import import_with_auto_install 2 | from transphone.lang.base_tokenizer import BaseTokenizer 3 | from phonepiece.pinyin import PinyinConverter 4 | from transphone.lang.cmn.normalizer import CMNNormalizer 5 | 6 | 7 | def read_cmn_tokenizer(lang_id='cmn', g2p_model='latest', device=None, use_lexicon=True): 8 | return CMNTokenizer(lang_id, g2p_model, device) 9 | 10 | 11 | class CMNTokenizer(BaseTokenizer): 12 | 13 | def __init__(self, lang_id='cmn', g2p_model='latest', device=None): 14 | 15 | super().__init__(lang_id, g2p_model, device) 16 | 17 | # import jieba and pypinyin for segmentation 18 | self.jieba = import_with_auto_install('jieba', 'jieba') 19 | pypinyin = import_with_auto_install('pypinyin', 'pypinyin') 20 | 21 | self.pinyin = pypinyin.lazy_pinyin 22 | self.converter = PinyinConverter() 23 | self.normalizer = CMNNormalizer() 24 | 25 | def tokenize_words(self, text): 26 | 27 | text = self.normalizer(text) 28 | words = list(self.jieba.cut(text, use_paddle=False)) 29 | return words 30 | 31 | 32 | def tokenize(self, text, use_g2p=True, use_space=False, verbose=False): 33 | 34 | norm_text = self.normalizer(text) 35 | 36 | log = f"normalization: {text} -> {norm_text}" 37 | self.logger.info(log) 38 | if verbose: 39 | print(log) 40 | 41 | text = norm_text 42 | 43 | words = list(self.jieba.cut(text, use_paddle=False)) 44 | 45 | ipa_lst = [] 46 | 47 | for word in words: 48 | pinyins = self.pinyin(word) 49 | 50 | self.logger.info(f"pinyin: {word} -> {pinyins}") 51 | if verbose: 52 | print(f"pinyin: {word} -> {pinyins}") 53 | 54 | for pinyin in pinyins: 55 | ipa_lst.extend(self.converter.convert(pinyin)) 56 | if use_space: 57 | ipa_lst.append('') 58 | return ipa_lst -------------------------------------------------------------------------------- /transphone/lang/eng/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinjli/transphone/7769dec083b90dab072fc7e4c592744592a35736/transphone/lang/eng/__init__.py -------------------------------------------------------------------------------- /transphone/lang/eng/normalizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is from: https://github.com/tomaarsen/TTSTextNormalization/blob/master/converters/Digit.py 3 | 4 | """ 5 | 6 | import re 7 | 8 | 9 | # Function to convert numbers to words, adapted from GPT generation 10 | def number_to_words(num): 11 | 12 | # Create a dictionary for mapping digits to words 13 | digits = { 14 | 0: 'zero', 1: 'one', 2: 'two', 3: 'three', 4: 'four', 15 | 5: 'five', 6: 'six', 7: 'seven', 8: 'eight', 9: 'nine' 16 | } 17 | 18 | # Create a dictionary for mapping two-digit numbers to words 19 | two_digits = { 20 | 10: 'ten', 11: 'eleven', 12: 'twelve', 13: 'thirteen', 14: 'fourteen', 21 | 15: 'fifteen', 16: 'sixteen', 17: 'seventeen', 18: 'eighteen', 19: 'nineteen', 22 | 20: 'twenty', 30: 'thirty', 40: 'forty', 50: 'fifty', 60: 'sixty', 23 | 70: 'seventy', 80: 'eighty', 90: 'ninety' 24 | } 25 | 26 | # Create a dictionary for mapping powers of 10 to words 27 | powers_of_10 = { 28 | 100: 'hundred', 1000: 'thousand', 1000000: 'million', 29 | 1000000000: 'billion', 1000000000000: 'trillion' 30 | } 31 | 32 | # Handle negative numbers 33 | if num < 0: 34 | return "minus " + number_to_words(abs(num)) 35 | 36 | # Handle numbers from 0 to 9 37 | if num < 10: 38 | return digits[num] 39 | 40 | # Handle numbers from 10 to 99 41 | if num < 100: 42 | if num in two_digits: 43 | return two_digits[num] 44 | else: 45 | return two_digits[num // 10 * 10] + " " + digits[num % 10] 46 | 47 | # Handle numbers from 100 to 999 48 | if num < 1000: 49 | hundreds = num // 100 50 | remainder = num % 100 51 | if remainder == 0: 52 | return digits[hundreds] + " " + powers_of_10[100] 53 | else: 54 | return digits[hundreds] + " " + powers_of_10[100] + " and " + number_to_words(remainder) 55 | 56 | # Handle numbers greater than or equal to 1000 57 | for power in sorted(powers_of_10.keys(), reverse=True): 58 | if num >= power: 59 | quotient = num // power 60 | remainder = num % power 61 | if remainder == 0: 62 | return number_to_words(quotient) + " " + powers_of_10[power] 63 | else: 64 | return number_to_words(quotient) + " " + powers_of_10[power] + " " + number_to_words(remainder) 65 | 66 | return "infinity" 67 | 68 | class ENGNormalizer: 69 | 70 | def __init__(self): 71 | super().__init__() 72 | # Regex used to filter out non digits 73 | self.filter_regex = re.compile("[^0-9]") 74 | # Translation dict to convert digits to text 75 | self.trans_dict = { 76 | "0": "zero", 77 | "1": "one", 78 | "2": "two", 79 | "3": "three", 80 | "4": "four", 81 | "5": "five", 82 | "6": "six", 83 | "7": "seven", 84 | "8": "eight", 85 | "9": "nine" 86 | } 87 | 88 | self.punctuation = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~' 89 | 90 | 91 | def __call__(self, text: str) -> str: 92 | 93 | text = text.translate(str.maketrans('', '', self.punctuation)).lower() 94 | 95 | tokens = text.split() 96 | 97 | res = [] 98 | for token in tokens: 99 | if str.isdigit(token): 100 | # 1 Filter out anything that isn't a digit 101 | token = self.filter_regex.sub("", token) 102 | # 2 Check for special case 103 | if token == "007": 104 | return "double o seven" 105 | 106 | token = number_to_words(int(token)) 107 | 108 | res.append(token) 109 | 110 | return " ".join(res) 111 | -------------------------------------------------------------------------------- /transphone/lang/eng/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transphone.utils import import_with_auto_install 2 | from transphone.lang.base_tokenizer import BaseTokenizer 3 | from phonepiece.arpa import ArpaConverter 4 | from transphone.lang.eng.normalizer import ENGNormalizer 5 | 6 | def read_eng_tokenizer(lang_id='eng', g2p_model='latest', device=None, use_lexicon=True): 7 | return ENGTokenizer(lang_id, g2p_model, device) 8 | 9 | 10 | class ENGTokenizer(BaseTokenizer): 11 | 12 | def __init__(self, lang_id='eng', g2p_model='latest', device=None): 13 | 14 | super().__init__(lang_id, g2p_model, device) 15 | 16 | # import jieba and pypinyin for segmentation 17 | cmudict_module = import_with_auto_install('cmudict', 'cmudict') 18 | self.cmudict = cmudict_module.dict() 19 | self.converter = ArpaConverter() 20 | self.normalizer = ENGNormalizer() 21 | 22 | def tokenize_words(self, text:str): 23 | text = self.normalizer(text) 24 | words = text.split() 25 | return words 26 | 27 | def tokenize(self, text, use_g2p=True, use_space=False, verbose=False): 28 | 29 | norm_text = self.normalizer(text) 30 | 31 | log = f"normalization: {text} -> {norm_text}" 32 | self.logger.info(log) 33 | if verbose: 34 | print(log) 35 | 36 | ipa_lst = [] 37 | text = norm_text 38 | 39 | for word in text.split(): 40 | if len(word) >= 1: 41 | if word in self.cache: 42 | ipas = self.cache[word] 43 | ipa_lst.extend(ipas) 44 | 45 | elif word in self.cmudict: 46 | arpa = self.cmudict[word][0] 47 | ipas = self.converter.convert(arpa) 48 | ipa_lst.extend(ipas) 49 | 50 | log = f"CMUdict: {word} -> {arpa} -> {ipas}" 51 | self.logger.info(f"CMUdict: {word} -> {arpa} -> {ipas}") 52 | if verbose: 53 | print(log) 54 | 55 | self.cache[word] = ipas 56 | 57 | elif use_g2p: 58 | phonemes = self.g2p.inference_batch(word, self.lang_id, verbose=verbose) 59 | remapped_phonemes = self.inventory.remap(phonemes) 60 | 61 | log = f"g2p {word} -> {remapped_phonemes}" 62 | self.logger.info(log) 63 | if verbose: 64 | print(log) 65 | 66 | self.add_cache(word, remapped_phonemes) 67 | ipa_lst.extend(remapped_phonemes) 68 | if use_space: 69 | ipa_lst.append('') 70 | 71 | return ipa_lst -------------------------------------------------------------------------------- /transphone/lang/epitran_tokenizer.py: -------------------------------------------------------------------------------- 1 | """Basic Epitran class for G2P in most languages.""" 2 | import logging 3 | import os.path 4 | import sys 5 | import csv 6 | import unicodedata 7 | from collections import defaultdict 8 | from typing import DefaultDict, Callable # pylint: disable=unused-import 9 | from phonepiece.epitran import read_epitran_g2p 10 | from phonepiece.lexicon import read_lexicon 11 | from phonepiece.inventory import read_inventory 12 | import re 13 | from transphone.lang.base_tokenizer import BaseTokenizer 14 | from epitran.ppprocessor import PrePostProcessor 15 | from phonepiece.ipa import read_ipa 16 | import epitran 17 | 18 | logger = logging.getLogger('epitran') 19 | 20 | customized_epitran_dict = { 21 | 'spa': 'spa-Latn', 22 | 'deu': 'deu-Latn', 23 | 'ita': 'ita-Latn', 24 | 'rus': 'rus-Cyrl', 25 | 'fra': 'fra-Latn', 26 | } 27 | 28 | raw_epitran_dict = { 29 | 'tur': 'tur-Latn', 30 | 'vie': 'vie-Latn', 31 | 'aar': 'aar-Latn', 32 | 'got': 'got-Latn', 33 | 'lsm': 'lsm-Latn', 34 | 'swa': 'swa-Latn', 35 | 'aii': 'aii-Syrc', 36 | 'hak': 'hak-Latn', 37 | 'ltc': 'ltc-Latn-bax', 38 | 'swe': 'swe-Latn', 39 | 'amh': 'amh-Ethi', 40 | 'hat': 'hat-Latn-bab', 41 | 'tam': 'tam-Taml', 42 | 'hau': 'hau-Latn', 43 | 'mal': 'mal-Mlym', 44 | 'hin': 'hin-Deva', 45 | 'mar': 'mar-Deva', 46 | 'tel': 'tel-Telu', 47 | 'ara': 'ara-Arab', 48 | 'hmn': 'hmn-Latn', 49 | 'mlt': 'mlt-Latn', 50 | 'tgk': 'tgk-Cyrl', 51 | 'ava': 'ava-Cyrl', 52 | 'hrv': 'hrv-Latn', 53 | 'mon': 'mon-Cyrl-bab', 54 | 'tgl': 'tgl-Latn', 55 | 'aze': 'aze-Latn', 56 | 'hsn': 'hsn-Latn', 57 | 'mri': 'mri-Latn', 58 | 'hun': 'hun-Latn', 59 | 'msa': 'msa-Latn', 60 | 'tha': 'tha-Thai', 61 | 'ilo': 'ilo-Latn', 62 | 'mya': 'mya-Mymr', 63 | 'tir': 'tir-Ethi', 64 | 'ind': 'ind-Latn', 65 | 'nan': 'nan-Latn', 66 | 'ita': 'ita-Latn', 67 | 'bxk': 'bxk-Latn', 68 | 'jam': 'jam-Latn', 69 | 'nld': 'nld-Latn', 70 | 'tpi': 'tpi-Latn', 71 | 'cat': 'cat-Latn', 72 | 'jav': 'jav-Latn', 73 | 'nya': 'nya-Latn', 74 | 'tuk': 'tuk-Latn', 75 | 'ceb': 'ceb-Latn', 76 | 'jpn': 'jpn-Ktkn-red', 77 | 'ood': 'ood-Latn-sax', 78 | 'ces': 'ces-Latn', 79 | 'cjy': 'cjy-Latn', 80 | 'ori': 'ori-Orya', 81 | 'ckb': 'ckb-Arab', 82 | 'orm': 'orm-Latn', 83 | 'cmn': 'cmn-Latn', 84 | 'kat': 'kat-Geor', 85 | 'pan': 'pan-Guru', 86 | 'uew': 'uew', 87 | 'kaz': 'kaz-Latn', 88 | 'pii': 'pii-latn_Wiktionary', 89 | 'uig': 'uig-Arab', 90 | 'csb': 'csb-Latn', 91 | 'ukr': 'ukr-Cyrl', 92 | 'deu': 'deu-Latn', 93 | 'pol': 'pol-Latn', 94 | 'urd': 'urd-Arab', 95 | 'kbd': 'kbd-Cyrl', 96 | 'por': 'por-Latn', 97 | 'uzb': 'uzb-Latn', 98 | 'khm': 'khm-Khmr', 99 | 'ron': 'ron-Latn', 100 | 'fas': 'fas-Arab', 101 | 'kin': 'kin-Latn', 102 | 'run': 'run-Latn', 103 | 'kir': 'kir-Latn', 104 | 'rus': 'rus-Cyrl', 105 | 'wuu': 'wuu-Latn', 106 | 'sag': 'sag-Latn', 107 | 'xho': 'xho-Latn', 108 | 'sin': 'sin-Sinh', 109 | 'yor': 'yor-Latn', 110 | 'kmr': 'kmr-Latn', 111 | 'sna': 'sna-Latn', 112 | 'yue': 'yue-Latn', 113 | 'som': 'som-Latn', 114 | 'zha': 'zha-Latn', 115 | 'ful': 'ful-Latn', 116 | 'lao': 'lao-Laoo-prereform', 117 | 'spa': 'spa-Latn-eu', 118 | 'zul': 'zul-Latn', 119 | 'gan': 'gan-Latn', 120 | 'sqi': 'sqi-Latn', 121 | 'lij': 'lij-Latn' 122 | } 123 | 124 | 125 | def read_epitran_tokenizer(lang_id, g2p_model=None, device=None, use_lexicon=True): 126 | if lang_id in customized_epitran_dict: 127 | return read_customized_epitran_tokenizer(customized_epitran_dict[lang_id], use_lexicon) 128 | elif lang_id in raw_epitran_dict: 129 | return read_raw_epitran_tokenizer(raw_epitran_dict[lang_id], use_lexicon) 130 | else: 131 | raise ValueError('Unknown epitran id: {}'.format(lang_id_or_epi_id)) 132 | 133 | 134 | def read_raw_epitran_tokenizer(lang_id_or_epi_id, use_lexicon=True): 135 | 136 | if '-' in lang_id_or_epi_id: 137 | lang_id, writing_system = lang_id_or_epi_id.split('-', 1) 138 | else: 139 | lang_id = lang_id_or_epi_id 140 | writing_system = 'Latn' 141 | 142 | if use_lexicon: 143 | lexicon = read_lexicon(lang_id) 144 | else: 145 | lexicon = {} 146 | 147 | return RawEpitranTokenizer(lang_id, writing_system, lexicon) 148 | 149 | 150 | def read_customized_epitran_tokenizer(lang_id_or_epi_id, use_lexicon=True): 151 | 152 | if '-' in lang_id_or_epi_id: 153 | lang_id, writing_system = lang_id_or_epi_id.split('-', 1) 154 | else: 155 | lang_id = lang_id_or_epi_id 156 | writing_system = 'Latn' 157 | 158 | if use_lexicon: 159 | lexicon = read_lexicon(lang_id) 160 | else: 161 | lexicon = {} 162 | 163 | return CustomizedEpitranTokenizer(lang_id, writing_system, lexicon) 164 | 165 | 166 | class CustomizedEpitranTokenizer(BaseTokenizer): 167 | 168 | def __init__(self, lang_id, writing_system=None, lexicon=None): 169 | super().__init__(lang_id, None) 170 | 171 | #self.lexicon = read_lexicon(lang_id) 172 | 173 | """Constructor""" 174 | self.lang_id = lang_id 175 | self.writing_system = writing_system 176 | 177 | if writing_system: 178 | lang_id = lang_id + '-' + writing_system 179 | 180 | if not lexicon: 181 | lexicon = {} 182 | 183 | self.lexicon = lexicon 184 | 185 | self.g2p = read_epitran_g2p(lang_id) 186 | self.inv = read_inventory(self.lang_id) 187 | self.ipa = read_ipa() 188 | 189 | self.regexp = self._construct_regex(self.g2p.keys()) 190 | self.nils = defaultdict(int) 191 | 192 | self.cache = defaultdict(list) 193 | 194 | self.preprocessor = PrePostProcessor(lang_id, 'pre', False) 195 | #self.postprocessor = PrePostProcessor(lang_id, 'post', False) 196 | 197 | def _construct_regex(self, g2p_keys): 198 | """Build a regular expression that will greadily match segments from 199 | the mapping table. 200 | """ 201 | graphemes = sorted(g2p_keys, key=len, reverse=True) 202 | return re.compile(f"({r'|'.join(graphemes)})", re.I) 203 | 204 | def match_word(self, text, verbose=False): 205 | ipa_lst = [] 206 | while text: 207 | logger.debug('text=', repr(list(text))) 208 | if verbose: 209 | print('text=', repr(list(text))) 210 | 211 | m = self.regexp.match(text) 212 | if m: 213 | source = m.group(0) 214 | try: 215 | targets = self.g2p[source] 216 | if verbose: 217 | print(source, ' -> ', targets) 218 | except KeyError: 219 | logger.debug("source = '%s''", source) 220 | logger.debug("self.g2p[source] = %s'", self.g2p[source]) 221 | targets = [] 222 | except IndexError: 223 | logger.debug("self.g2p[source]= %s", self.g2p[source]) 224 | targets = [] 225 | 226 | ipa_lst.extend(targets) 227 | text = text[len(source):] 228 | else: 229 | self.nils[text[0]] += 2 230 | text = text[1:] 231 | ipa_lst = self.inv.remap(ipa_lst) 232 | return ipa_lst 233 | 234 | 235 | def tokenize(self, text: str, verbose: bool=False): 236 | text = text.lower() 237 | 238 | ipa_lst = [] 239 | 240 | for word in text.split(): 241 | 242 | if word in self.cache: 243 | ipa_lst.extend(self.cache[word]) 244 | 245 | elif word in self.lexicon: 246 | phonemes = self.lexicon[word] 247 | ipa_lst.extend(phonemes) 248 | self.cache[word] = phonemes 249 | log = f"lexicon {word} -> {phonemes}" 250 | self.logger.info(log) 251 | if verbose: 252 | print(log) 253 | else: 254 | norm_word = unicodedata.normalize('NFC', word) 255 | 256 | norm_word = self.preprocessor.process(norm_word) 257 | 258 | word_ipa_lst = self.match_word(norm_word, verbose) 259 | 260 | log = f"rule raw: {word} -> norm: {norm_word} -> {word_ipa_lst}" 261 | self.logger.info(log) 262 | if verbose: 263 | print(log) 264 | 265 | #word_ipa_lst = self.ipa.tokenize(self.postprocessor.process(''.join(word_ipa_lst))) 266 | 267 | self.cache[word] = word_ipa_lst 268 | ipa_lst.extend(self.cache[word]) 269 | 270 | return ipa_lst 271 | 272 | 273 | class RawEpitranTokenizer(BaseTokenizer): 274 | 275 | def __init__(self, lang_id, writing_system='Latin', lexicon=None): 276 | super().__init__(lang_id, None) 277 | 278 | #self.lexicon = read_lexicon(lang_id) 279 | 280 | """Constructor""" 281 | self.lang_id = lang_id 282 | self.writing_system = writing_system 283 | 284 | if writing_system: 285 | lang_id = lang_id + '-' + writing_system 286 | 287 | if not lexicon: 288 | lexicon = {} 289 | 290 | self.lexicon = lexicon 291 | 292 | self.g2p = epitran.Epitran(lang_id) 293 | self.inv = read_inventory(self.lang_id) 294 | self.ipa = read_ipa() 295 | 296 | self.cache = defaultdict(list) 297 | 298 | 299 | def tokenize(self, text: str, verbose: bool=False): 300 | text = text.lower() 301 | 302 | ipa_lst = [] 303 | 304 | for word in text.split(): 305 | 306 | if word in self.cache: 307 | ipa_lst.extend(self.cache[word]) 308 | 309 | elif word in self.lexicon: 310 | phonemes = self.lexicon[word] 311 | ipa_lst.extend(phonemes) 312 | self.cache[word] = phonemes 313 | log = f"lexicon {word} -> {phonemes}" 314 | self.logger.info(log) 315 | if verbose: 316 | print(log) 317 | else: 318 | 319 | word_ipa_lst = self.g2p.trans_list(word) 320 | word_ipa_lst = self.inv.remap(word_ipa_lst, broad=True) 321 | 322 | log = f"rule raw: {word} -> remap {word_ipa_lst}" 323 | self.logger.info(log) 324 | if verbose: 325 | print(log) 326 | 327 | #word_ipa_lst = self.ipa.tokenize(self.postprocessor.process(''.join(word_ipa_lst))) 328 | 329 | self.cache[word] = word_ipa_lst 330 | ipa_lst.extend(self.cache[word]) 331 | 332 | return ipa_lst -------------------------------------------------------------------------------- /transphone/lang/g2p_tokenizer.py: -------------------------------------------------------------------------------- 1 | from phonepiece.lang import normalize_lang_id 2 | from phonepiece.lexicon import read_lexicon 3 | from transphone.lang.base_tokenizer import BaseTokenizer 4 | 5 | 6 | def read_g2p_tokenizer(lang_id, g2p_model='latest', device=None): 7 | lang_id = normalize_lang_id(lang_id) 8 | return G2PTokenizer(lang_id, g2p_model, device) 9 | 10 | class G2PTokenizer(BaseTokenizer): 11 | 12 | def __init__(self, lang_id, g2p_model='latest', device=None): 13 | super().__init__(lang_id, g2p_model, device) 14 | 15 | self.lexicon = read_lexicon(lang_id) 16 | 17 | 18 | 19 | def tokenize(self, text, use_g2p=True, use_space=False, verbose=False): 20 | 21 | norm_text = text.translate(str.maketrans('', '', self.punctuation)).lower() 22 | log = f"normalization: {text} -> {norm_text}" 23 | self.logger.info(log) 24 | 25 | if verbose: 26 | print(log) 27 | 28 | text = norm_text 29 | 30 | result = [] 31 | 32 | for word in text.split(): 33 | if word in self.cache: 34 | result.extend(self.cache[word]) 35 | elif word in self.lexicon: 36 | phonemes = self.lexicon[word] 37 | result.extend(phonemes) 38 | self.cache[word] = phonemes 39 | log = f"lexicon {word} -> {phonemes}" 40 | self.logger.info(log) 41 | if verbose: 42 | print(log) 43 | else: 44 | phonemes = self.g2p.inference_batch(word, self.lang_id, verbose=verbose) 45 | remapped_phonemes = self.inventory.remap(phonemes) 46 | 47 | log = f"g2p batch mode: {word} -> {remapped_phonemes}" 48 | self.logger.info(log) 49 | if verbose: 50 | print(log) 51 | self.add_cache(word, remapped_phonemes) 52 | result.extend(remapped_phonemes) 53 | if use_space: 54 | result.append('') 55 | 56 | return result 57 | -------------------------------------------------------------------------------- /transphone/lang/jpn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinjli/transphone/7769dec083b90dab072fc7e4c592744592a35736/transphone/lang/jpn/__init__.py -------------------------------------------------------------------------------- /transphone/lang/jpn/conv_table.py: -------------------------------------------------------------------------------- 1 | HIRAGANA = list('ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすず' 2 | 'せぜそぞただちぢっつづてでとどなにぬねのはばぱひびぴ' 3 | 'ふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろわ' 4 | 'をんーゎゐゑゕゖゔゝゞ・「」。、') 5 | HALF_ASCII = list('!"#$%&\'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ' 6 | '[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~ ') 7 | HALF_DIGIT = list('0123456789') 8 | HALF_KANA_SEION = list('ァアィイゥウェエォオカキクケコサシスセソタチッツテトナニヌネノハヒフヘホマミムメモャヤュユョヨ' 9 | 'ラリルレロワヲンーヮヰヱヵヶヽヾ・「」。、') 10 | HALF_KANA = ['ァ', 'ア', 'ィ', 'イ', 'ゥ', 'ウ', 'ェ', 'エ', 'ォ', 'オ', 11 | 'カ', 'ガ', 'キ', 'ギ', 'ク', 'グ', 'ケ', 'ゲ', 'コ', 12 | 'ゴ', 'サ', 'ザ', 'シ', 'ジ', 'ス', 'ズ', 'セ', 'ゼ', 13 | 'ソ', 'ゾ', 'タ', 'ダ', 'チ', 'ヂ', 'ッ', 'ツ', 'ヅ', 14 | 'テ', 'デ', 'ト', 'ド', 'ナ', 'ニ', 'ヌ', 'ネ', 'ノ', 'ハ', 15 | 'バ', 'パ', 'ヒ', 'ビ', 'ピ', 'フ', 'ブ', 'プ', 'ヘ', 16 | 'ベ', 'ペ', 'ホ', 'ボ', 'ポ', 'マ', 'ミ', 'ム', 'メ', 17 | 'モ', 'ャ', 'ヤ', 'ュ', 'ユ', 'ョ', 'ヨ', 'ラ', 'リ', 'ル', 18 | 'レ', 'ロ', 'ワ', 'ヲ', 'ン', 'ー', 19 | 'ヮ', 'ヰ', 'ヱ', 'ヵ', 'ヶ', 'ヴ', 'ヽ', 'ヾ', '・', 20 | '「', '」', '。', '、'] 21 | FULL_ASCII = list('!"#$%&'()*+,-./:;<=>?@' 22 | 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 23 | '[\]^_`abcdefghijklmnopqrst' 24 | 'uvwxyz{|}~ ') 25 | FULL_DIGIT = list('0123456789') 26 | FULL_KANA = list('ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソ' 27 | 'ゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペ' 28 | 'ホボポマミムメモャヤュユョヨラリルレロワヲンーヮヰヱヵヶヴ' 29 | 'ヽヾ・「」。、') 30 | FULL_KANA_SEION = list('ァアィイゥウェエォオカキクケコサシスセソタチッツテト' 31 | 'ナニヌネノハヒフヘホマミムメモャヤュユョヨラリルレロ' 32 | 'ワヲンーヮヰヱヵヶヽヾ・「」。、') 33 | HEPBURN = list('aiueoaiueon') 34 | HEPBURN_KANA = list('ぁぃぅぇぉあいうえおん') 35 | 36 | 37 | def _to_ord_list(chars): 38 | return list(map(ord, chars)) 39 | 40 | HIRAGANA_ORD = _to_ord_list(HIRAGANA) 41 | FULL_KANA_ORD = _to_ord_list(FULL_KANA) 42 | HALF_ASCII_ORD = _to_ord_list(HALF_ASCII) 43 | FULL_ASCII_ORD = _to_ord_list(FULL_ASCII) 44 | HALF_DIGIT_ORD = _to_ord_list(HALF_DIGIT) 45 | FULL_DIGIT_ORD = _to_ord_list(FULL_DIGIT) 46 | HALF_KANA_SEION_ORD = _to_ord_list(HALF_KANA_SEION) 47 | FULL_KANA_SEION_ORD = _to_ord_list(FULL_KANA_SEION) 48 | 49 | 50 | def _to_dict(_from, _to): 51 | return dict(zip(_from, _to)) 52 | 53 | 54 | H2K_TABLE = _to_dict(HIRAGANA_ORD, FULL_KANA) 55 | H2HK_TABLE = _to_dict(HIRAGANA_ORD, HALF_KANA) 56 | K2H_TABLE = _to_dict(FULL_KANA_ORD, HIRAGANA) 57 | 58 | H2Z_A = _to_dict(HALF_ASCII_ORD, FULL_ASCII) 59 | H2Z_AD = _to_dict(HALF_ASCII_ORD+HALF_DIGIT_ORD, FULL_ASCII+FULL_DIGIT) 60 | H2Z_AK = _to_dict(HALF_ASCII_ORD+HALF_KANA_SEION_ORD, 61 | FULL_ASCII+FULL_KANA_SEION) 62 | H2Z_D = _to_dict(HALF_DIGIT_ORD, FULL_DIGIT) 63 | H2Z_K = _to_dict(HALF_KANA_SEION_ORD, FULL_KANA_SEION) 64 | H2Z_DK = _to_dict(HALF_DIGIT_ORD+HALF_KANA_SEION_ORD, 65 | FULL_DIGIT+FULL_KANA_SEION) 66 | H2Z_ALL = _to_dict(HALF_ASCII_ORD+HALF_DIGIT_ORD+HALF_KANA_SEION_ORD, 67 | FULL_ASCII+FULL_DIGIT+FULL_KANA_SEION) 68 | 69 | Z2H_A = _to_dict(FULL_ASCII_ORD, HALF_ASCII) 70 | Z2H_AD = _to_dict(FULL_ASCII_ORD+FULL_DIGIT_ORD, HALF_ASCII+HALF_DIGIT) 71 | Z2H_AK = _to_dict(FULL_ASCII_ORD+FULL_KANA_ORD, HALF_ASCII+HALF_KANA) 72 | Z2H_D = _to_dict(FULL_DIGIT_ORD, HALF_DIGIT) 73 | Z2H_K = _to_dict(FULL_KANA_ORD, HALF_KANA) 74 | Z2H_DK = _to_dict(FULL_DIGIT_ORD+FULL_KANA_ORD, HALF_DIGIT+HALF_KANA) 75 | Z2H_ALL = _to_dict(FULL_ASCII_ORD+FULL_DIGIT_ORD+FULL_KANA_ORD, 76 | HALF_ASCII+HALF_DIGIT+HALF_KANA) 77 | KANA2HEP = _to_dict(_to_ord_list(HEPBURN_KANA), HEPBURN) 78 | HEP2KANA = _to_dict(_to_ord_list(HEPBURN), HEPBURN_KANA) 79 | 80 | del _to_ord_list 81 | del _to_dict 82 | del HIRAGANA_ORD 83 | del HIRAGANA 84 | del HALF_KANA 85 | del FULL_KANA_ORD 86 | del FULL_KANA 87 | del HALF_ASCII_ORD 88 | del HALF_ASCII 89 | del FULL_ASCII_ORD 90 | del FULL_ASCII 91 | del HALF_DIGIT_ORD 92 | del HALF_DIGIT 93 | del FULL_DIGIT_ORD 94 | del FULL_DIGIT 95 | del HALF_KANA_SEION_ORD 96 | del HALF_KANA_SEION 97 | del FULL_KANA_SEION_ORD 98 | del FULL_KANA_SEION 99 | del HEPBURN 100 | del HEPBURN_KANA 101 | HIRAGANA = list('ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすず' 102 | 'せぜそぞただちぢっつづてでとどなにぬねのはばぱひびぴ' 103 | 'ふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろわ' 104 | 'をんーゎゐゑゕゖゔゝゞ・「」。、') 105 | HALF_ASCII = list('!"#$%&\'()*+,-./:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ' 106 | '[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~ ') 107 | HALF_DIGIT = list('0123456789') 108 | HALF_KANA_SEION = list('ァアィイゥウェエォオカキクケコサシスセソタチッツテトナニヌネノハヒフヘホマミムメモャヤュユョヨ' 109 | 'ラリルレロワヲンーヮヰヱヵヶヽヾ・「」。、') 110 | HALF_KANA = ['ァ', 'ア', 'ィ', 'イ', 'ゥ', 'ウ', 'ェ', 'エ', 'ォ', 'オ', 111 | 'カ', 'ガ', 'キ', 'ギ', 'ク', 'グ', 'ケ', 'ゲ', 'コ', 112 | 'ゴ', 'サ', 'ザ', 'シ', 'ジ', 'ス', 'ズ', 'セ', 'ゼ', 113 | 'ソ', 'ゾ', 'タ', 'ダ', 'チ', 'ヂ', 'ッ', 'ツ', 'ヅ', 114 | 'テ', 'デ', 'ト', 'ド', 'ナ', 'ニ', 'ヌ', 'ネ', 'ノ', 'ハ', 115 | 'バ', 'パ', 'ヒ', 'ビ', 'ピ', 'フ', 'ブ', 'プ', 'ヘ', 116 | 'ベ', 'ペ', 'ホ', 'ボ', 'ポ', 'マ', 'ミ', 'ム', 'メ', 117 | 'モ', 'ャ', 'ヤ', 'ュ', 'ユ', 'ョ', 'ヨ', 'ラ', 'リ', 'ル', 118 | 'レ', 'ロ', 'ワ', 'ヲ', 'ン', 'ー', 119 | 'ヮ', 'ヰ', 'ヱ', 'ヵ', 'ヶ', 'ヴ', 'ヽ', 'ヾ', '・', 120 | '「', '」', '。', '、'] 121 | FULL_ASCII = list('!"#$%&'()*+,-./:;<=>?@' 122 | 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 123 | '[\]^_`abcdefghijklmnopqrst' 124 | 'uvwxyz{|}~ ') 125 | FULL_DIGIT = list('0123456789') 126 | FULL_KANA = list('ァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソ' 127 | 'ゾタダチヂッツヅテデトドナニヌネノハバパヒビピフブプヘベペ' 128 | 'ホボポマミムメモャヤュユョヨラリルレロワヲンーヮヰヱヵヶヴ' 129 | 'ヽヾ・「」。、') 130 | FULL_KANA_SEION = list('ァアィイゥウェエォオカキクケコサシスセソタチッツテト' 131 | 'ナニヌネノハヒフヘホマミムメモャヤュユョヨラリルレロ' 132 | 'ワヲンーヮヰヱヵヶヽヾ・「」。、') 133 | HEPBURN = list('aiueoaiueon') 134 | HEPBURN_KANA = list('ぁぃぅぇぉあいうえおん') 135 | 136 | 137 | def _to_ord_list(chars): 138 | return list(map(ord, chars)) 139 | 140 | HIRAGANA_ORD = _to_ord_list(HIRAGANA) 141 | FULL_KANA_ORD = _to_ord_list(FULL_KANA) 142 | HALF_ASCII_ORD = _to_ord_list(HALF_ASCII) 143 | FULL_ASCII_ORD = _to_ord_list(FULL_ASCII) 144 | HALF_DIGIT_ORD = _to_ord_list(HALF_DIGIT) 145 | FULL_DIGIT_ORD = _to_ord_list(FULL_DIGIT) 146 | HALF_KANA_SEION_ORD = _to_ord_list(HALF_KANA_SEION) 147 | FULL_KANA_SEION_ORD = _to_ord_list(FULL_KANA_SEION) 148 | 149 | 150 | def _to_dict(_from, _to): 151 | return dict(zip(_from, _to)) 152 | 153 | 154 | H2K_TABLE = _to_dict(HIRAGANA_ORD, FULL_KANA) 155 | H2HK_TABLE = _to_dict(HIRAGANA_ORD, HALF_KANA) 156 | K2H_TABLE = _to_dict(FULL_KANA_ORD, HIRAGANA) 157 | 158 | H2Z_A = _to_dict(HALF_ASCII_ORD, FULL_ASCII) 159 | H2Z_AD = _to_dict(HALF_ASCII_ORD+HALF_DIGIT_ORD, FULL_ASCII+FULL_DIGIT) 160 | H2Z_AK = _to_dict(HALF_ASCII_ORD+HALF_KANA_SEION_ORD, 161 | FULL_ASCII+FULL_KANA_SEION) 162 | H2Z_D = _to_dict(HALF_DIGIT_ORD, FULL_DIGIT) 163 | H2Z_K = _to_dict(HALF_KANA_SEION_ORD, FULL_KANA_SEION) 164 | H2Z_DK = _to_dict(HALF_DIGIT_ORD+HALF_KANA_SEION_ORD, 165 | FULL_DIGIT+FULL_KANA_SEION) 166 | H2Z_ALL = _to_dict(HALF_ASCII_ORD+HALF_DIGIT_ORD+HALF_KANA_SEION_ORD, 167 | FULL_ASCII+FULL_DIGIT+FULL_KANA_SEION) 168 | 169 | Z2H_A = _to_dict(FULL_ASCII_ORD, HALF_ASCII) 170 | Z2H_AD = _to_dict(FULL_ASCII_ORD+FULL_DIGIT_ORD, HALF_ASCII+HALF_DIGIT) 171 | Z2H_AK = _to_dict(FULL_ASCII_ORD+FULL_KANA_ORD, HALF_ASCII+HALF_KANA) 172 | Z2H_D = _to_dict(FULL_DIGIT_ORD, HALF_DIGIT) 173 | Z2H_K = _to_dict(FULL_KANA_ORD, HALF_KANA) 174 | Z2H_DK = _to_dict(FULL_DIGIT_ORD+FULL_KANA_ORD, HALF_DIGIT+HALF_KANA) 175 | Z2H_ALL = _to_dict(FULL_ASCII_ORD+FULL_DIGIT_ORD+FULL_KANA_ORD, 176 | HALF_ASCII+HALF_DIGIT+HALF_KANA) 177 | KANA2HEP = _to_dict(_to_ord_list(HEPBURN_KANA), HEPBURN) 178 | HEP2KANA = _to_dict(_to_ord_list(HEPBURN), HEPBURN_KANA) 179 | 180 | del _to_ord_list 181 | del _to_dict 182 | del HIRAGANA_ORD 183 | del HIRAGANA 184 | del HALF_KANA 185 | del FULL_KANA_ORD 186 | del HALF_ASCII_ORD 187 | del HALF_ASCII 188 | del FULL_ASCII_ORD 189 | del FULL_ASCII 190 | del HALF_DIGIT_ORD 191 | del HALF_DIGIT 192 | del FULL_DIGIT_ORD 193 | del FULL_DIGIT 194 | del HALF_KANA_SEION_ORD 195 | del HALF_KANA_SEION 196 | del FULL_KANA_SEION_ORD 197 | del FULL_KANA_SEION 198 | del HEPBURN 199 | del HEPBURN_KANA 200 | -------------------------------------------------------------------------------- /transphone/lang/jpn/jaconv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import unicodedata 3 | from .conv_table import (H2K_TABLE, H2HK_TABLE, K2H_TABLE, H2Z_A, H2Z_AD, 4 | H2Z_AK, H2Z_D, H2Z_K, H2Z_DK, H2Z_ALL, Z2H_A, Z2H_AD, 5 | Z2H_AK, Z2H_D, Z2H_K, Z2H_DK, Z2H_ALL, KANA2HEP, HEP2KANA) 6 | 7 | consonants = frozenset('sdfghjklqwrtypzxcvbnm') 8 | 9 | def _exclude_ignorechar(ignore, conv_map): 10 | for character in map(ord, ignore): 11 | conv_map[character] = character 12 | return conv_map 13 | 14 | 15 | def _convert(text, conv_map): 16 | return text.translate(conv_map) 17 | 18 | 19 | def hira2kata(text, ignore=''): 20 | """Convert Hiragana to Full-width (Zenkaku) Katakana. 21 | Parameters 22 | ---------- 23 | text : str 24 | Hiragana string. 25 | ignore : str 26 | Characters to be ignored in converting. 27 | Return 28 | ------ 29 | str 30 | Katakana string. 31 | Examples 32 | -------- 33 | >>> print(jaconv.hira2kata('ともえまみ')) 34 | トモエマミ 35 | >>> print(jaconv.hira2kata('まどまぎ', ignore='ど')) 36 | マどマギ 37 | """ 38 | if ignore: 39 | h2k_map = _exclude_ignorechar(ignore, H2K_TABLE.copy()) 40 | return _convert(text, h2k_map) 41 | return _convert(text, H2K_TABLE) 42 | 43 | 44 | def hira2hkata(text, ignore=''): 45 | """Convert Hiragana to Half-width (Hankaku) Katakana 46 | Parameters 47 | ---------- 48 | text : str 49 | Hiragana string. 50 | ignore : str 51 | Characters to be ignored in converting. 52 | Return 53 | ------ 54 | str 55 | Half-width Katakana string. 56 | Examples 57 | -------- 58 | >>> print(jaconv.hira2hkata('ともえまみ')) 59 | トモエマミ 60 | >>> print(jaconv.hira2hkata('ともえまみ', ignore='み')) 61 | トモエマみ 62 | """ 63 | if ignore: 64 | h2hk_map = _exclude_ignorechar(ignore, H2HK_TABLE.copy()) 65 | return _convert(text, h2hk_map) 66 | return _convert(text, H2HK_TABLE) 67 | 68 | 69 | def kata2hira(text, ignore=''): 70 | """Convert Full-width Katakana to Hiragana 71 | Parameters 72 | ---------- 73 | text : str 74 | Full-width Katakana string. 75 | ignore : str 76 | Characters to be ignored in converting. 77 | Return 78 | ------ 79 | str 80 | Hiragana string. 81 | Examples 82 | -------- 83 | >>> print(jaconv.kata2hira('巴マミ')) 84 | 巴まみ 85 | >>> print(jaconv.kata2hira('マミサン', ignore='ン')) 86 | まみさン 87 | """ 88 | if ignore: 89 | k2h_map = _exclude_ignorechar(ignore, K2H_TABLE.copy()) 90 | return _convert(text, k2h_map) 91 | return _convert(text, K2H_TABLE) 92 | 93 | 94 | def h2z(text, ignore='', kana=True, ascii=False, digit=False): 95 | """Convert Half-width (Hankaku) Katakana to Full-width (Zenkaku) Katakana 96 | Parameters 97 | ---------- 98 | text : str 99 | Half-width Katakana string. 100 | ignore : str 101 | Characters to be ignored in converting. 102 | kana : bool 103 | Either converting Kana or not. 104 | ascii : bool 105 | Either converting ascii or not. 106 | digit : bool 107 | Either converting digit or not. 108 | Return 109 | ------ 110 | str 111 | Full-width Katakana string. 112 | Examples 113 | -------- 114 | >>> print(jaconv.h2z('ティロフィナーレ')) 115 | ティロフィナーレ 116 | >>> print(jaconv.h2z('ティロフィナーレ', ignore='ィ')) 117 | ティロフィナーレ 118 | >>> print(jaconv.h2z('abcd', ascii=True)) 119 | ABCD 120 | >>> print(jaconv.h2z('1234', digit=True)) 121 | 1234 122 | """ 123 | def _conv_dakuten(text): 124 | """Convert Hankaku Dakuten Kana to Zenkaku Dakuten Kana 125 | """ 126 | text = text.replace("ガ", "ガ").replace("ギ", "ギ") 127 | text = text.replace("グ", "グ").replace("ゲ", "ゲ") 128 | text = text.replace("ゴ", "ゴ").replace("ザ", "ザ") 129 | text = text.replace("ジ", "ジ").replace("ズ", "ズ") 130 | text = text.replace("ゼ", "ゼ").replace("ゾ", "ゾ") 131 | text = text.replace("ダ", "ダ").replace("ヂ", "ヂ") 132 | text = text.replace("ヅ", "ヅ").replace("デ", "デ") 133 | text = text.replace("ド", "ド").replace("バ", "バ") 134 | text = text.replace("ビ", "ビ").replace("ブ", "ブ") 135 | text = text.replace("ベ", "ベ").replace("ボ", "ボ") 136 | text = text.replace("パ", "パ").replace("ピ", "ピ") 137 | text = text.replace("プ", "プ").replace("ペ", "ペ") 138 | return text.replace("ポ", "ポ").replace("ヴ", "ヴ") 139 | 140 | if ascii: 141 | if digit: 142 | if kana: 143 | h2z_map = H2Z_ALL 144 | else: 145 | h2z_map = H2Z_AD 146 | elif kana: 147 | h2z_map = H2Z_AK 148 | else: 149 | h2z_map = H2Z_A 150 | elif digit: 151 | if kana: 152 | h2z_map = H2Z_DK 153 | else: 154 | h2z_map = H2Z_D 155 | else: 156 | h2z_map = H2Z_K 157 | if kana: 158 | text = _conv_dakuten(text) 159 | if ignore: 160 | h2z_map = _exclude_ignorechar(ignore, h2z_map.copy()) 161 | return _convert(text, h2z_map) 162 | 163 | 164 | def z2h(text, ignore='', kana=True, ascii=False, digit=False): 165 | """Convert Full-width (Zenkaku) Katakana to Half-width (Hankaku) Katakana 166 | Parameters 167 | ---------- 168 | text : str 169 | Full-width Katakana string. 170 | ignore : str 171 | Characters to be ignored in converting. 172 | kana : bool 173 | Either converting Kana or not. 174 | ascii : bool 175 | Either converting ascii or not. 176 | digit : bool 177 | Either converting digit or not. 178 | Return 179 | ------ 180 | str 181 | Half-width Katakana string. 182 | Examples 183 | -------- 184 | >>> print(jaconv.z2h('ティロフィナーレ')) 185 | ティロフィナーレ 186 | >>> print(jaconv.z2h('ティロフィナーレ', ignore='ィ')) 187 | ティロフィナーレ 188 | >>> print(jaconv.z2h('ABCD', ascii=True)) 189 | abcd 190 | >>> print(jaconv.z2h('1234', digit=True)) 191 | 1234 192 | """ 193 | if ascii: 194 | if digit: 195 | if kana: 196 | z2h_map = Z2H_ALL 197 | else: 198 | z2h_map = Z2H_AD 199 | elif kana: 200 | z2h_map = Z2H_AK 201 | else: 202 | z2h_map = Z2H_A 203 | elif digit: 204 | if kana: 205 | z2h_map = Z2H_DK 206 | else: 207 | z2h_map = Z2H_D 208 | else: 209 | z2h_map = Z2H_K 210 | if ignore: 211 | z2h_map = _exclude_ignorechar(ignore, z2h_map.copy()) 212 | return _convert(text, z2h_map) 213 | 214 | 215 | def normalize(text, mode='NFKC'): 216 | """Convert Half-width (Hankaku) Katakana to Full-width (Zenkaku) Katakana, 217 | Full-width (Zenkaku) ASCII and DIGIT to Half-width (Hankaku) ASCII 218 | and DIGIT. 219 | Additionally, Full-width wave dash (〜) etc. are normalized 220 | Parameters 221 | ---------- 222 | text : str 223 | Source string. 224 | mode : str 225 | Unicode normalization mode. 226 | Return 227 | ------ 228 | str 229 | Normalized string. 230 | Examples 231 | -------- 232 | >>> print(jaconv.normalize('ティロ・フィナ〜レ', 'NFKC')) 233 | ティロ・フィナーレ 234 | """ 235 | text = text.replace('〜', 'ー').replace('~', 'ー') 236 | text = text.replace("’", "'").replace('”', '"').replace('“', '``') 237 | text = text.replace('―', '-').replace('‐', '-').replace('˗', '-').replace('֊', '-') 238 | text = text.replace('‐', '-').replace('‑', '-').replace('‒', '-').replace('–', '-') 239 | text = text.replace('⁃', '-').replace('⁻', '-').replace('₋', '-').replace('−', '-') 240 | text = text.replace('﹣', 'ー').replace('-', 'ー').replace('—', 'ー').replace('―', 'ー') 241 | text = text.replace('━', 'ー').replace('─', 'ー') 242 | return unicodedata.normalize(mode, text) 243 | 244 | 245 | def kana2alphabet(text): 246 | """Convert Hiragana to hepburn-style alphabets 247 | Parameters 248 | ---------- 249 | text : str 250 | Hiragana string. 251 | Return 252 | ------ 253 | str 254 | Hepburn-style alphabets string. 255 | Examples 256 | -------- 257 | >>> print(jaconv.kana2alphabet('まみさん')) 258 | mamisan 259 | """ 260 | text = text.replace('きゃ', 'kya').replace('きゅ', 'kyu').replace('きょ', 'kyo') 261 | text = text.replace('ぎゃ', 'gya').replace('ぎゅ', 'gyu').replace('ぎょ', 'gyo') 262 | text = text.replace('しゃ', 'sha').replace('しゅ', 'shu').replace('しょ', 'sho') 263 | text = text.replace('じゃ', 'ja').replace('じゅ', 'ju').replace('じょ', 'jo') 264 | text = text.replace('ちゃ', 'cha').replace('ちゅ', 'chu').replace('ちょ', 'cho') 265 | text = text.replace('にゃ', 'nya').replace('にゅ', 'nyu').replace('にょ', 'nyo') 266 | text = text.replace('ふぁ', 'fa').replace('ふぃ', 'fi').replace('ふぇ', 'fe') 267 | text = text.replace('ふぉ', 'fo') 268 | text = text.replace('ひゃ', 'hya').replace('ひゅ', 'hyu').replace('ひょ', 'hyo') 269 | text = text.replace('みゃ', 'mya').replace('みゅ', 'myu').replace('みょ', 'myo') 270 | text = text.replace('りゃ', 'rya').replace('りゅ', 'ryu').replace('りょ', 'ryo') 271 | text = text.replace('びゃ', 'bya').replace('びゅ', 'byu').replace('びょ', 'byo') 272 | text = text.replace('ぴゃ', 'pya').replace('ぴゅ', 'pyu').replace('ぴょ', 'pyo') 273 | text = text.replace('が', 'ga').replace('ぎ', 'gi').replace('ぐ', 'gu') 274 | text = text.replace('げ', 'ge').replace('ご', 'go').replace('ざ', 'za') 275 | text = text.replace('じ', 'ji').replace('ず', 'zu').replace('ぜ', 'ze') 276 | text = text.replace('ぞ', 'zo').replace('だ', 'da').replace('ぢ', 'ji') 277 | text = text.replace('づ', 'zu').replace('で', 'de').replace('ど', 'do') 278 | text = text.replace('ば', 'ba').replace('び', 'bi').replace('ぶ', 'bu') 279 | text = text.replace('べ', 'be').replace('ぼ', 'bo').replace('ぱ', 'pa') 280 | text = text.replace('ぴ', 'pi').replace('ぷ', 'pu').replace('ぺ', 'pe') 281 | text = text.replace('ぽ', 'po') 282 | text = text.replace('か', 'ka').replace('き', 'ki').replace('く', 'ku') 283 | text = text.replace('け', 'ke').replace('こ', 'ko').replace('さ', 'sa') 284 | text = text.replace('し', 'shi').replace('す', 'su').replace('せ', 'se') 285 | text = text.replace('そ', 'so').replace('た', 'ta').replace('ち', 'chi') 286 | text = text.replace('つ', 'tsu').replace('て', 'te').replace('と', 'to') 287 | text = text.replace('な', 'na').replace('に', 'ni').replace('ぬ', 'nu') 288 | text = text.replace('ね', 'ne').replace('の', 'no').replace('は', 'ha') 289 | text = text.replace('ひ', 'hi').replace('ふ', 'fu').replace('へ', 'he') 290 | text = text.replace('ほ', 'ho').replace('ま', 'ma').replace('み', 'mi') 291 | text = text.replace('む', 'mu').replace('め', 'me').replace('も', 'mo') 292 | text = text.replace('ら', 'ra').replace('り', 'ri').replace('る', 'ru') 293 | text = text.replace('れ', 're').replace('ろ', 'ro') 294 | text = text.replace('や', 'ya').replace('ゆ', 'yu').replace('よ', 'yo') 295 | text = text.replace('わ', 'wa').replace('ゐ', 'wi').replace('を', 'wo') 296 | text = text.replace('ゑ', 'we') 297 | text = _convert(text, KANA2HEP) 298 | while 'っ' in text: 299 | text = list(text) 300 | tsu_pos = text.index('っ') 301 | if len(text) <= tsu_pos + 1: 302 | return ''.join(text[:-1]) + 'xtsu' 303 | if tsu_pos == 0: 304 | text[tsu_pos] = 'xtsu' 305 | else: 306 | text[tsu_pos] = text[tsu_pos + 1] 307 | text = ''.join(text) 308 | return text 309 | 310 | 311 | def alphabet2kana(text): 312 | """Convert alphabets to Hiragana 313 | Parameters 314 | ---------- 315 | text : str 316 | Alphabets string. 317 | Return 318 | ------ 319 | str 320 | Hiragana string. 321 | Examples 322 | -------- 323 | >>> print(jaconv.alphabet2kana('mamisan')) 324 | まみさん 325 | """ 326 | text = text.replace('kya', 'きゃ').replace('kyu', 'きゅ').replace('kyo', 'きょ') 327 | text = text.replace('gya', 'ぎゃ').replace('gyu', 'ぎゅ').replace('gyo', 'ぎょ') 328 | text = text.replace('sha', 'しゃ').replace('shu', 'しゅ').replace('sho', 'しょ') 329 | text = text.replace('zya', 'じゃ').replace('zyu', 'じゅ').replace('zyo', 'じょ') 330 | text = text.replace('zyi', 'じぃ').replace('zye', 'じぇ') 331 | text = text.replace('ja', 'じゃ').replace('ju', 'じゅ').replace('jo', 'じょ') 332 | text = text.replace('jya', 'じゃ').replace('jyu', 'じゅ').replace('jyo', 'じょ') 333 | text = text.replace('cha', 'ちゃ').replace('chu', 'ちゅ').replace('cho', 'ちょ') 334 | text = text.replace('tya', 'ちゃ').replace('tyu', 'ちゅ').replace('tyo', 'ちょ') 335 | text = text.replace('nya', 'にゃ').replace('nyu', 'にゅ').replace('nyo', 'にょ') 336 | text = text.replace('hya', 'ひゃ').replace('hyu', 'ひゅ').replace('hyo', 'ひょ') 337 | text = text.replace('mya', 'みゃ').replace('myu', 'みゅ').replace('myo', 'みょ') 338 | text = text.replace('rya', 'りゃ').replace('ryu', 'りゅ').replace('ryo', 'りょ') 339 | text = text.replace('bya', 'びゃ').replace('byu', 'びゅ').replace('byo', 'びょ') 340 | text = text.replace('pya', 'ぴゃ').replace('pyu', 'ぴゅ').replace('pyo', 'ぴょ') 341 | text = text.replace('oh', 'おお') 342 | text = text.replace('ga', 'が').replace('gi', 'ぎ').replace('gu', 'ぐ') 343 | text = text.replace('ge', 'げ').replace('go', 'ご').replace('za', 'ざ') 344 | text = text.replace('ji', 'じ').replace('zu', 'ず').replace('ze', 'ぜ') 345 | text = text.replace('zo', 'ぞ').replace('da', 'だ').replace('ji', 'ぢ').replace('di', 'ぢ') 346 | text = text.replace('va', 'ゔぁ').replace('vi', 'ゔぃ').replace('vu', 'ゔ') 347 | text = text.replace('ve', 'ゔぇ').replace('vo', 'ゔぉ').replace('vya', 'ゔゃ') 348 | text = text.replace('vyi', 'ゔぃ').replace('vyu', 'ゔゅ').replace('vye', 'ゔぇ') 349 | text = text.replace('vyo', 'ゔょ') 350 | text = text.replace('zu', 'づ').replace('de', 'で').replace('do', 'ど') 351 | text = text.replace('ba', 'ば').replace('bi', 'び').replace('bu', 'ぶ') 352 | text = text.replace('be', 'べ').replace('bo', 'ぼ').replace('pa', 'ぱ') 353 | text = text.replace('pi', 'ぴ').replace('pu', 'ぷ').replace('pe', 'ぺ') 354 | text = text.replace('po', 'ぽ').replace('dha', 'でゃ').replace('dhi', 'でぃ') 355 | text = text.replace('dhu', 'でゅ').replace('dhe', 'でぇ').replace('dho', 'でょ') 356 | text = text.replace('ka', 'か').replace('ki', 'き').replace('ku', 'く') 357 | text = text.replace('ke', 'け').replace('ko', 'こ').replace('sa', 'さ') 358 | text = text.replace('shi', 'し').replace('su', 'す').replace('se', 'せ') 359 | text = text.replace('so', 'そ').replace('ta', 'た').replace('chi', 'ち') 360 | text = text.replace('tsu', 'つ').replace('te', 'て').replace('to', 'と') 361 | text = text.replace('na', 'な').replace('ni', 'に').replace('nu', 'ぬ') 362 | text = text.replace('ne', 'ね').replace('no', 'の').replace('ha', 'は') 363 | text = text.replace('hi', 'ひ').replace('fu', 'ふ').replace('he', 'へ') 364 | text = text.replace('ho', 'ほ').replace('ma', 'ま').replace('mi', 'み') 365 | text = text.replace('mu', 'む').replace('me', 'め').replace('mo', 'も') 366 | text = text.replace('ra', 'ら').replace('ri', 'り').replace('ru', 'る') 367 | text = text.replace('re', 'れ').replace('ro', 'ろ') 368 | text = text.replace('ya', 'や').replace('yu', 'ゆ').replace('yo', 'よ') 369 | text = text.replace('wa', 'わ').replace('wi', 'ゐ').replace('we', 'ゑ') 370 | text = text.replace('wo', 'を') 371 | text = text.replace('nn', 'ん').replace('tu', 'つ').replace('hu', 'ふ') 372 | text = text.replace('fa', 'ふぁ').replace('fi', 'ふぃ').replace('fe', 'ふぇ') 373 | text = text.replace('fo', 'ふぉ').replace('-', 'ー') 374 | text = _convert(text, HEP2KANA) 375 | ret = [] 376 | for (i, char) in enumerate(text): 377 | if char in consonants: 378 | char = 'っ' 379 | ret.append(char) 380 | return ''.join(ret) 381 | -------------------------------------------------------------------------------- /transphone/lang/jpn/kana2phoneme.py: -------------------------------------------------------------------------------- 1 | from .conv_table import FULL_KANA 2 | from phonepiece.inventory import read_inventory 3 | from transphone.config import TransphoneConfig 4 | 5 | _kana2phonemes = { 6 | 'ア': 'a', 7 | 'イ': 'i', 8 | 'ウ': 'u', 9 | 'エ': 'e', 10 | 'オ': 'o', 11 | 'カ': 'k a', 12 | 'キ': 'k i', 13 | 'ク': 'k u', 14 | 'ケ': 'k e', 15 | 'コ': 'k o', 16 | 'ガ': 'g a', 17 | 'ギ': 'g i', 18 | 'グ': 'g u', 19 | 'ゲ': 'g e', 20 | 'ゴ': 'g o', 21 | 'サ': 's a', 22 | 'シ': 'sh i', 23 | 'ス': 's u', 24 | 'セ': 's e', 25 | 'ソ': 's o', 26 | 'ザ': 'z a', 27 | 'ジ': 'j i', 28 | 'ズ': 'z u', 29 | 'ゼ': 'z e', 30 | 'ゾ': 'z o', 31 | 'タ': 't a', 32 | 'チ': 'ch i', 33 | 'ツ': 't͡s u', 34 | 'テ': 't e', 35 | 'ト': 't o', 36 | 'ダ': 'd a', 37 | 'ヂ': 'j i', 38 | 'ヅ': 'z u', 39 | 'デ': 'd e', 40 | 'ド': 'd o', 41 | 'ナ': 'n a', 42 | 'ニ': 'n i', 43 | 'ヌ': 'n u', 44 | 'ネ': 'n e', 45 | 'ノ': 'n o', 46 | 'ハ': 'h a', 47 | 'ヒ': 'h i', 48 | 'フ': 'f u', 49 | 'ヘ': 'h e', 50 | 'ホ': 'h o', 51 | 'バ': 'b a', 52 | 'ビ': 'b i', 53 | 'ブ': 'b u', 54 | 'ベ': 'b e', 55 | 'ボ': 'b o', 56 | 'パ': 'p a', 57 | 'ピ': 'p i', 58 | 'プ': 'p u', 59 | 'ペ': 'p e', 60 | 'ポ': 'p o', 61 | 'マ': 'm a', 62 | 'ミ': 'm i', 63 | 'ム': 'm u', 64 | 'メ': 'm e', 65 | 'モ': 'm o', 66 | 'ラ': 'r a', 67 | 'リ': 'r i', 68 | 'ル': 'r u', 69 | 'レ': 'r e', 70 | 'ロ': 'r o', 71 | 'ワ': 'w a', 72 | 'ヲ': 'o', 73 | 'ヤ': 'y a', 74 | 'ユ': 'y u', 75 | 'ヨ': 'y o', 76 | 'キャ': 'ky a', 77 | 'キュ': 'ky u', 78 | 'キョ': 'ky o', 79 | 'ギャ': 'gy a', 80 | 'ギュ': 'gy u', 81 | 'ギョ': 'gy o', 82 | 'シャ': 'sh a', 83 | 'シュ': 'sh u', 84 | 'ショ': 'sh o', 85 | 'ジャ': 'j a', 86 | 'ジュ': 'j u', 87 | 'ジョ': 'j o', 88 | 'チャ': 'ch a', 89 | 'チュ': 'ch u', 90 | 'チョ': 'ch o', 91 | 'ニャ': 'ny a', 92 | 'ニュ': 'ny u', 93 | 'ニョ': 'ny o', 94 | 'ヒャ': 'hy a', 95 | 'ヒュ': 'hy u', 96 | 'ヒョ': 'hy o', 97 | 'ビャ': 'by a', 98 | 'ビュ': 'by u', 99 | 'ビョ': 'by o', 100 | 'ピャ': 'py a', 101 | 'ピュ': 'py u', 102 | 'ピョ': 'py o', 103 | 'ミャ': 'my a', 104 | 'ミュ': 'my u', 105 | 'ミョ': 'my o', 106 | 'リャ': 'ry a', 107 | 'リュ': 'ry u', 108 | 'リョ': 'ry o', 109 | 'イェ': 'i e', 110 | 'シェ': 'sh e', 111 | 'ジェ': 'j e', 112 | 'ティ': 't i', 113 | 'トゥ': 't u', 114 | 'チェ': 'ch e', 115 | 'ツァ': 't͡s a', 116 | 'ツィ': 't͡s i', 117 | 'ツェ': 't͡s e', 118 | 'ツォ': 't͡s o', 119 | 'ディ': 'd i', 120 | 'ドゥ': 'd u', 121 | 'デュ': 'd u', 122 | 'ニェ': 'n i e', 123 | 'ヒェ': 'h e', 124 | 'ファ': 'f a', 125 | 'フィ': 'f i', 126 | 'フェ': 'f e', 127 | 'フォ': 'f o', 128 | 'フュ': 'hy u', 129 | 'ブィ': 'b i', 130 | 'ミェ': 'm e', 131 | 'ウィ': 'w i', 132 | 'ウェ': 'w e', 133 | 'ウォ': 'w o', 134 | 'クヮ': 'k a', 135 | 'グヮ': 'g a', 136 | 'スィ': 's u i', 137 | 'ズィ': 'j i', 138 | 'テュ': 't e y u', 139 | 'ヴァ': 'b a', 140 | 'ヴィ': 'b i', 141 | 'ヴ': 'b u', 142 | 'ヴェ': 'b e', 143 | 'ヴォ': 'b o', 144 | 'ン': 'N', 145 | 'ッ': 'q', 146 | 'ー': 'ː' 147 | } 148 | 149 | import re 150 | 151 | class Kana2Phoneme: 152 | def __init__(self): 153 | 154 | self.phoneme_set = set(read_inventory('jpn').phoneme.elems) 155 | 156 | self._dict1 = { 157 | 'キャ': 'ky a ', 158 | 'キュ': 'ky u ', 159 | 'キョ': 'ky o ', 160 | 'ギャ': 'gy a ', 161 | 'ギュ': 'gy u ', 162 | 'ギョ': 'gy o ', 163 | 'シャ': 'sh a ', 164 | 'シュ': 'sh u ', 165 | 'ショ': 'sh o ', 166 | 'ジャ': 'j a ', 167 | 'ジュ': 'j u ', 168 | 'ジョ': 'j o ', 169 | 'チャ': 'ch a ', 170 | 'チュ': 'ch u ', 171 | 'チョ': 'ch o ', 172 | 'ニャ': 'ny a ', 173 | 'ニュ': 'ny u ', 174 | 'ニョ': 'ny o ', 175 | 'ヒャ': 'hy a ', 176 | 'ヒュ': 'hy u ', 177 | 'ヒョ': 'hy o ', 178 | 'ビャ': 'by a ', 179 | 'ビュ': 'by u ', 180 | 'ビョ': 'by o ', 181 | 'ピャ': 'py a ', 182 | 'ピュ': 'py u ', 183 | 'ピョ': 'py o ', 184 | 'ミャ': 'my a ', 185 | 'ミュ': 'my u ', 186 | 'ミョ': 'my o ', 187 | 'リャ': 'ry a ', 188 | 'リュ': 'ry u ', 189 | 'リョ': 'ry o ', 190 | 'イェ': 'i e ', 191 | 'シェ': 'sh e ', 192 | 'ジェ': 'j e ', 193 | 'ティ': 't i ', 194 | 'トゥ': 't u ', 195 | 'チェ': 'ch e ', 196 | 'ツァ': 't͡s a ', 197 | 'ツィ': 't͡s i ', 198 | 'ツェ': 't͡s e ', 199 | 'ツォ': 't͡s o ', 200 | 'ディ': 'd i ', 201 | 'ドゥ': 'd u ', 202 | 'デュ': 'd u ', 203 | 'ニェ': 'n i e ', 204 | 'ヒェ': 'h e ', 205 | 'ファ': 'f a ', 206 | 'フィ': 'f i ', 207 | 'フェ': 'f e ', 208 | 'フォ': 'f o ', 209 | 'フュ': 'hy u ', 210 | 'ブィ': 'b i ', 211 | 'ミェ': 'm e ', 212 | 'ウィ': 'w i ', 213 | 'ウェ': 'w e ', 214 | 'ウォ': 'w o ', 215 | 'クヮ': 'k a ', 216 | 'グヮ': 'g a ', 217 | 'スィ': 's u i ', 218 | 'ズィ': 'j i ', 219 | 'テュ': 't e y u ', 220 | 'ヴァ': 'b a ', 221 | 'ヴィ': 'b i ', 222 | 'ヴ': 'b u ', 223 | 'ヴェ': 'b e ', 224 | 'ヴォ': 'b o ', 225 | } 226 | self._dict2 = { 227 | 'ア': 'a ', 228 | 'イ': 'i ', 229 | 'ウ': 'u ', 230 | 'エ': 'e ', 231 | 'オ': 'o ', 232 | 'カ': 'k a ', 233 | 'キ': 'k i ', 234 | 'ク': 'k u ', 235 | 'ケ': 'k e ', 236 | 'コ': 'k o ', 237 | 'ガ': 'g a ', 238 | 'ギ': 'g i ', 239 | 'グ': 'g u ', 240 | 'ゲ': 'g e ', 241 | 'ゴ': 'g o ', 242 | 'サ': 's a ', 243 | 'シ': 'sh i ', 244 | 'ス': 's u ', 245 | 'セ': 's e ', 246 | 'ソ': 's o ', 247 | 'ザ': 'z a ', 248 | 'ジ': 'j i ', 249 | 'ズ': 'z u ', 250 | 'ゼ': 'z e ', 251 | 'ゾ': 'z o ', 252 | 'タ': 't a ', 253 | 'チ': 'ch i ', 254 | 'ツ': 't͡s u ', 255 | 'テ': 't e ', 256 | 'ト': 't o ', 257 | 'ダ': 'd a ', 258 | 'ヂ': 'j i ', 259 | 'ヅ': 'z u ', 260 | 'デ': 'd e ', 261 | 'ド': 'd o ', 262 | 'ナ': 'n a ', 263 | 'ニ': 'n i ', 264 | 'ヌ': 'n u ', 265 | 'ネ': 'n e ', 266 | 'ノ': 'n o ', 267 | 'ハ': 'h a ', 268 | 'ヒ': 'h i ', 269 | 'フ': 'f u ', 270 | 'ヘ': 'h e ', 271 | 'ホ': 'h o ', 272 | 'バ': 'b a ', 273 | 'ビ': 'b i ', 274 | 'ブ': 'b u ', 275 | 'ベ': 'b e ', 276 | 'ボ': 'b o ', 277 | 'パ': 'p a ', 278 | 'ピ': 'p i ', 279 | 'プ': 'p u ', 280 | 'ペ': 'p e ', 281 | 'ポ': 'p o ', 282 | 'マ': 'm a ', 283 | 'ミ': 'm i ', 284 | 'ム': 'm u ', 285 | 'メ': 'm e ', 286 | 'モ': 'm o ', 287 | 'ラ': 'r a ', 288 | 'リ': 'r i ', 289 | 'ル': 'r u ', 290 | 'レ': 'r e ', 291 | 'ロ': 'r o ', 292 | 'ワ': 'w a ', 293 | 'ヲ': 'o ', 294 | 'ヤ': 'y a ', 295 | 'ユ': 'y u ', 296 | 'ヨ': 'y o ', 297 | 'ン': 'N ', 298 | 'ッ': 'q ', 299 | 'ー': 'ː ', 300 | 'ァ': 'a ', 301 | 'ィ': 'i ', 302 | 'ゥ': 'u ', 303 | 'ェ': 'e ', 304 | 'ォ': 'o ', 305 | 'ャ': 'y a ', 306 | 'ュ': 'y u ', 307 | 'ョ': 'y o ', 308 | 'ヮ': 'w a ', 309 | 'ヵ': 'k a ', 310 | 'ヶ': 'k e ', 311 | 'ヰ': 'i ', 312 | 'ヱ': 'e ', 313 | 'ヴ': 'b u ', 314 | 'ヽ': '', # this should not reach here though 315 | 'ヾ': '', # this should not reach here though 316 | 'ヷ': 'w a ', 317 | 'ヸ': 'i ', 318 | 'ヹ': 'e ', 319 | 'ヺ': 'o ', 320 | 'ヿ': '', # this should not reach here though 321 | '゛': '', # this should not reach here though 322 | '゜': '', # this should not reach here though 323 | 'ヽ': '', # this should not reach here though 324 | 'ヾ': '', # this should not reach here though 325 | 'ゝ': '', # this should not reach here though 326 | 'ゞ': '', # this should not reach here though 327 | '〆': '', # this should not reach here though 328 | '々': '', # this should not reach here though 329 | } 330 | self._regex1 = re.compile(u"(%s)" % u"|".join(map(re.escape, self._dict1.keys()))) 331 | self._regex2 = re.compile(u"(%s)" % u"|".join(map(re.escape, self._dict2.keys()))) 332 | 333 | def validate(self, text): 334 | for word in text.strip(): 335 | if word not in FULL_KANA: 336 | return False 337 | return True 338 | 339 | def convert(self, origin_text): 340 | 341 | if isinstance(origin_text, list): 342 | text = ' '.join(origin_text) 343 | else: 344 | text = origin_text 345 | 346 | ret = text 347 | ret = self._regex1.sub(lambda m: self._dict1[m.string[m.start():m.end()]], ret) 348 | ret = self._regex2.sub(lambda m: self._dict2[m.string[m.start():m.end()]], ret) 349 | 350 | temp_phonemes = ret.split() 351 | 352 | phonemes = [] 353 | for temp_phoneme in temp_phonemes: 354 | if temp_phoneme == 'ː': 355 | if len(phonemes) > 0 and phonemes[-1] in ['a', 'i', 'u', 'e', 'o']: 356 | phonemes[-1] = phonemes[-1]+'ː' 357 | continue 358 | 359 | if temp_phoneme in self.phoneme_set: 360 | phonemes.append(temp_phoneme) 361 | else: 362 | TransphoneConfig.logger.error("Unknown phoneme: %s" % temp_phoneme) 363 | 364 | return phonemes 365 | -------------------------------------------------------------------------------- /transphone/lang/jpn/normalizer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf8 2 | """ 3 | from: https://github.com/neologd/mecab-ipadic-neologd/wiki/Regexp.ja 4 | """ 5 | from __future__ import unicode_literals 6 | import re 7 | import unicodedata 8 | 9 | def unicode_normalize(cls, s): 10 | pt = re.compile('([{}]+)'.format(cls)) 11 | 12 | def norm(c): 13 | return unicodedata.normalize('NFKC', c) if pt.match(c) else c 14 | 15 | s = ''.join(norm(x) for x in re.split(pt, s)) 16 | s = re.sub('-', '-', s) 17 | return s 18 | 19 | def remove_extra_spaces(s): 20 | s = re.sub('[  ]+', ' ', s) 21 | blocks = ''.join(('\u4E00-\u9FFF', # CJK UNIFIED IDEOGRAPHS 22 | '\u3040-\u309F', # HIRAGANA 23 | '\u30A0-\u30FF', # KATAKANA 24 | '\u3000-\u303F', # CJK SYMBOLS AND PUNCTUATION 25 | '\uFF00-\uFFEF' # HALFWIDTH AND FULLWIDTH FORMS 26 | )) 27 | basic_latin = '\u0000-\u007F' 28 | 29 | def remove_space_between(cls1, cls2, s): 30 | p = re.compile('([{}]) ([{}])'.format(cls1, cls2)) 31 | while p.search(s): 32 | s = p.sub(r'\1\2', s) 33 | return s 34 | 35 | s = remove_space_between(blocks, blocks, s) 36 | s = remove_space_between(blocks, basic_latin, s) 37 | s = remove_space_between(basic_latin, blocks, s) 38 | return s 39 | 40 | def normalize_neologd(s): 41 | s = s.strip() 42 | s = unicode_normalize('0-9A-Za-z。-゚', s) 43 | 44 | def maketrans(f, t): 45 | return {ord(x): ord(y) for x, y in zip(f, t)} 46 | 47 | s = re.sub('[˗֊‐‑‒–⁃⁻₋−]+', '-', s) # normalize hyphens 48 | s = re.sub('[﹣-ー—―─━ー]+', 'ー', s) # normalize choonpus 49 | s = re.sub('[~∼∾〜〰~]', '', s) # remove tildes 50 | s = s.translate( 51 | maketrans('!"#$%&\'()*+,-./:;<=>?@[¥]^_`{|}~。、・「」', 52 | '!”#$%&’()*+,-./:;<=>?@[¥]^_`{|}〜。、・「」')) 53 | 54 | s = remove_extra_spaces(s) 55 | s = unicode_normalize('!”#$%&’()*+,-./:;<>?@[¥]^_`{|}〜', s) # keep =,・,「,」 56 | s = re.sub('[’]', '\'', s) 57 | s = re.sub('[”]', '"', s) 58 | return s 59 | 60 | 61 | def parse_small_jpn_number(num): 62 | 63 | if num == 0: 64 | return "ぜろ" 65 | if num == 1: 66 | return "いち" 67 | 68 | digit = ["", "いち", "に", "さん", "よん", "ご", "ろく", "なな", "はち", "きゅう"] 69 | unit = ["", "じゅう", "ひゃく", "せん"] 70 | dakuon_unit = ["", "じゅう", "びゃく", "ぜん"] 71 | sokuon_unit = ["", "じゅう", "ぴゃく", "せん"] 72 | 73 | num_str = str(num) 74 | n = len(num_str) 75 | read_list = [] 76 | for i in range(n): 77 | d = int(num_str[i]) 78 | dakuon = False 79 | sokuon = False 80 | if d != 0: 81 | # skip 1 unless it is the last char 82 | if i != n-1 and d == 1: 83 | read_list.append("") 84 | # 濁音 85 | elif d == 3: 86 | dakuon = True 87 | read_list.append(digit[d]) 88 | elif d == 6 and n-i == 3: 89 | read_list.append("ろっ") 90 | sokuon = True 91 | elif d == 8 and n-i == 3: 92 | read_list.append("はっ") 93 | sokuon = True 94 | else: 95 | read_list.append(digit[d]) 96 | 97 | if dakuon: 98 | read_list.append(dakuon_unit[n-i-1]) 99 | elif sokuon: 100 | read_list.append(sokuon_unit[n-i-1]) 101 | else: 102 | read_list.append(unit[n - i - 1]) 103 | read_str = "".join(read_list) 104 | return read_str 105 | 106 | 107 | def parse_jpn_alphabet(text): 108 | 109 | katakana_dict = { 110 | 'A': 'エー', 111 | 'B': 'ビー', 112 | 'C': 'シー', 113 | 'D': 'ディー', 114 | 'E': 'イー', 115 | 'F': 'エフ', 116 | 'G': 'ジー', 117 | 'H': 'エイチ', 118 | 'I': 'アイ', 119 | 'J': 'ジェー', 120 | 'K': 'ケー', 121 | 'L': 'エル', 122 | 'M': 'エム', 123 | 'N': 'エヌ', 124 | 'O': 'オー', 125 | 'P': 'ピー', 126 | 'Q': 'キュー', 127 | 'R': 'アール', 128 | 'S': 'エス', 129 | 'T': 'ティー', 130 | 'U': 'ユー', 131 | 'V': 'ブイ', 132 | 'W': 'ダブリュー', 133 | 'X': 'エックス', 134 | 'Y': 'ワイ', 135 | 'Z': 'ゼット', 136 | } 137 | 138 | katakana_text = '' 139 | for char in text.upper(): 140 | if char in katakana_dict: 141 | katakana_text += katakana_dict[char] 142 | return katakana_text 143 | 144 | def parse_jpn_number(num): 145 | 146 | num_size = len(str(num)) 147 | if num_size <= 4: 148 | return parse_small_jpn_number(num) 149 | elif num_size <= 8: 150 | low_digit = num%10000 151 | high_digit = num//10000 152 | 153 | high_read = parse_small_jpn_number(high_digit) + 'まん' 154 | low_read = "" 155 | if low_digit != 0: 156 | low_read = parse_small_jpn_number(low_digit) 157 | return high_read + low_read 158 | else: 159 | 160 | low_digit = num % 10000 161 | mid_digit = (num//10000)%10000 162 | high_digit = num//100000000 163 | 164 | high_read = parse_small_jpn_number(high_digit) + 'おく' 165 | mid_read = "" 166 | if mid_digit != 0: 167 | mid_read = parse_small_jpn_number(mid_digit) + 'まん' 168 | low_read = "" 169 | if low_digit != 0: 170 | low_read = parse_small_jpn_number(low_digit) 171 | 172 | return high_read + mid_read + low_read -------------------------------------------------------------------------------- /transphone/lang/jpn/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transphone.utils import import_with_auto_install 2 | from transphone.lang.jpn.kana2phoneme import Kana2Phoneme 3 | from transphone.g2p import read_g2p 4 | from phonepiece.inventory import read_inventory 5 | from transphone.lang.jpn import jaconv 6 | from transphone.lang.base_tokenizer import BaseTokenizer 7 | from transphone.lang.jpn.normalizer import normalize_neologd, parse_jpn_number, parse_jpn_alphabet 8 | 9 | 10 | def read_jpn_tokenizer(lang_id, g2p_model='latest', device=None, use_lexicon=True): 11 | return JPNTokenizer(lang_id, g2p_model, device) 12 | 13 | class JPNTokenizer(BaseTokenizer): 14 | 15 | def __init__(self, lang_id, g2p_model='latest', device=None): 16 | 17 | super().__init__(lang_id, g2p_model, device) 18 | 19 | # import mecab and its dict 20 | MeCab = import_with_auto_install('MeCab', 'mecab-python3') 21 | import_with_auto_install('unidic_lite', 'unidic-lite') 22 | 23 | self.tagger = MeCab.Tagger() 24 | self.kana2phoneme = Kana2Phoneme() 25 | 26 | def tokenize_words(self, text:str): 27 | text = normalize_neologd(text) 28 | 29 | raw_words = self.tagger.parse(text).split('\n') 30 | 31 | result = [] 32 | 33 | # exclude the last EOS word 34 | for word in raw_words[:-2]: 35 | raw = word.split('\t')[0] 36 | result.append(raw) 37 | 38 | return result 39 | 40 | def tokenize(self, text, use_g2p=True, use_space=False, verbose=False): 41 | 42 | text = normalize_neologd(text) 43 | 44 | raw_words = self.tagger.parse(text).split('\n') 45 | 46 | result = [] 47 | 48 | # exclude the last EOS word 49 | for word in raw_words[:-2]: 50 | 51 | kana = word.split('\t')[1] 52 | raw = word.split('\t')[0] 53 | 54 | if str.isdigit(raw): 55 | hankaku_num = jaconv.z2h(kana) 56 | kana = jaconv.hira2kata(parse_jpn_number(hankaku_num)) 57 | 58 | # interestingly, isalpha also return true for kana 59 | if str.isalpha(raw) and str.isascii(raw): 60 | kana = parse_jpn_alphabet(raw) 61 | 62 | res = self.kana2phoneme.convert(kana) 63 | 64 | if res == ['*'] and self.kana2phoneme.validate(raw): 65 | res = self.kana2phoneme.convert(raw) 66 | 67 | if verbose: 68 | print(kana, res) 69 | 70 | if res != ['*']: 71 | result.extend(res) 72 | elif use_g2p: 73 | if raw in self.cache: 74 | result.extend(self.cache[raw]) 75 | else: 76 | phonemes = self.g2p.inference(raw) 77 | remapped_phonemes = self.inventory.remap(phonemes) 78 | if verbose: 79 | print(f"g2p {raw} -> {remapped_phonemes}") 80 | self.cache[raw] = remapped_phonemes 81 | result.extend(remapped_phonemes) 82 | 83 | if use_space: 84 | result.append('') 85 | 86 | return result -------------------------------------------------------------------------------- /transphone/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinjli/transphone/7769dec083b90dab072fc7e4c592744592a35736/transphone/model/__init__.py -------------------------------------------------------------------------------- /transphone/model/checkpoint_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from collections import OrderedDict 7 | from typing import Union 8 | import collections 9 | import logging 10 | import os 11 | import re 12 | import traceback 13 | import shutil 14 | 15 | import torch 16 | from torch.serialization import default_restore_location 17 | from pathlib import Path 18 | from transphone.config import TransphoneConfig 19 | 20 | 21 | def find_topk_models(exp_name, topk=1): 22 | 23 | if isinstance(exp_name, str): 24 | exp_dir = TransphoneConfig.data_path / 'model' / exp_name 25 | else: 26 | assert isinstance(exp_name, Path) 27 | exp_dir = exp_name 28 | 29 | model_lst = [] 30 | for model_path in exp_dir.glob('model_*.pt'): 31 | perf = model_path.stem.split('_')[1] 32 | model_lst.append((float(perf), model_path)) 33 | 34 | model_lst.sort() 35 | topk_models = [model[1] for model in model_lst[:topk]] 36 | return topk_models 37 | 38 | 39 | def torch_save(model, path): 40 | """Save torch model states. 41 | 42 | Args: 43 | path (str): Model path to be saved. 44 | model (torch.nn.Module): Torch model. 45 | 46 | """ 47 | path = str(path) 48 | if hasattr(model, 'module'): 49 | torch.save(model.module.state_dict(), path) 50 | else: 51 | torch.save(model.state_dict(), path) 52 | 53 | 54 | def torch_load(model, path): 55 | """Load torch model states. 56 | 57 | Args: 58 | path (str): Model path or snapshot file path to be loaded. 59 | model (torch.nn.Module): Torch model. 60 | 61 | """ 62 | model_state_dict = torch.load(str(path), map_location=torch.device('cpu')) 63 | 64 | new_state_dict = OrderedDict() 65 | for k, v in model_state_dict.items(): 66 | 67 | if k.startswith('module.'): 68 | name = k[7:] # remove `module.` 69 | else: 70 | name = k 71 | 72 | new_state_dict[name] = v 73 | 74 | if hasattr(model, 'module'): 75 | model.module.load_state_dict(new_state_dict) 76 | else: 77 | model.load_state_dict(new_state_dict) 78 | 79 | del model_state_dict, new_state_dict 80 | 81 | 82 | def save_checkpoint(args, trainer, epoch_itr, val_loss): 83 | from pyspeech.ml.torch import distributed_utils, meters 84 | 85 | prev_best = getattr(save_checkpoint, 'best', val_loss) 86 | if val_loss is not None: 87 | best_function = max if args.maximize_best_checkpoint_metric else min 88 | save_checkpoint.best = best_function(val_loss, prev_best) 89 | 90 | if args.no_save or not distributed_utils.is_master(args): 91 | return 92 | 93 | def is_better(a, b): 94 | return a >= b if args.maximize_best_checkpoint_metric else a <= b 95 | 96 | write_timer = meters.StopwatchMeter() 97 | write_timer.start() 98 | 99 | epoch = epoch_itr.epoch 100 | end_of_epoch = epoch_itr.end_of_epoch() 101 | updates = trainer.get_num_updates() 102 | 103 | checkpoint_conds = collections.OrderedDict() 104 | checkpoint_conds['checkpoint{}.pt'.format(epoch)] = ( 105 | end_of_epoch and not args.no_epoch_checkpoints and 106 | epoch % args.save_interval == 0 107 | ) 108 | checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = ( 109 | not end_of_epoch and args.save_interval_updates > 0 and 110 | updates % args.save_interval_updates == 0 111 | ) 112 | checkpoint_conds['checkpoint_best.pt'] = ( 113 | val_loss is not None and 114 | (not hasattr(save_checkpoint, 'best') or is_better(val_loss, save_checkpoint.best)) 115 | ) 116 | checkpoint_conds['checkpoint_last.pt'] = not args.no_last_checkpoints 117 | 118 | extra_state = { 119 | 'train_iterator': epoch_itr.state_dict(), 120 | 'val_loss': val_loss, 121 | } 122 | if hasattr(save_checkpoint, 'best'): 123 | extra_state.update({'best': save_checkpoint.best}) 124 | 125 | checkpoints = [os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond] 126 | if len(checkpoints) > 0: 127 | trainer.save_checkpoint(checkpoints[0], extra_state) 128 | for cp in checkpoints[1:]: 129 | shutil.copyfile(checkpoints[0], cp) 130 | 131 | write_timer.stop() 132 | print('| saved checkpoint {} (epoch {} @ {} updates) (writing took {} seconds)'.format( 133 | checkpoints[0], epoch, updates, write_timer.sum)) 134 | 135 | if not end_of_epoch and args.keep_interval_updates > 0: 136 | # remove old checkpoints; checkpoints are sorted in descending order 137 | checkpoints = checkpoint_paths( 138 | args.save_dir, pattern=r'checkpoint_\d+_(\d+)\.pt', 139 | ) 140 | for old_chk in checkpoints[args.keep_interval_updates:]: 141 | if os.path.lexists(old_chk): 142 | os.remove(old_chk) 143 | 144 | if args.keep_last_epochs > 0: 145 | # remove old epoch checkpoints; checkpoints are sorted in descending order 146 | checkpoints = checkpoint_paths( 147 | args.save_dir, pattern=r'checkpoint(\d+)\.pt', 148 | ) 149 | for old_chk in checkpoints[args.keep_last_epochs:]: 150 | if os.path.lexists(old_chk): 151 | os.remove(old_chk) 152 | 153 | 154 | def load_checkpoint(args, trainer, data_selector=None): 155 | """Load a checkpoint and restore the training iterator.""" 156 | # only one worker should attempt to create the required dir 157 | if args.distributed_rank == 0: 158 | os.makedirs(args.save_dir, exist_ok=True) 159 | 160 | if args.restore_file == 'checkpoint_last.pt': 161 | checkpoint_path = os.path.join(args.save_dir, 'checkpoint_last.pt') 162 | else: 163 | checkpoint_path = args.restore_file 164 | 165 | extra_state = trainer.load_checkpoint( 166 | checkpoint_path, 167 | args.reset_optimizer, 168 | args.reset_lr_scheduler, 169 | eval(args.optimizer_overrides), 170 | reset_meters=args.reset_meters, 171 | ) 172 | 173 | if ( 174 | extra_state is not None 175 | and 'best' in extra_state 176 | and not args.reset_optimizer 177 | and not args.reset_meters 178 | ): 179 | save_checkpoint.best = extra_state['best'] 180 | 181 | if extra_state is not None and not args.reset_dataloader: 182 | # restore iterator from checkpoint 183 | itr_state = extra_state['train_iterator'] 184 | epoch_itr = trainer.get_train_iterator(epoch=itr_state['epoch'], load_dataset=True, data_selector=data_selector) 185 | epoch_itr.load_state_dict(itr_state) 186 | else: 187 | epoch_itr = trainer.get_train_iterator(epoch=0, load_dataset=True, data_selector=data_selector) 188 | 189 | trainer.lr_step(epoch_itr.epoch) 190 | 191 | return extra_state, epoch_itr 192 | 193 | 194 | def load_checkpoint_to_cpu(path, arg_overrides=None): 195 | """Loads a checkpoint to CPU (with upgrading for backward compatibility).""" 196 | # if path manager not found, continue with local file. 197 | state = torch.load(path, map_location=lambda s, l: default_restore_location(s, 'cpu'),) 198 | 199 | args = state['args'] 200 | if arg_overrides is not None: 201 | for arg_name, arg_val in arg_overrides.items(): 202 | setattr(args, arg_name, arg_val) 203 | state = _upgrade_state_dict(state) 204 | return state 205 | 206 | 207 | def load_model_ensemble(filenames, arg_overrides=None, task=None): 208 | """Loads an ensemble of models. 209 | 210 | Args: 211 | filenames (List[str]): checkpoint files to load 212 | arg_overrides (Dict[str,Any], optional): override model args that 213 | were used during model training 214 | task (fairseq.tasks.FairseqTask, optional): task to use for loading 215 | """ 216 | ensemble, args, _task = load_model_ensemble_and_task(filenames, arg_overrides, task) 217 | return ensemble, args 218 | 219 | 220 | def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None): 221 | from fairseq import tasks 222 | 223 | ensemble = [] 224 | for filename in filenames: 225 | if not os.path.exists(filename): 226 | raise IOError('Model file not found: {}'.format(filename)) 227 | state = load_checkpoint_to_cpu(filename, arg_overrides) 228 | 229 | args = state['args'] 230 | if task is None: 231 | task = tasks.setup_task(args) 232 | 233 | # build model for ensemble 234 | model = task.build_model(args) 235 | model.load_state_dict(state['model'], strict=True) 236 | ensemble.append(model) 237 | return ensemble, args, task 238 | 239 | 240 | def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): 241 | """Retrieves all checkpoints found in `path` directory. 242 | 243 | Checkpoints are identified by matching filename to the specified pattern. If 244 | the pattern contains groups, the result will be sorted by the first group in 245 | descending order. 246 | """ 247 | pt_regexp = re.compile(pattern) 248 | files = os.listdir(path) 249 | 250 | entries = [] 251 | for i, f in enumerate(files): 252 | m = pt_regexp.fullmatch(f) 253 | if m is not None: 254 | idx = int(m.group(1)) if len(m.groups()) > 0 else i 255 | entries.append((idx, m.group(0))) 256 | return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] 257 | 258 | 259 | def torch_persistent_save(*args, **kwargs): 260 | for i in range(3): 261 | try: 262 | return torch.save(*args, **kwargs) 263 | except Exception: 264 | if i == 2: 265 | logging.error(traceback.format_exc()) 266 | 267 | 268 | def convert_state_dict_type(state_dict, ttype=torch.FloatTensor): 269 | if isinstance(state_dict, dict): 270 | cpu_dict = OrderedDict() 271 | for k, v in state_dict.items(): 272 | cpu_dict[k] = convert_state_dict_type(v) 273 | return cpu_dict 274 | elif isinstance(state_dict, list): 275 | return [convert_state_dict_type(v) for v in state_dict] 276 | elif torch.is_tensor(state_dict): 277 | return state_dict.type(ttype) 278 | else: 279 | return state_dict 280 | 281 | 282 | def save_state( 283 | filename, args, model_state_dict, criterion, optimizer, lr_scheduler, 284 | num_updates, optim_history=None, extra_state=None, 285 | ): 286 | from fairseq import utils 287 | if optim_history is None: 288 | optim_history = [] 289 | if extra_state is None: 290 | extra_state = {} 291 | state_dict = { 292 | 'args': args, 293 | 'model': model_state_dict if model_state_dict else {}, 294 | 'optimizer_history': optim_history + [ 295 | { 296 | 'criterion_name': criterion.__class__.__name__, 297 | 'optimizer_name': optimizer.__class__.__name__, 298 | 'lr_scheduler_state': lr_scheduler.state_dict(), 299 | 'num_updates': num_updates, 300 | } 301 | ], 302 | 'extra_state': extra_state, 303 | } 304 | if utils.has_parameters(criterion): 305 | state_dict['criterion'] = criterion.state_dict() 306 | if not args.no_save_optimizer_state: 307 | state_dict['last_optimizer_state'] = convert_state_dict_type(optimizer.state_dict()) 308 | 309 | torch_persistent_save(state_dict, filename) 310 | 311 | 312 | def _upgrade_state_dict(state): 313 | """Helper for upgrading old model checkpoints.""" 314 | from fairseq import models, registry, tasks 315 | 316 | # add optimizer_history 317 | if 'optimizer_history' not in state: 318 | state['optimizer_history'] = [ 319 | { 320 | 'criterion_name': 'CrossEntropyCriterion', 321 | 'best_loss': state['best_loss'], 322 | }, 323 | ] 324 | state['last_optimizer_state'] = state['optimizer'] 325 | del state['optimizer'] 326 | del state['best_loss'] 327 | # move extra_state into sub-dictionary 328 | if 'epoch' in state and 'extra_state' not in state: 329 | state['extra_state'] = { 330 | 'epoch': state['epoch'], 331 | 'batch_offset': state['batch_offset'], 332 | 'val_loss': state['val_loss'], 333 | } 334 | del state['epoch'] 335 | del state['batch_offset'] 336 | del state['val_loss'] 337 | # reduce optimizer history's memory usage (only keep the last state) 338 | if 'optimizer' in state['optimizer_history'][-1]: 339 | state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer'] 340 | for optim_hist in state['optimizer_history']: 341 | del optim_hist['optimizer'] 342 | # record the optimizer class name 343 | if 'optimizer_name' not in state['optimizer_history'][-1]: 344 | state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG' 345 | # move best_loss into lr_scheduler_state 346 | if 'lr_scheduler_state' not in state['optimizer_history'][-1]: 347 | state['optimizer_history'][-1]['lr_scheduler_state'] = { 348 | 'best': state['optimizer_history'][-1]['best_loss'], 349 | } 350 | del state['optimizer_history'][-1]['best_loss'] 351 | # keep track of number of updates 352 | if 'num_updates' not in state['optimizer_history'][-1]: 353 | state['optimizer_history'][-1]['num_updates'] = 0 354 | # old model checkpoints may not have separate source/target positions 355 | if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'): 356 | state['args'].max_source_positions = state['args'].max_positions 357 | state['args'].max_target_positions = state['args'].max_positions 358 | # use stateful training data iterator 359 | if 'train_iterator' not in state['extra_state']: 360 | state['extra_state']['train_iterator'] = { 361 | 'epoch': state['extra_state']['epoch'], 362 | 'iterations_in_epoch': state['extra_state'].get('batch_offset', 0), 363 | } 364 | # default to translation task 365 | if not hasattr(state['args'], 'task'): 366 | state['args'].task = 'translation' 367 | 368 | # set any missing default values in the task, model or other registries 369 | registry.set_defaults(state['args'], tasks.TASK_REGISTRY[state['args'].task]) 370 | registry.set_defaults(state['args'], models.ARCH_MODEL_REGISTRY[state['args'].arch]) 371 | for registry_name, REGISTRY in registry.REGISTRIES.items(): 372 | choice = getattr(state['args'], registry_name, None) 373 | if choice is not None: 374 | cls = REGISTRY['registry'][choice] 375 | registry.set_defaults(state['args'], cls) 376 | 377 | return state 378 | 379 | 380 | def verify_checkpoint_directory(model_path: Path) -> None: 381 | 382 | if not model_path.exists(): 383 | model_path.mkdir(exist_ok=True, parents=True) 384 | 385 | temp_file_path = model_path / 'dummy' 386 | try: 387 | with open(str(temp_file_path), 'w'): 388 | pass 389 | except OSError as e: 390 | (f'| Unable to access checkpoint save directory: {model_path}') 391 | raise e 392 | else: 393 | os.remove(temp_file_path) 394 | -------------------------------------------------------------------------------- /transphone/model/dataset.py: -------------------------------------------------------------------------------- 1 | from phonepiece.lexicon import read_lexicon 2 | from phonepiece.lang import read_all_langs 3 | from transphone.model.vocab import Vocab 4 | from transphone.config import TransphoneConfig 5 | from tqdm import tqdm 6 | import torch 7 | 8 | 9 | def read_dataset(lang_id='eng'): 10 | 11 | lexicon = read_lexicon(lang_id) 12 | 13 | phoneme_lst = [] 14 | grapheme_lst = [] 15 | 16 | phoneme_set = set() 17 | grapheme_set = set() 18 | 19 | for grapheme_str, phonemes in lexicon.word2phoneme.items(): 20 | graphemes = list(grapheme_str) 21 | phoneme_lst.append(phonemes) 22 | grapheme_lst.append(graphemes) 23 | phoneme_set.update(phonemes) 24 | grapheme_set.update(graphemes) 25 | 26 | phoneme_vocab = Vocab(phoneme_set) 27 | grapheme_vocab = Vocab(grapheme_set) 28 | 29 | return Dataset(phoneme_lst, grapheme_lst, phoneme_vocab, grapheme_vocab) 30 | 31 | 32 | def read_multilingual_dataset(): 33 | 34 | phoneme_lst = [] 35 | grapheme_lst = [] 36 | 37 | test_phoneme_lst = [] 38 | test_grapheme_lst = [] 39 | 40 | phoneme_set = set() 41 | grapheme_set = set() 42 | 43 | print("loading dataset...") 44 | 45 | for lang_id in tqdm(read_all_langs()): 46 | 47 | try: 48 | lexicon = read_lexicon(lang_id) 49 | except: 50 | print("skip ", lang_id) 51 | 52 | if len(lexicon.word2phoneme) < 50: 53 | continue 54 | 55 | lang_tag = '<'+lang_id+'>' 56 | 57 | grapheme_set.add(lang_tag) 58 | 59 | word2phoneme_lst = list(lexicon.word2phoneme.items()) 60 | 61 | # train set 62 | for grapheme_str, phonemes in word2phoneme_lst[:-50]: 63 | graphemes = [lang_tag] + list(grapheme_str) 64 | phoneme_lst.append(phonemes) 65 | grapheme_lst.append(graphemes) 66 | phoneme_set.update(phonemes) 67 | grapheme_set.update(graphemes) 68 | 69 | # dev set 70 | for grapheme_str, phonemes in word2phoneme_lst[-50:-25]: 71 | graphemes = [lang_tag] + list(grapheme_str) 72 | test_phoneme_lst.append(phonemes) 73 | test_grapheme_lst.append(graphemes) 74 | phoneme_set.update(phonemes) 75 | grapheme_set.update(graphemes) 76 | 77 | phoneme_vocab = Vocab(phoneme_set) 78 | grapheme_vocab = Vocab(grapheme_set) 79 | 80 | train_dataset = Dataset(phoneme_lst, grapheme_lst, phoneme_vocab, grapheme_vocab) 81 | dev_dataset = Dataset(test_phoneme_lst, test_grapheme_lst, phoneme_vocab, grapheme_vocab) 82 | 83 | return train_dataset, dev_dataset 84 | 85 | 86 | def read_test_dataset(model_name): 87 | 88 | model_path = TransphoneConfig.data_path / 'model' / model_name 89 | grapheme_vocab = Vocab.read(model_path / 'grapheme.vocab') 90 | phoneme_vocab = Vocab.read(model_path / 'phoneme.vocab') 91 | 92 | test_phoneme_lst = [] 93 | test_grapheme_lst = [] 94 | lang_lst = [] 95 | 96 | for lang_id in tqdm(read_all_langs()): 97 | 98 | try: 99 | lexicon = read_lexicon(lang_id) 100 | except: 101 | print("skip ", lang_id) 102 | 103 | if len(lexicon.word2phoneme) <= 50: 104 | continue 105 | 106 | word2phoneme_lst = list(lexicon.word2phoneme.items()) 107 | lang_phoneme_lst = [] 108 | lang_grapheme_lst = [] 109 | 110 | # last 25 for testing 111 | for grapheme_str, phonemes in word2phoneme_lst[-25:]: 112 | 113 | graphemes = list(grapheme_str) 114 | 115 | skip = False 116 | 117 | for phoneme in phonemes: 118 | if phoneme not in phoneme_vocab: 119 | skip = True 120 | break 121 | 122 | for grapheme in graphemes: 123 | if grapheme not in grapheme_vocab: 124 | skip = True 125 | break 126 | 127 | if skip: 128 | continue 129 | 130 | lang_phoneme_lst.append(phonemes) 131 | lang_grapheme_lst.append(grapheme_str) 132 | 133 | if len(lang_phoneme_lst) > 0: 134 | test_phoneme_lst.append(lang_phoneme_lst) 135 | test_grapheme_lst.append(lang_grapheme_lst) 136 | lang_lst.append(lang_id) 137 | 138 | return test_grapheme_lst, test_phoneme_lst, lang_lst 139 | 140 | 141 | def read_zsl_dataset(model_name): 142 | 143 | model_path = TransphoneConfig.data_path / 'model' / model_name 144 | grapheme_vocab = Vocab.read(model_path / 'grapheme.vocab') 145 | phoneme_vocab = Vocab.read(model_path / 'phoneme.vocab') 146 | 147 | test_phoneme_lst = [] 148 | test_grapheme_lst = [] 149 | lang_lst = [] 150 | 151 | for lang_id in tqdm(read_all_langs()): 152 | 153 | try: 154 | lexicon = read_lexicon(lang_id) 155 | except: 156 | print("skip ", lang_id) 157 | 158 | if len(lexicon.word2phoneme) >= 50 or len(lexicon.word2phoneme) == 0: 159 | continue 160 | 161 | word2phoneme_lst = list(lexicon.word2phoneme.items()) 162 | 163 | lang_phoneme_lst = [] 164 | lang_grapheme_lst = [] 165 | 166 | # last 10 for testing 167 | for grapheme_str, phonemes in word2phoneme_lst[-25:]: 168 | graphemes = list(grapheme_str) 169 | 170 | skip = False 171 | 172 | for phoneme in phonemes: 173 | if phoneme not in phoneme_vocab: 174 | skip = True 175 | break 176 | 177 | for grapheme in graphemes: 178 | if grapheme not in grapheme_vocab: 179 | skip = True 180 | break 181 | 182 | if skip: 183 | continue 184 | 185 | lang_phoneme_lst.append(phonemes) 186 | lang_grapheme_lst.append(grapheme_str) 187 | 188 | test_phoneme_lst.append(lang_phoneme_lst) 189 | test_grapheme_lst.append(lang_grapheme_lst) 190 | lang_lst.append(lang_id) 191 | 192 | return test_grapheme_lst, test_phoneme_lst, lang_lst 193 | 194 | 195 | class Dataset: 196 | 197 | def __init__(self, phoneme_lst, grapheme_lst, phoneme_vocab, grapheme_vocab): 198 | 199 | self.phoneme_lst = phoneme_lst 200 | self.grapheme_lst = grapheme_lst 201 | self.phoneme_vocab = phoneme_vocab 202 | self.grapheme_vocab = grapheme_vocab 203 | 204 | def __getitem__(self, item): 205 | 206 | phones = self.phoneme_lst[item] 207 | graphemes = self.grapheme_lst[item] 208 | 209 | phone_ids = [self.phoneme_vocab.atoi(phone) for phone in phones] 210 | grapheme_ids = [self.grapheme_vocab.atoi(grapheme) for grapheme in graphemes] 211 | 212 | return (torch.LongTensor(grapheme_ids), torch.LongTensor(phone_ids)) 213 | 214 | def __len__(self): 215 | return len(self.phoneme_lst) 216 | 217 | 218 | class P2GDataset: 219 | 220 | def __init__(self, phoneme_lst, grapheme_lst, phoneme_vocab, grapheme_vocab): 221 | 222 | self.phoneme_lst = phoneme_lst 223 | self.grapheme_lst = grapheme_lst 224 | self.phoneme_vocab = phoneme_vocab 225 | self.grapheme_vocab = grapheme_vocab 226 | 227 | def __getitem__(self, item): 228 | 229 | phones = self.phoneme_lst[item] 230 | graphemes = self.grapheme_lst[item] 231 | 232 | phone_ids = [self.phoneme_vocab.atoi(phone) for phone in phones] 233 | grapheme_ids = [self.grapheme_vocab.atoi(grapheme) for grapheme in graphemes] 234 | 235 | return (torch.LongTensor(phone_ids), torch.LongTensor(grapheme_ids)) 236 | 237 | def __len__(self): 238 | return len(self.phoneme_lst) -------------------------------------------------------------------------------- /transphone/model/ensemble.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | 4 | def test(): 5 | 6 | a = Lattice("abcd") 7 | b = Lattice("bzde") 8 | c = combine(a,b) 9 | d = Lattice("bcdef") 10 | e = combine(c, d) 11 | 12 | print(e) 13 | 14 | def ensemble(pred_lst): 15 | 16 | assert len(pred_lst) >= 1 17 | 18 | lattice_base = None 19 | 20 | for pred in pred_lst: 21 | if not isinstance(pred, list): 22 | pred = pred.split() 23 | 24 | if lattice_base is None: 25 | lattice_base = Lattice(pred) 26 | 27 | else: 28 | lattice_base = combine(lattice_base, Lattice(pred)) 29 | 30 | return lattice_base.compute() 31 | 32 | 33 | def combine(lattice_a, lattice_b, verbose=True, out=None): 34 | 35 | # output 36 | lattice = Lattice() 37 | 38 | cs_lst = [] 39 | 40 | # length of each string 41 | len_a = len(lattice_a) 42 | len_b = len(lattice_b) 43 | 44 | # dp table 45 | dp = [[0 for x in range(len_a+1)] for y in range(len_b+1)] 46 | path = [[(0, 0) for x in range(len_a+1)] for y in range(len_b+1)] 47 | 48 | # initialize first row and first column 49 | for i in range(1, len_a+1): 50 | dp[0][i] = i 51 | path[0][i] = (0, i-1) 52 | 53 | for i in range(1, len_b+1): 54 | dp[i][0] = i 55 | path[i][0] = (i-1, 0) 56 | 57 | # dp update 58 | for i in range(1, len_b+1): 59 | for j in range(1, len_a+1): 60 | index_a = j-1 61 | index_b = i-1 62 | 63 | cs_a = lattice_a[index_a] 64 | cs_b = lattice_b[index_b] 65 | 66 | sub_cost = cs_a.substitute_cost(cs_b) 67 | del_cost = cs_a.delete_cost() 68 | add_cost = cs_b.delete_cost() 69 | 70 | dp[i][j] = dp[i-1][j-1]+sub_cost 71 | path[i][j] = (i-1,j-1) 72 | 73 | if dp[i][j] > dp[i-1][j]+del_cost: 74 | dp[i][j] = dp[i-1][j]+del_cost 75 | path[i][j] = (i-1,j) 76 | 77 | if dp[i][j] > dp[i][j-1]+add_cost: 78 | dp[i][j] = dp[i][j-1]+add_cost 79 | path[i][j] = (i, j-1) 80 | 81 | cur_node = (len_b, len_a) 82 | 83 | while(cur_node != (0,0)): 84 | prev_node = path[cur_node[0]][cur_node[1]] 85 | 86 | cs_a = lattice_a[cur_node[1]-1] 87 | cs_b = lattice_b[cur_node[0]-1] 88 | 89 | # substitution or match 90 | if prev_node[0]+1 == cur_node[0] and prev_node[1]+1 == cur_node[1]: 91 | 92 | cs_a.merge(cs_b) 93 | cs_lst.append(cs_a) 94 | 95 | 96 | # addition 97 | if prev_node[0] + 1 == cur_node[0] and prev_node[1] == cur_node[1]: 98 | 99 | cs_e = cs_a.create_empty_set() 100 | cs_e.merge(cs_b) 101 | 102 | cs_lst.append(cs_e) 103 | 104 | # deletion 105 | if prev_node[0] == cur_node[0] and prev_node[1]+1 == cur_node[1]: 106 | 107 | cs_e = cs_b.create_empty_set() 108 | cs_e.merge(cs_a) 109 | cs_lst.append(cs_e) 110 | 111 | cur_node = prev_node 112 | 113 | cs_lst.reverse() 114 | lattice.cs_lst = cs_lst 115 | 116 | return lattice 117 | 118 | class CorrespondenceSet: 119 | 120 | def __init__(self, units=None, scores=None): 121 | 122 | if units is None: 123 | units = [] 124 | 125 | if scores is None: 126 | scores = [1.0]*len(units) 127 | 128 | self.units = units 129 | self.unit_set = set(units) 130 | self.scores = scores 131 | 132 | 133 | def __repr__(self): 134 | return '{'+','.join(self.units)+'}' 135 | 136 | def __str__(self): 137 | return self.__repr__() 138 | 139 | def __iter__(self): 140 | for unit in self.units: 141 | yield unit 142 | 143 | def __len__(self): 144 | return len(self.units) 145 | 146 | def __getitem__(self, idx): 147 | return self.units[idx] 148 | 149 | def compute(self): 150 | unit2score = defaultdict(float) 151 | 152 | for i in range(len(self.units)): 153 | unit = self.units[i] 154 | score = self.scores[i] 155 | unit2score[unit] += score 156 | 157 | return sorted(unit2score.items(), key=lambda x:-x[1])[0][0] 158 | 159 | def has_epsilon(self): 160 | return '' in self.unit_set 161 | 162 | def contains(self, unit): 163 | return unit in self.unit_set 164 | 165 | def delete_cost(self): 166 | if self.has_epsilon(): 167 | return 0 168 | else: 169 | return 1 170 | 171 | def substitute_cost(self, other): 172 | overlap = False 173 | for unit in self.units: 174 | if other.contains(unit): 175 | overlap = True 176 | 177 | if overlap: 178 | return 0 179 | else: 180 | return 1 181 | 182 | def create_empty_set(self): 183 | ave_score = np.mean(self.scores) 184 | cs = CorrespondenceSet([''], [ave_score]) 185 | return cs 186 | 187 | def merge(self, other): 188 | 189 | for unit, score in zip(other.units, other.scores): 190 | self.units.append(unit) 191 | self.scores.append(score) 192 | self.unit_set.add(unit) 193 | 194 | 195 | class Lattice: 196 | 197 | def __init__(self, units=None, scores=None): 198 | 199 | if units is None: 200 | units = [] 201 | else: 202 | units = list(units) 203 | 204 | if scores is None: 205 | scores = [1.0]*len(units) 206 | else: 207 | scores = list(scores) 208 | 209 | assert len(units) == len(scores) 210 | 211 | self.cs_lst = [] 212 | 213 | for unit, score in zip(units, scores): 214 | self.cs_lst.append(CorrespondenceSet([unit], [score])) 215 | 216 | def __repr__(self): 217 | return '\n'.join([str(cs) for cs in self.cs_lst]) 218 | 219 | def __str__(self): 220 | return self.__repr__() 221 | 222 | def __len__(self): 223 | return len(self.cs_lst) 224 | 225 | def __getitem__(self, idx): 226 | return self.cs_lst[idx] 227 | 228 | def __iter__(self): 229 | for cs in self.cs_lst: 230 | yield cs 231 | 232 | def compute(self): 233 | cs_items = [cs.compute() for cs in self.cs_lst] 234 | 235 | return [cs_item for cs_item in cs_items if cs_item != "" and not cs_item.isspace()] -------------------------------------------------------------------------------- /transphone/model/grapheme.py: -------------------------------------------------------------------------------- 1 | from phonepiece.unit import Unit 2 | from pathlib import Path 3 | from phonepiece.config import phonepiece_config 4 | from unidecode import unidecode 5 | import editdistance 6 | 7 | def read_grapheme(lang_id): 8 | 9 | unit_to_id = dict() 10 | 11 | unit_to_id[''] = 0 12 | 13 | idx = 0 14 | 15 | unit_path = Path(phonepiece_config.data_path / 'phonetisaurus' / lang_id / 'char.txt') 16 | 17 | for line in open(str(unit_path), 'r', encoding='utf-8'): 18 | fields = line.strip().split() 19 | 20 | assert len(fields) < 3 21 | 22 | if len(fields) == 1: 23 | unit = fields[0] 24 | idx += 1 25 | else: 26 | unit = fields[0] 27 | idx = int(fields[1]) 28 | 29 | unit_to_id[unit] = idx 30 | 31 | unit_to_id[''] = 0 32 | unit_to_id[''] = idx+1 33 | 34 | return Grapheme(unit_to_id) 35 | 36 | 37 | class Grapheme(Unit): 38 | 39 | def __init__(self, unit_to_id): 40 | super().__init__(unit_to_id) 41 | 42 | self.latins = [] 43 | self.nearest_mapping = None 44 | 45 | for i, elem in enumerate(self.elems): 46 | if i == 0: 47 | self.latins.append('') 48 | else: 49 | self.latins.append(unidecode(elem)) 50 | 51 | def get_nearest_unit(self, unit): 52 | 53 | if self.nearest_mapping is None: 54 | self.nearest_mapping = dict() 55 | 56 | if unit in self.nearest_mapping: 57 | return self.nearest_mapping[unit] 58 | 59 | if unit in self.unit_to_id: 60 | self.nearest_mapping[unit] = unit 61 | return unit 62 | 63 | target_latin = unidecode(unit) 64 | 65 | edit_score = dict() 66 | for i, latin in enumerate(self.latins): 67 | unit = self.id_to_unit[i] 68 | edit_score[unit] = editdistance.eval(latin, target_latin) 69 | 70 | target_unit = min(edit_score, key=edit_score.get) 71 | self.nearest_mapping[unit] = target_unit 72 | 73 | return target_unit -------------------------------------------------------------------------------- /transphone/model/loader.py: -------------------------------------------------------------------------------- 1 | from transphone.model.utils import pad_list 2 | import torch 3 | from torch.utils.data import DataLoader 4 | 5 | 6 | def collate(xy_lst): 7 | x_lst = [xy[0] for xy in xy_lst] 8 | y_lst = [xy[1] for xy in xy_lst] 9 | 10 | x = pad_list(x_lst) 11 | y = pad_list(y_lst) 12 | 13 | return x,y 14 | 15 | def read_loader(dataset, batch_size=32): 16 | 17 | loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, collate_fn=collate) 18 | return loader -------------------------------------------------------------------------------- /transphone/model/lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transphone.model.utils import pad_sos_eos 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | 8 | class Encoder(nn.Module): 9 | 10 | def __init__(self, vocab_size, hidden_size, layer_size): 11 | super().__init__() 12 | 13 | self.embed = nn.Embedding(vocab_size, hidden_size) 14 | self.rnn = nn.LSTM(hidden_size, hidden_size, layer_size, batch_first=True) 15 | #self.linear = nn.Linear(2*hidden_size, hidden_size) 16 | 17 | def forward(self, input_tensor): 18 | embed = self.embed(input_tensor) 19 | output, (hidden, _) = self.rnn(embed) 20 | #output = self.linear(output) 21 | 22 | return output, hidden 23 | 24 | 25 | class Decoder(nn.Module): 26 | 27 | def __init__(self, vocab_size, hidden_size): 28 | super().__init__() 29 | self.embed = nn.Embedding(vocab_size, hidden_size) 30 | self.rnn = nn.LSTM(hidden_size, hidden_size, 1, batch_first=True) 31 | self.linear = nn.Linear(hidden_size, vocab_size) 32 | self.logsoftmax = nn.LogSoftmax(dim=2) 33 | 34 | def forward(self, input_tensor, prev_hidden, prev_cell): 35 | # input: [B] 36 | 37 | embed = self.embed(input_tensor) 38 | output, (hidden, cell) = self.rnn(embed, (prev_hidden, prev_cell)) 39 | #print('after rnn', output) 40 | output = self.linear(output) 41 | #print('after linear:', output) 42 | output = self.logsoftmax(output) 43 | #print('after softmax', output) 44 | 45 | return output, hidden, cell 46 | 47 | 48 | class AttentionDecoder(nn.Module): 49 | 50 | def __init__(self, vocab_size, hidden_size): 51 | super().__init__() 52 | self.embed = nn.Embedding(vocab_size, hidden_size) 53 | self.rnn = nn.LSTM(2*hidden_size, hidden_size, 1, batch_first=True) 54 | self.linear = nn.Linear(hidden_size, vocab_size) 55 | self.linearQ = nn.Linear(2*hidden_size, hidden_size) 56 | self.logsoftmax = nn.LogSoftmax(dim=2) 57 | self.hidden_size = hidden_size 58 | 59 | def forward(self, input_tensor, encoder_vector, encoder_mask, prev_hidden, prev_cell): 60 | # input: [B, 1] 61 | # encoder: [B,T,H] 62 | # encoder mask: (B,T) 63 | # prev_hidden: (1,B,H) 64 | # prev_cell: (1,B,H) 65 | 66 | # [B,1,H] 67 | embed = self.embed(input_tensor).squeeze(1) 68 | 69 | # [B,2H] -> [B,H,1] 70 | Q = self.linearQ(torch.cat([embed, prev_hidden.squeeze(0)], dim=1)).unsqueeze(-1) 71 | K = encoder_vector 72 | 73 | # (B,T) 74 | unnormed_weights = torch.bmm(K, Q).squeeze(2)/math.sqrt(self.hidden_size) 75 | 76 | # masking (B,T) 77 | masked_weights = unnormed_weights.masked_fill(~encoder_mask, -np.inf) 78 | 79 | # [B,T,H] [B,H,1] -> [B,T,1] -> [B,T] 80 | attn_weights = F.softmax(masked_weights, dim=1) 81 | 82 | # [B,H,T]x[B,T,1] -> [B,H,1] -> [B,H] 83 | attn_applied = torch.bmm(encoder_vector.transpose(1,2), attn_weights.unsqueeze(-1)).squeeze(2) 84 | 85 | # (B,2H) -> (B,1,2H) 86 | lstm_input = torch.cat([attn_applied, embed], dim=1).unsqueeze(1) 87 | 88 | # [B,1,2H] 89 | output, (hidden, cell) = self.rnn(lstm_input, (prev_hidden, prev_cell)) 90 | 91 | #print('after rnn', output) 92 | output = self.linear(output) 93 | #print('after linear:', output) 94 | output = self.logsoftmax(output) 95 | #print('after softmax', output) 96 | 97 | return output, hidden, cell, attn_weights 98 | 99 | 100 | class G2P(nn.Module): 101 | 102 | def __init__(self): 103 | super().__init__() 104 | 105 | self.hidden_size = 256 106 | self.vocab_size = 200 107 | self.encoder = Encoder(self.vocab_size, self.hidden_size, 1) 108 | self.decoder = Decoder(self.vocab_size, self.hidden_size) 109 | 110 | self.criterion = nn.NLLLoss(ignore_index=0) 111 | 112 | def train_step(self, x, y): 113 | 114 | self.train() 115 | batch_size = x.shape[0] 116 | 117 | output, prev_hidden = self.encoder(x) 118 | prev_cell = x.new_zeros(1, batch_size, self.hidden_size, dtype=torch.float) 119 | 120 | ys_in, ys_out = pad_sos_eos(y, 1, 1) 121 | ys_in = ys_in.transpose(1,0) 122 | ys_out = ys_out.transpose(1,0) 123 | 124 | loss = 0 125 | 126 | for i in range(len(ys_in)): 127 | y_in = ys_in[i].view(batch_size, 1) 128 | y_out = ys_out[i] 129 | 130 | output, prev_hidden, prev_cell = self.decoder(y_in, prev_hidden, prev_cell) 131 | 132 | output = output.squeeze() 133 | #print('----') 134 | #print(y_out.shape) 135 | #print(output.shape) 136 | #print('y_out', y_out) 137 | #print('output', output) 138 | loss += self.criterion(output, y_out) 139 | 140 | return loss 141 | 142 | def inference(self, x): 143 | 144 | self.eval() 145 | 146 | #x = torch.LongTensor([[52, 74, 57, 72, 63, 0, 0, 0, 0, 0, 0, 0]]) 147 | 148 | batch_size = 1 149 | 150 | output, prev_hidden = self.encoder(x) 151 | prev_cell = prev_hidden.new_zeros((1, batch_size, self.hidden_size), dtype=torch.float) 152 | 153 | #print(prev_hidden) 154 | 155 | #y_in = torch.LongTensor([[1]]) 156 | y_out = [] 157 | 158 | w = 1 159 | 160 | while(True): 161 | y_in = x.new([[w]]) 162 | output, prev_hidden, prev_cell = self.decoder(y_in, prev_hidden, prev_cell) 163 | output = output.squeeze() 164 | #print(output) 165 | w = output.data.topk(1)[1].item() 166 | y_out.append(w) 167 | 168 | if w == 1 or len(y_out)>16: 169 | break 170 | 171 | return y_out 172 | 173 | 174 | class AttnG2P(nn.Module): 175 | 176 | def __init__(self): 177 | super().__init__() 178 | 179 | self.hidden_size = 512 180 | self.vocab_size = 200 181 | self.encoder = Encoder(self.vocab_size, self.hidden_size, 2) 182 | self.decoder = AttentionDecoder(self.vocab_size, self.hidden_size) 183 | 184 | self.criterion = nn.NLLLoss(ignore_index=0) 185 | 186 | def train_step(self, x, y): 187 | 188 | self.train() 189 | batch_size = x.shape[0] 190 | encoder_mask = (x != 0) 191 | 192 | encoder_output, _ = self.encoder(x) 193 | prev_cell = x.new_zeros(1, batch_size, self.hidden_size, dtype=torch.float) 194 | prev_hidden = x.new_zeros(1, batch_size, self.hidden_size, dtype=torch.float) 195 | 196 | ys_in, ys_out = pad_sos_eos(y, 1, 1) 197 | ys_in = ys_in.transpose(1,0) 198 | ys_out = ys_out.transpose(1,0) 199 | 200 | loss = 0 201 | 202 | for i in range(len(ys_in)): 203 | y_in = ys_in[i].view(batch_size, 1) 204 | y_out = ys_out[i] 205 | 206 | output, prev_hidden, prev_cell, _ = self.decoder(y_in, encoder_output, encoder_mask, prev_hidden, prev_cell) 207 | 208 | output = output.squeeze() 209 | #print('----') 210 | #print(y_out.shape) 211 | #print(output.shape) 212 | #print('y_out', y_out) 213 | #print('output', output) 214 | loss += self.criterion(output, y_out) 215 | 216 | return loss 217 | 218 | def inference_with_attention(self, x): 219 | 220 | self.eval() 221 | encoder_mask = (x != 0) 222 | 223 | #x = torch.LongTensor([[52, 74, 57, 72, 63, 0, 0, 0, 0, 0, 0, 0]]) 224 | 225 | batch_size = 1 226 | 227 | encoder_output, _ = self.encoder(x) 228 | prev_cell = x.new_zeros((1, batch_size, self.hidden_size), dtype=torch.float) 229 | prev_hidden = x.new_zeros(1, batch_size, self.hidden_size, dtype=torch.float) 230 | 231 | #print(prev_hidden) 232 | 233 | #y_in = torch.LongTensor([[1]]) 234 | y_out = [] 235 | 236 | w = 1 237 | 238 | weights = [] 239 | 240 | while(True): 241 | y_in = x.new([[w]]) 242 | output, prev_hidden, prev_cell, attn_weights = self.decoder(y_in, encoder_output, encoder_mask, prev_hidden, prev_cell) 243 | output = output.squeeze() 244 | w = output.data.topk(1)[1].item() 245 | 246 | weights.append(attn_weights.unsqueeze(-1)) 247 | if w == 1 or len(y_out)>16: 248 | break 249 | 250 | y_out.append(w) 251 | 252 | weights = torch.cat(weights, dim=2) 253 | 254 | return y_out, weights 255 | 256 | 257 | 258 | def inference(self, x): 259 | 260 | self.eval() 261 | 262 | #x = torch.LongTensor([[52, 74, 57, 72, 63, 0, 0, 0, 0, 0, 0, 0]]) 263 | 264 | batch_size = 1 265 | 266 | encoder_output, prev_hidden = self.encoder(x) 267 | prev_cell = prev_hidden.new_zeros((1, batch_size, self.hidden_size), dtype=torch.float) 268 | 269 | #print(prev_hidden) 270 | 271 | #y_in = torch.LongTensor([[1]]) 272 | y_out = [] 273 | 274 | w = 1 275 | 276 | while(True): 277 | y_in = x.new([[w]]) 278 | output, prev_hidden, prev_cell = self.decoder(y_in, encoder_output, prev_hidden, prev_cell) 279 | output = output.squeeze() 280 | #print(output) 281 | w = output.data.topk(1)[1].item() 282 | y_out.append(w) 283 | 284 | if w == 1 or len(y_out)>16: 285 | break 286 | 287 | return y_out -------------------------------------------------------------------------------- /transphone/model/transformer.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import Transformer 5 | import math 6 | from transphone.model.utils import pad_sos_eos 7 | from transphone.config import TransphoneConfig 8 | 9 | UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 0, 1, 1 10 | 11 | def generate_square_subsequent_mask(sz): 12 | mask = (torch.triu(torch.ones((sz, sz), device=TransphoneConfig.device)) == 1).transpose(0, 1) 13 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 14 | return mask 15 | 16 | 17 | def create_mask(src, tgt): 18 | src_seq_len = src.shape[0] 19 | tgt_seq_len = tgt.shape[0] 20 | 21 | tgt_mask = generate_square_subsequent_mask(tgt_seq_len) 22 | src_mask = torch.zeros((src_seq_len, src_seq_len),device=TransphoneConfig.device).type(torch.bool) 23 | 24 | src_padding_mask = (src == PAD_IDX).transpose(0, 1) 25 | tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1) 26 | return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask 27 | 28 | 29 | # helper Module that adds positional encoding to the token embedding to introduce a notion of word order. 30 | class PositionalEncoding(nn.Module): 31 | def __init__(self, 32 | emb_size: int, 33 | dropout: float, 34 | maxlen: int = 5000): 35 | super(PositionalEncoding, self).__init__() 36 | den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size) 37 | pos = torch.arange(0, maxlen).reshape(maxlen, 1) 38 | pos_embedding = torch.zeros((maxlen, emb_size)) 39 | pos_embedding[:, 0::2] = torch.sin(pos * den) 40 | pos_embedding[:, 1::2] = torch.cos(pos * den) 41 | pos_embedding = pos_embedding.unsqueeze(-2) 42 | 43 | self.dropout = nn.Dropout(dropout) 44 | self.register_buffer('pos_embedding', pos_embedding) 45 | 46 | def forward(self, token_embedding: Tensor): 47 | return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :]) 48 | 49 | # helper Module to convert tensor of input indices into corresponding tensor of token embeddings 50 | class TokenEmbedding(nn.Module): 51 | def __init__(self, vocab_size: int, emb_size): 52 | super(TokenEmbedding, self).__init__() 53 | self.embedding = nn.Embedding(vocab_size, emb_size) 54 | self.emb_size = emb_size 55 | 56 | def forward(self, tokens: Tensor): 57 | return self.embedding(tokens.long()) * math.sqrt(self.emb_size) 58 | 59 | 60 | # Seq2Seq Network 61 | class TransformerG2P(nn.Module): 62 | def __init__(self, 63 | num_encoder_layers: int, 64 | num_decoder_layers: int, 65 | emb_size: int, 66 | nhead: int, 67 | src_vocab_size: int, 68 | tgt_vocab_size: int, 69 | dim_feedforward: int = 512, 70 | dropout: float = 0.1): 71 | super().__init__() 72 | self.transformer = Transformer(d_model=emb_size, 73 | nhead=nhead, 74 | num_encoder_layers=num_encoder_layers, 75 | num_decoder_layers=num_decoder_layers, 76 | dim_feedforward=dim_feedforward, 77 | dropout=dropout) 78 | self.generator = nn.Linear(emb_size, tgt_vocab_size) 79 | self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) 80 | self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) 81 | self.positional_encoding = PositionalEncoding( 82 | emb_size, dropout=dropout) 83 | 84 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX) 85 | 86 | 87 | def forward(self, 88 | src: Tensor, 89 | trg: Tensor, 90 | src_mask: Tensor, 91 | tgt_mask: Tensor, 92 | src_padding_mask: Tensor, 93 | tgt_padding_mask: Tensor, 94 | memory_key_padding_mask: Tensor): 95 | src_emb = self.positional_encoding(self.src_tok_emb(src)) 96 | tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) 97 | outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 98 | src_padding_mask, tgt_padding_mask, memory_key_padding_mask) 99 | return self.generator(outs) 100 | 101 | def encode(self, src: Tensor, src_mask: Tensor): 102 | return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)), src_mask) 103 | 104 | def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): 105 | return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory,tgt_mask) 106 | 107 | 108 | def train_step(self, x,y): 109 | 110 | self.train() 111 | 112 | ys_in, ys_out = pad_sos_eos(y, 1, 1) 113 | 114 | # T,B 115 | tgt_input = ys_in.transpose(1,0) 116 | tgt_out = ys_out.transpose(1,0) 117 | 118 | src = x.transpose(1,0) 119 | 120 | src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input) 121 | 122 | logits = self.forward(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask) 123 | 124 | loss = self.criterion(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1)) 125 | 126 | return loss 127 | 128 | 129 | def inference(self, x): 130 | 131 | self.eval() 132 | src = x.view(-1, 1) 133 | num_tokens = src.shape[0] 134 | src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool).to(TransphoneConfig.device) 135 | max_len=num_tokens + 5 136 | 137 | memory = self.encode(src, src_mask) 138 | ys = x.new_ones(1, 1).fill_(BOS_IDX).type(torch.long) 139 | 140 | for i in range(max_len-1): 141 | memory = memory.to(TransphoneConfig.device) 142 | tgt_mask = (generate_square_subsequent_mask(ys.size(0)) 143 | .type(torch.bool)).to(TransphoneConfig.device) 144 | out = self.decode(ys, memory, tgt_mask) 145 | out = out.transpose(0, 1) 146 | prob = self.generator(out[:, -1]) 147 | _, next_word = torch.max(prob, dim=1) 148 | next_word = next_word.item() 149 | 150 | ys = torch.cat([ys, 151 | torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0) 152 | if next_word == EOS_IDX: 153 | break 154 | 155 | out = ys.squeeze(1).tolist()[1:] 156 | 157 | if out[-1] == 1: 158 | out = out[:-1] 159 | 160 | return out 161 | 162 | def inference_batch(self, x): 163 | 164 | self.eval() 165 | 166 | # (T,B) 167 | src = x.transpose(0, 1) 168 | 169 | num_tokens = src.shape[0] 170 | batch_size = src.shape[1] 171 | 172 | src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool).to(TransphoneConfig.device) 173 | max_len = num_tokens + 5 174 | 175 | memory = self.encode(src, src_mask) 176 | ys = x.new_ones(1, batch_size).fill_(BOS_IDX).type(torch.long) 177 | 178 | is_done = [False] * batch_size 179 | 180 | for i in range(max_len-1): 181 | memory = memory.to(TransphoneConfig.device) 182 | tgt_mask = (generate_square_subsequent_mask(ys.size(0)) 183 | .type(torch.bool)).to(TransphoneConfig.device) 184 | out = self.decode(ys, memory, tgt_mask) 185 | out = out.transpose(0, 1) 186 | prob = self.generator(out[:, -1]) 187 | 188 | _, next_words = torch.max(prob, dim=1) 189 | 190 | for j in range(batch_size): 191 | if next_words[j].item() == EOS_IDX: 192 | is_done[j] = True 193 | 194 | if is_done[j]: 195 | next_words[j] = EOS_IDX 196 | 197 | ys = torch.cat([ys, next_words.unsqueeze(0)], dim=0) 198 | 199 | if all(is_done): 200 | break 201 | 202 | outs = [] 203 | for y in ys.transpose(0,1).tolist(): 204 | outs.append([i for i in y if i >= 2]) 205 | 206 | return outs -------------------------------------------------------------------------------- /transphone/model/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from transphone.config import TransphoneConfig 3 | import shutil 4 | import yaml 5 | 6 | 7 | class dotdict(dict): 8 | """dot.notation access to dictionary attributes""" 9 | __getattr__ = dict.get 10 | __setattr__ = dict.__setitem__ 11 | __delattr__ = dict.__delitem__ 12 | 13 | def read_model_config(exp): 14 | yaml_file = TransphoneConfig.data_path / 'exp' / f'{exp}.yml' 15 | # Open the YAML file and load its contents into a Python dictionary 16 | with open(yaml_file, "r") as f: 17 | model_config = yaml.safe_load(f) 18 | 19 | return dotdict(model_config) 20 | 21 | 22 | def resolve_model_name(model_name='latest', alt_model_path=None): 23 | """ 24 | select the model 25 | 26 | :param model_name: 27 | :return: 28 | """ 29 | 30 | models = { 31 | 'latest': '042801_base', 32 | '042801_base': '042801_base' 33 | } 34 | 35 | assert model_name in models, f"{model_name} is not available" 36 | 37 | return models[model_name] 38 | 39 | 40 | def pad_list(tensor_lst): 41 | max_length = max(t.size(0) for t in tensor_lst) 42 | batch_size = len(tensor_lst) 43 | 44 | padded_tensor = tensor_lst[0].new(batch_size, max_length, *tensor_lst[0].size()[1:]).fill_(0) 45 | for i, t in enumerate(tensor_lst): 46 | padded_tensor[i,:t.size(0)] = t 47 | 48 | return padded_tensor 49 | 50 | 51 | def pad_sos_eos(ys, sos, eos): 52 | 53 | batch_size = len(ys) 54 | 55 | sos_tensor = ys.new_zeros((batch_size, 1)).fill_(sos) 56 | y_in = torch.cat([sos_tensor, ys], dim=1) 57 | 58 | y_out = [] 59 | 60 | zero_tensor = ys.new_zeros((batch_size, 1)) 61 | extended_ys = torch.cat([ys, zero_tensor], dim=1) 62 | 63 | for y in extended_ys: 64 | #eos_idx = (y==0).nonzero()[0].item() 65 | #print(y) 66 | eos_idx = torch.nonzero(y==0)[0].item() 67 | y[eos_idx] = eos 68 | y_out.append(y) 69 | 70 | y_out = pad_list(y_out) 71 | return y_in, y_out -------------------------------------------------------------------------------- /transphone/model/vocab.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | class Vocab: 5 | 6 | def __init__(self, word_set): 7 | self.words = ['', '']+list(sorted(word_set)) 8 | self.map = dict() 9 | 10 | for i, word in enumerate(self.words): 11 | self.map[word] = i 12 | 13 | 14 | @classmethod 15 | def read(cls, file_path): 16 | 17 | vocab = cls([]) 18 | 19 | vocab.words = [] 20 | vocab.map = dict() 21 | 22 | for i, line in enumerate(open(Path(file_path), encoding='utf-8')): 23 | word = line.strip() 24 | vocab.words.append(word) 25 | vocab.map[word] = i 26 | 27 | return vocab 28 | 29 | def __len__(self): 30 | return len(self.words) 31 | 32 | def __contains__(self, item): 33 | return item in self.map 34 | 35 | def atoi(self, word): 36 | 37 | return self.map[word] 38 | 39 | def itoa(self, idx): 40 | word = self.words[idx] 41 | if word == '': 42 | word = ' ' 43 | 44 | return word 45 | 46 | def write(self, file_path): 47 | 48 | w = open(file_path, 'w', encoding='utf-8') 49 | 50 | for word in self.words: 51 | w.write(word+'\n') 52 | 53 | w.close() -------------------------------------------------------------------------------- /transphone/run.py: -------------------------------------------------------------------------------- 1 | from transphone.bin.download_model import download_model 2 | from transphone.model.utils import resolve_model_name 3 | from transphone.tokenizer import read_tokenizer 4 | from pathlib import Path 5 | import argparse 6 | import tqdm 7 | 8 | 9 | if __name__ == '__main__': 10 | 11 | parser = argparse.ArgumentParser('running transphone g2p model') 12 | parser.add_argument('-m', '--model', type=str, default='latest', 13 | help='specify which model to use. default is to use the latest local model') 14 | parser.add_argument('-l', '--lang', type=str, default='eng', 15 | help='specify which language inventory to use for recognition. default is to use all phone inventory') 16 | parser.add_argument('-i', '--input', type=str, required=True, help='specify your input wav file/directory') 17 | parser.add_argument('-o', '--output', type=str, default='stdout', 18 | help='specify output file. the default will be stdout') 19 | parser.add_argument('-d', '--device', help='specify device to use, if not specified, it will try using gpu when applicable') 20 | parser.add_argument('-f', '--format', type=str, default='text', help='kaldi or text') 21 | parser.add_argument('-c', '--combine', type=bool, default=False, 22 | help='write outputs by including both grapheme inputs and phonemes in the same line, delimited by space') 23 | 24 | args = parser.parse_args() 25 | 26 | # resolve model's name 27 | model_name = resolve_model_name(args.model) 28 | 29 | # format 30 | file_format = args.format 31 | 32 | if args.combine: 33 | assert file_format == 'text' 34 | 35 | # download specified model automatically if no model exists 36 | download_model(model_name) 37 | 38 | device = None 39 | if args.device is not None and isinstance(args.device, str) and str.isdigit(args.device): 40 | device = int(args.device) 41 | 42 | # create model 43 | tokenizer = read_tokenizer(args.lang, g2p_model=model_name, device=device) 44 | 45 | # output file descriptor 46 | output_fd = None 47 | if args.output != 'stdout': 48 | output_fd = open(args.output, 'w', encoding='utf-8') 49 | 50 | # input file/path 51 | input_path = Path(args.input) 52 | 53 | for line in tqdm.tqdm(open(input_path, 'r').readlines(), disable=args.output=='stdout'): 54 | fields = line.strip().split() 55 | 56 | utt_id = None 57 | 58 | if file_format == 'text': 59 | text = ' '.join(fields) 60 | else: 61 | text = ' '.join(fields[1:]) 62 | utt_id = fields[0] 63 | 64 | phonemes = tokenizer.tokenize(text) 65 | line_output = ' '.join(phonemes) 66 | 67 | if args.combine: 68 | line_output = text + '\t' + line_output 69 | 70 | if utt_id is not None: 71 | line_output = utt_id + ' ' + line_output 72 | 73 | if output_fd: 74 | output_fd.write(line_output + '\n') 75 | else: 76 | print(line_output) 77 | 78 | if output_fd: 79 | output_fd.close() -------------------------------------------------------------------------------- /transphone/tokenizer.py: -------------------------------------------------------------------------------- 1 | from phonepiece.lang import normalize_lang_id 2 | from phonepiece.lexicon import read_lexicon 3 | from transphone.lang.base_tokenizer import BaseTokenizer 4 | from transphone.lang.eng.tokenizer import read_eng_tokenizer 5 | from transphone.lang.cmn.tokenizer import read_cmn_tokenizer 6 | from transphone.lang.jpn.tokenizer import read_jpn_tokenizer 7 | from transphone.lang.g2p_tokenizer import read_g2p_tokenizer 8 | from transphone.model.utils import resolve_model_name 9 | from transphone.lang.epitran_tokenizer import read_epitran_tokenizer 10 | 11 | lang2tokenizer = { 12 | 'eng': read_eng_tokenizer, 13 | 'cmn': read_cmn_tokenizer, 14 | 'jpn': read_jpn_tokenizer, 15 | 'spa': read_epitran_tokenizer, 16 | 'deu': read_epitran_tokenizer, 17 | 'ita': read_epitran_tokenizer, 18 | 'rus': read_epitran_tokenizer, 19 | 'fra': read_epitran_tokenizer, 20 | 'vie': read_epitran_tokenizer, 21 | 'tha': read_epitran_tokenizer, 22 | 'swa': read_epitran_tokenizer, 23 | 'ckb': read_epitran_tokenizer, 24 | 'cat': read_epitran_tokenizer, 25 | } 26 | 27 | def read_tokenizer(lang_id, g2p_model='latest', device=None, use_lexicon=True): 28 | 29 | lang_id = normalize_lang_id(lang_id) 30 | 31 | if lang_id in lang2tokenizer: 32 | return lang2tokenizer[lang_id](lang_id=lang_id, g2p_model=g2p_model, device=device, use_lexicon=use_lexicon) 33 | else: 34 | return read_g2p_tokenizer(lang_id=lang_id, g2p_model=g2p_model, device=device) -------------------------------------------------------------------------------- /transphone/utils.py: -------------------------------------------------------------------------------- 1 | import pip 2 | import importlib 3 | 4 | def import_with_auto_install(package, package_name): 5 | try: 6 | return importlib.import_module(package) 7 | except ImportError: 8 | pip.main(['install', package_name]) 9 | return importlib.import_module(package) 10 | 11 | 12 | 13 | class Singleton(type): 14 | _instances = {} 15 | def __call__(cls, *args, **kwargs): 16 | if cls not in cls._instances: 17 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 18 | return cls._instances[cls] --------------------------------------------------------------------------------