├── model └── .place ├── sample.wav ├── docs ├── cover.png ├── sdk.py └── sdk.js ├── component ├── langdetect_fasttext │ ├── __init__.py │ ├── LICENSE │ ├── detect.py │ └── README.md ├── nlp_utils │ ├── __init__.py │ ├── cut.py │ └── detect.py ├── pike.py └── warp.py ├── onnx_infer ├── __init__.py ├── infer │ ├── __init__.py │ ├── commons.py │ ├── transforms.py │ └── modules.py └── utils │ ├── __init__.py │ ├── onnx_utils.py │ └── onnx_transforms.py ├── chinese_dialect_lexicons ├── cixi.ocd2 ├── fuyang.ocd2 ├── pinghu.ocd2 ├── ruao.ocd2 ├── sanmen.ocd2 ├── wuxi.ocd2 ├── xiashi.ocd2 ├── youbu.ocd2 ├── zhenru.ocd2 ├── changzhou.ocd2 ├── hangzhou.ocd2 ├── jiading.ocd2 ├── jiashan.ocd2 ├── jingjiang.ocd2 ├── jyutjyu.ocd2 ├── linping.ocd2 ├── ningbo.ocd2 ├── shaoxing.ocd2 ├── suichang.ocd2 ├── suzhou.ocd2 ├── tiantai.ocd2 ├── tongxiang.ocd2 ├── wenzhou.ocd2 ├── xiaoshan.ocd2 ├── yixing.ocd2 ├── zaonhe.ocd2 ├── wuxi.json ├── ningbo.json ├── suzhou.json ├── jyutjyu.json ├── xiashi.json ├── yixing.json ├── zaonhe.json ├── hangzhou.json ├── cixi.json ├── ruao.json ├── youbu.json ├── fuyang.json ├── pinghu.json ├── sanmen.json ├── zhenru.json ├── jiading.json ├── jiashan.json ├── linping.json ├── shaoxing.json ├── suichang.json ├── tiantai.json ├── wenzhou.json ├── xiaoshan.json ├── changzhou.json ├── jingjiang.json └── tongxiang.json ├── .env ├── CLAUSE ├── pm2.json ├── requirements.txt ├── CITATION.cff ├── Dockerfile ├── text ├── LICENSE ├── __init__.py ├── ngu_dialect.py ├── thai.py ├── sanskrit.py ├── cantonese.py ├── shanghainese.py ├── symbols.py ├── japanese.py ├── english.py ├── korean.py ├── cleaners.py └── mandarin.py ├── LICENSE_ORIGIN ├── deploy_script.sh ├── .github └── workflows │ └── docker-latest.yaml ├── LICENSE ├── .gitattributes ├── main.py ├── utils.py ├── mel_processing.py ├── README.md ├── pth2onnx.py ├── hubert_model.py ├── .gitignore ├── server.py └── test └── onnx_test.ipynb /model/.place: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LlmKira/VitsServer/HEAD/sample.wav -------------------------------------------------------------------------------- /docs/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LlmKira/VitsServer/HEAD/docs/cover.png -------------------------------------------------------------------------------- /component/langdetect_fasttext/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .detect import detect 3 | 4 | __all__ = ["detect"] 5 | -------------------------------------------------------------------------------- /onnx_infer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/7 下午11:29 3 | # @Author : sudoskys 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /onnx_infer/infer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/7 下午11:06 3 | # @Author : sudoskys 4 | # @File : __init__.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /onnx_infer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/7 下午11:28 3 | # @Author : sudoskys 4 | # @File : __init__.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/cixi.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8113aca87c4728c66cfa6c7b5adfbb596a2930df9b7c6187c6a227ff2de87f00 3 | size 98015 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/fuyang.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:98e1fbec75e090550cf131de226a1d867c7896b51170f8d7d21f9101297f4c08 3 | size 83664 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/pinghu.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:01b0e0dad8cddb0e2cb23899d4a2f97f2c0b369d5ff369076c5cdb7bd4528e4f 3 | size 69420 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/ruao.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:259a42ad761233f7d6ca6eec39268e27a65b2ded025f2b7725501cf5e3e02d8a 3 | size 58841 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/sanmen.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:afa70a920b6805e279ed15246026b70dbeb2a8329ad585fbae8cfdf45e7489a9 3 | size 80210 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/wuxi.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:64f27ffaa75e542e4464e53c4acf94607be1526a90922ac8b28870104aaebdff 3 | size 358666 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/xiashi.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2bc638633b82e196776a3adfc621c854d0da923b7cff6e7d0c9576723cdc03cd 3 | size 70314 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/youbu.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5fb1aef6b12d9474249717ce7c5b5303aeeea4d8d26943d62d269568b2985c17 3 | size 84985 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/zhenru.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d71dd437715e1055534f929e1591b11086a265d87480694e723eb4a6a95874e8 3 | size 56967 4 | -------------------------------------------------------------------------------- /component/nlp_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/5 上午10:17 3 | # @Author : sudoskys 4 | # @File : __init__.py.py 5 | # @Software: PyCharm 6 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/changzhou.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:db4ec02be9812e804291a88f9a984f544e221ed472f682bba8da5ecbefbabd8c 3 | size 96119 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/hangzhou.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c7a9eb5fbd3b8c91745dbb2734f2700b75a47c3821e381566afc567d7da4d9d5 3 | size 427268 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/jiading.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4f3ac33214e65e7223e8c561bc12ec90a2d87db3cf8d20e87a30bbd8eb788187 3 | size 111144 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/jiashan.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6310729b85976b6e6407b4f66ad13a3ad7a51a42f3c05c98e294bcbb3159456c 3 | size 71716 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/jingjiang.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:154d9cac032a3284a6aa175689a5805f068f6896429009a7d94d41616694131f 3 | size 86093 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/jyutjyu.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:aea11bfe51b184b3f000d20ab49757979b216219203839d2b2e3c1f990a13fa5 3 | size 2432991 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/linping.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7fcd3b53e5aa6cd64419835c14769d53cc230e229c0fbd20efb65c46e07b712b 3 | size 65351 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/ningbo.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5876b000f611ea52bf18cda5bcbdd0cfcc55e1c09774d9a24e3b5c7d90002435 3 | size 386414 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/shaoxing.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a347aa25bf435803727b4194cf34de4de3e61f03427ee21043a711cdb0b9d940 3 | size 113108 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/suichang.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a8062749ff70db65d469d91bd92375607f8648a138b896e58cf7c28edb8f970e 3 | size 81004 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/suzhou.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a8815595a248135874329e7f34662dd243a266be3e8375e8409f95da95d6d540 3 | size 506184 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/tiantai.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:339e0ba454444dbf8fbe75de6f49769d11dfe2f2f5ba7dea74ba20fba5d6d343 3 | size 120951 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/tongxiang.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7432d85588eb8ba34e7baea9f26af8d332572037ff7d41a6730f96c02e5fd063 3 | size 137499 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/wenzhou.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ed05c0c615a38f55a139a73bcc3960897d8cd567c9482a0a06b272eb0b46aa05 3 | size 83121 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/xiaoshan.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:651bd314c5e57312eeee537037f6c6e56a12ef446216264aad70bf68bf6a283d 3 | size 77119 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/yixing.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6c56a73eb531f49f64562bdb714753d37dc015baac943b3264bccba9b2aacf9b 3 | size 155050 4 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/zaonhe.ocd2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a71b5a97eb49699f440137391565d208ea82156f0765986b7f3e16909e15672e 3 | size 4095228 4 | -------------------------------------------------------------------------------- /.env: -------------------------------------------------------------------------------- 1 | VITS_SERVER_HOST=0.0.0.0 2 | VITS_SERVER_PORT=9557 3 | VITS_SERVER_RELOAD=false 4 | 5 | 6 | 7 | # VITS_SERVER_INIT_MODEL=https://huggingface.co/Wce.pth 8 | # VITS_SERVER_INIT_CONFIG=https://huggingface.co/Win.json -------------------------------------------------------------------------------- /CLAUSE: -------------------------------------------------------------------------------- 1 | 免责声明: 2 | 3 | 本项目的使用者应当遵守所有适用的法律和规定,并且对其使用本项目所产生的后果负有全部责任。 4 | 5 | 本项目的作者和贡献者不对使用本项目可能导致的任何损失或损害承担任何责任。使用者应自行承担使用本项目所造成的风险和后果,包括但不限于数据丢失、计算错误等。 6 | 7 | 任何恶意伪造、侵犯肖像权或其他侵权行为均与本项目及其作者和贡献者无关。如果使用者违反了适用法律或规定,则将自行承担法律责任。 8 | 9 | 此免责声明适用于所有使用本项目的个人和组织,无论是否已经取得授权或许可。在使用本项目之前,请您务必确认已经仔细阅读了此免责声明并理解了其中的内容。 -------------------------------------------------------------------------------- /pm2.json: -------------------------------------------------------------------------------- 1 | { 2 | "apps": [ 3 | { 4 | "name": "vits server", 5 | "script": "pipenv shell;python3 main.py", 6 | "instances": 1, 7 | "error_file": "main_error.log", 8 | "out_file": "main_out.log", 9 | "log_date_format": "YYYY-MM-DD HH-mm-ss", 10 | "max_memory_restart": "10240M" 11 | } 12 | ] 13 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/wuxi.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Wuxinese to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "wuxi.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [{ 11 | "dict": { 12 | "type": "group", 13 | "dicts": [{ 14 | "type": "ocd2", 15 | "file": "wuxi.ocd2" 16 | }] 17 | } 18 | }] 19 | } 20 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/ningbo.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Ningbonese to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "ningbo.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [{ 11 | "dict": { 12 | "type": "group", 13 | "dicts": [{ 14 | "type": "ocd2", 15 | "file": "ningbo.ocd2" 16 | }] 17 | } 18 | }] 19 | } 20 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/suzhou.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Suzhounese to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "suzhou.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [{ 11 | "dict": { 12 | "type": "group", 13 | "dicts": [{ 14 | "type": "ocd2", 15 | "file": "suzhou.ocd2" 16 | }] 17 | } 18 | }] 19 | } 20 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/jyutjyu.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Cantonese to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "jyutjyu.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [{ 11 | "dict": { 12 | "type": "group", 13 | "dicts": [{ 14 | "type": "ocd2", 15 | "file": "jyutjyu.ocd2" 16 | }] 17 | } 18 | }] 19 | } 20 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/xiashi.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Xiashi dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "xiashi.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [{ 11 | "dict": { 12 | "type": "group", 13 | "dicts": [{ 14 | "type": "ocd2", 15 | "file": "xiashi.ocd2" 16 | }] 17 | } 18 | }] 19 | } 20 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/yixing.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Yixing dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "yixing.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [{ 11 | "dict": { 12 | "type": "group", 13 | "dicts": [{ 14 | "type": "ocd2", 15 | "file": "yixing.ocd2" 16 | }] 17 | } 18 | }] 19 | } 20 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/zaonhe.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Shanghainese to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "zaonhe.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [{ 11 | "dict": { 12 | "type": "group", 13 | "dicts": [{ 14 | "type": "ocd2", 15 | "file": "zaonhe.ocd2" 16 | }] 17 | } 18 | }] 19 | } 20 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/hangzhou.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Hangzhounese to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "hangzhou.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [{ 11 | "dict": { 12 | "type": "group", 13 | "dicts": [{ 14 | "type": "ocd2", 15 | "file": "hangzhou.ocd2" 16 | }] 17 | } 18 | }] 19 | } 20 | -------------------------------------------------------------------------------- /chinese_dialect_lexicons/cixi.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Cixi dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "cixi.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "cixi.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/ruao.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Ruao dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "ruao.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "ruao.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/youbu.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Youbu dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "youbu.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "youbu.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/fuyang.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Fuyang dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "fuyang.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "fuyang.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/pinghu.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Pinghu dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "pinghu.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "pinghu.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/sanmen.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Sanmen dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "sanmen.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "sanmen.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/zhenru.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Zhenru dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "zhenru.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "zhenru.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/jiading.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Jiading dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "jiading.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "jiading.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/jiashan.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Jiashan dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "jiashan.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "jiashan.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/linping.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Linping dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "linping.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "linping.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/shaoxing.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Shaoxing dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "shaoxing.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "shaoxing.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/suichang.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Suichang dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "suichang.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "suichang.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/tiantai.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tiantai dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "tiantai.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "tiantai.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/wenzhou.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Wenzhou dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "wenzhou.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "wenzhou.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/xiaoshan.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Xiaoshan dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "xiaoshan.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "xiaoshan.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/changzhou.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Changzhou dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "changzhou.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "changzhou.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/jingjiang.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Jingjiang dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "jingjiang.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "jingjiang.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /chinese_dialect_lexicons/tongxiang.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Tongxiang dialect to IPA", 3 | "segmentation": { 4 | "type": "mmseg", 5 | "dict": { 6 | "type": "ocd2", 7 | "file": "tongxiang.ocd2" 8 | } 9 | }, 10 | "conversion_chain": [ 11 | { 12 | "dict": { 13 | "type": "group", 14 | "dicts": [ 15 | { 16 | "type": "ocd2", 17 | "file": "tongxiang.ocd2" 18 | } 19 | ] 20 | } 21 | } 22 | ] 23 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython>=0.29.21 2 | numba 3 | librosa 4 | numpy 5 | torch 6 | unidecode 7 | pyopenjtalk 8 | jamo 9 | pypinyin 10 | pypinyin_dict 11 | python-dotenv 12 | jieba 13 | protobuf 14 | cn2an 15 | inflect 16 | eng_to_ipa 17 | ko_pron 18 | indic_transliteration 19 | num_thai 20 | opencc 21 | uvicorn 22 | fastapi 23 | pydantic 24 | loguru 25 | soundfile 26 | graiax-silkcoder[libsndfile] 27 | psutil 28 | fasttext-wheel 29 | chardet 30 | onnx 31 | onnxruntime 32 | loguru 33 | monotonic_align 34 | tqdm 35 | httpx 36 | python-multipart -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: "GitHub - LlmKira/VitsServer: 🌻 A VITS ONNX server designed for fast inference" 3 | abstract: This repository contains the source code for VitsServer, a server designed for fast inference of ONNX models in the VITS format. 4 | authors: 5 | - name: LlmKira 6 | type: Organization 7 | url: https://github.com/LlmKira 8 | keywords: 9 | - vits 10 | - onnx 11 | version: 1.0.0 12 | date-released: 2023-04-01 13 | url: https://github.com/LlmKira/VitsServer 14 | citation: 15 | - text: "LlmKira (2023). GitHub - LlmKira/VitsServer: 🌻 A VITS ONNX server designed for fast inference. GitHub." 16 | doi: 17 | - text: "LlmKira. (2023). VitsServer [Source code]. GitHub. https://github.com/LlmKira/VitsServer" 18 | doi: 19 | license: BSD-3-Clause 20 | repository-code: https://github.com/LlmKira/VitsServer -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Stage 1 - Build dependencies from source 2 | FROM python:3.8-slim AS builder 3 | 4 | # Keeps Python from generating .pyc files in the container 5 | ENV PYTHONDONTWRITEBYTECODE=1 6 | 7 | # Turns off buffering for easier container logging 8 | ENV PYTHONUNBUFFERED=1 9 | 10 | RUN apt-get update && \ 11 | apt-get install -y build-essential libsndfile1 vim gcc g++ cmake 12 | 13 | RUN python3 -m pip install Cython --install-option="--no-cython-compile" 14 | 15 | # These are some packages who will take a lot of time to build 16 | RUN python3 -m pip install --upgrade pip numpy numba pyopenjtalk 17 | 18 | WORKDIR /build 19 | 20 | COPY requirements.txt . 21 | 22 | RUN python3 -m pip install -r requirements.txt 23 | 24 | # Stage 2 - Runtime image 25 | FROM python:3.8-slim 26 | 27 | ENV SERVER_HOST='0.0.0.0' \ 28 | SERVER_PORT=9557 29 | 30 | EXPOSE $SERVER_PORT 31 | 32 | WORKDIR /app 33 | 34 | COPY --from=builder /usr/local/ /usr/local/ 35 | 36 | COPY . . 37 | 38 | CMD ["python3", "main.py"] 39 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /LICENSE_ORIGIN: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 CjangCjengh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from text import cleaners 3 | 4 | 5 | def text_to_sequence(text, symbols, cleaner_names): 6 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 7 | Args: 8 | text: string to convert to a sequence 9 | cleaner_names: names of the cleaner functions to run the text through 10 | Returns: 11 | List of integers corresponding to the symbols in the text 12 | ''' 13 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 14 | 15 | sequence = [] 16 | 17 | clean_text = _clean_text(text, cleaner_names) 18 | for symbol in clean_text: 19 | if symbol not in _symbol_to_id.keys(): 20 | continue 21 | symbol_id = _symbol_to_id[symbol] 22 | sequence += [symbol_id] 23 | return sequence 24 | 25 | 26 | def _clean_text(text, cleaner_names): 27 | for name in cleaner_names: 28 | cleaner = getattr(cleaners, name) 29 | if not cleaner: 30 | raise Exception('Unknown cleaner: %s' % name) 31 | text = cleaner(text) 32 | return text 33 | -------------------------------------------------------------------------------- /component/langdetect_fasttext/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Zafer Çavdar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /text/ngu_dialect.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import opencc 4 | 5 | dialects = {'SZ': 'suzhou', 'WX': 'wuxi', 'CZ': 'changzhou', 'HZ': 'hangzhou', 6 | 'SX': 'shaoxing', 'NB': 'ningbo', 'JJ': 'jingjiang', 'YX': 'yixing', 7 | 'JD': 'jiading', 'ZR': 'zhenru', 'PH': 'pinghu', 'TX': 'tongxiang', 8 | 'JS': 'jiashan', 'HN': 'xiashi', 'LP': 'linping', 'XS': 'xiaoshan', 9 | 'FY': 'fuyang', 'RA': 'ruao', 'CX': 'cixi', 'SM': 'sanmen', 10 | 'TT': 'tiantai', 'WZ': 'wenzhou', 'SC': 'suichang', 'YB': 'youbu'} 11 | 12 | converters = {} 13 | 14 | for dialect in dialects.values(): 15 | try: 16 | converters[dialect] = opencc.OpenCC(dialect) 17 | except: 18 | pass 19 | 20 | 21 | def ngu_dialect_to_ipa(text, dialect): 22 | dialect = dialects[dialect] 23 | text = converters[dialect].convert(text).replace('-', '').replace('$', ' ') 24 | text = re.sub(r'[、;:]', ',', text) 25 | text = re.sub(r'\s*,\s*', ', ', text) 26 | text = re.sub(r'\s*。\s*', '. ', text) 27 | text = re.sub(r'\s*?\s*', '? ', text) 28 | text = re.sub(r'\s*!\s*', '! ', text) 29 | text = re.sub(r'\s*$', '', text) 30 | return text 31 | -------------------------------------------------------------------------------- /text/thai.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from num_thai.thainumbers import NumThai 4 | 5 | num = NumThai() 6 | 7 | # List of (Latin alphabet, Thai) pairs: 8 | _latin_to_thai = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 9 | ('a', 'เอ'), 10 | ('b', 'บี'), 11 | ('c', 'ซี'), 12 | ('d', 'ดี'), 13 | ('e', 'อี'), 14 | ('f', 'เอฟ'), 15 | ('g', 'จี'), 16 | ('h', 'เอช'), 17 | ('i', 'ไอ'), 18 | ('j', 'เจ'), 19 | ('k', 'เค'), 20 | ('l', 'แอล'), 21 | ('m', 'เอ็ม'), 22 | ('n', 'เอ็น'), 23 | ('o', 'โอ'), 24 | ('p', 'พี'), 25 | ('q', 'คิว'), 26 | ('r', 'แอร์'), 27 | ('s', 'เอส'), 28 | ('t', 'ที'), 29 | ('u', 'ยู'), 30 | ('v', 'วี'), 31 | ('w', 'ดับเบิลยู'), 32 | ('x', 'เอ็กซ์'), 33 | ('y', 'วาย'), 34 | ('z', 'ซี') 35 | ]] 36 | 37 | 38 | def num_to_thai(text): 39 | return re.sub(r'(?:\d+(?:,?\d+)?)+(?:\.\d+(?:,?\d+)?)?', 40 | lambda x: ''.join(num.NumberToTextThai(float(x.group(0).replace(',', '')))), text) 41 | 42 | 43 | def latin_to_thai(text): 44 | for regex, replacement in _latin_to_thai: 45 | text = re.sub(regex, replacement, text) 46 | return text 47 | -------------------------------------------------------------------------------- /deploy_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if the VitsServer directory already exists 4 | if [ -d "VitsServer" ]; then 5 | echo "VitsServer directory already exists, updating..." 6 | cd VitsServer 7 | git pull --exclude=.env --exclude=model 8 | exit 9 | else 10 | # Clone the repository 11 | git clone https://github.com/LlmKira/VitsServer.git 12 | cd VitsServer 13 | fi 14 | 15 | # Update system packages 16 | sudo apt-get update && 17 | sudo apt-get install -y build-essential libsndfile1 vim gcc g++ cmake 18 | 19 | # Install Python dependencies 20 | sudo apt install python3-pip 21 | pip3 install pipenv 22 | 23 | # Install dependency packages 24 | pipenv install 25 | 26 | # Activate the virtual environment 27 | pipenv shell 28 | 29 | # Set up the configuration file 30 | touch .env 31 | echo "VITS_SERVER_HOST=0.0.0.0" > .env 32 | echo "VITS_SERVER_PORT=9557" >> .env 33 | echo "VITS_SERVER_RELOAD=false" >> .env 34 | 35 | # Start the server using PM2 36 | sudo apt install npm 37 | npm install pm2 -g 38 | pm2 start pm2.json --name vits_server --watch 39 | 40 | # Save the PM2 process list so it will be started at boot 41 | sudo pm2 save 42 | 43 | # Exit the virtual environment 44 | exit 45 | -------------------------------------------------------------------------------- /component/pike.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/7 下午10:16 3 | # @Author : sudoskys 4 | # @File : pike.py 5 | # @Software: PyCharm 6 | 7 | 8 | class OnnxReader(object): 9 | 10 | @staticmethod 11 | def get_onnx_file_path(path: str): 12 | # 读取某个文件夹,将 .onnx 后缀的文件路径映射到一个kv表 13 | import os 14 | file_path = {} 15 | for root, dirs, files in os.walk(path): 16 | for file in files: 17 | if os.path.splitext(file)[1] == '.onnx': 18 | file_path[os.path.splitext(file)[0]] = os.path.join(root, file) 19 | return file_path 20 | 21 | # 读取某个文件夹,如果有n个onnx 文件,就打包成 tar.gz 22 | @staticmethod 23 | def get_onnx_file(path: str, n: int = 5): 24 | """ 25 | :param path: 文件夹路径 26 | :param n: onnx文件个数 27 | """ 28 | import tarfile 29 | file_path = OnnxReader.get_onnx_file_path(path) 30 | if len(file_path) > n: 31 | tar = tarfile.open(path + ".tar.gz", "w:gz") 32 | for file in file_path: 33 | tar.add(file_path[file], arcname=file) 34 | tar.close() 35 | return path + ".tar.gz" 36 | else: 37 | return None 38 | -------------------------------------------------------------------------------- /onnx_infer/utils/onnx_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import onnxruntime as ort 5 | 6 | 7 | # import torch.backends.cuda 8 | 9 | 10 | def set_random_seed(seed=0): 11 | import torch.backends.cudnn 12 | ort.set_seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.backends.cudnn.deterministic = True 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | 19 | 20 | class RunONNX(object): 21 | def __init__(self, model=None, 22 | providers=None): 23 | self.ort_session = None 24 | if model: 25 | self.load(model, providers=providers) 26 | 27 | def load(self, 28 | model, 29 | providers=None 30 | ): 31 | # 如果是 ByteIO 类,则转换为 bytes 32 | if providers is None: 33 | providers = ['CPUExecutionProvider'] 34 | if hasattr(model, "getvalue"): 35 | model = model.getvalue() 36 | # 创造运行时 37 | self.ort_session = ort.InferenceSession(model, providers=providers) 38 | 39 | def run(self, model_input): 40 | outputs = self.ort_session.run( 41 | None, 42 | input_feed=model_input 43 | ) 44 | return outputs 45 | -------------------------------------------------------------------------------- /.github/workflows/docker-latest.yaml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI 2 | 3 | on: 4 | workflow_dispatch: 5 | push: 6 | branches: 7 | - 'feat-ci' 8 | release: 9 | types: [ published ] 10 | 11 | jobs: 12 | build-and-push-image: 13 | runs-on: ubuntu-latest 14 | permissions: 15 | contents: read 16 | packages: write 17 | 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v3 21 | - name: Set up QEMU 22 | uses: docker/setup-qemu-action@v2 23 | - name: Set up Docker Buildx 24 | uses: docker/setup-buildx-action@v2 25 | - name: Login to Docker Hub 26 | uses: docker/login-action@v2 27 | with: 28 | username: ${{ secrets.DOCKERHUB_USERNAME }} 29 | password: ${{ secrets.DOCKERHUB_TOKEN }} 30 | - name: Extract metadata (tags, labels) for Docker 31 | id: meta 32 | uses: docker/metadata-action@v4.1.1 33 | with: 34 | images: sudoskys/vits-server 35 | - name: Build and push 36 | uses: docker/build-push-action@v4 37 | with: 38 | context: . 39 | push: true 40 | platforms: 'linux/amd64' 41 | tags: ${{ steps.meta.outputs.tags }} 42 | labels: ${{ steps.meta.outputs.labels }} 43 | cache-from: type=gha 44 | cache-to: type=gha,mode=max 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, LLM Kira 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /text/sanskrit.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from indic_transliteration import sanscript 4 | 5 | # List of (iast, ipa) pairs: 6 | _iast_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 7 | ('a', 'ə'), 8 | ('ā', 'aː'), 9 | ('ī', 'iː'), 10 | ('ū', 'uː'), 11 | ('ṛ', 'ɹ`'), 12 | ('ṝ', 'ɹ`ː'), 13 | ('ḷ', 'l`'), 14 | ('ḹ', 'l`ː'), 15 | ('e', 'eː'), 16 | ('o', 'oː'), 17 | ('k', 'k⁼'), 18 | ('k⁼h', 'kʰ'), 19 | ('g', 'g⁼'), 20 | ('g⁼h', 'gʰ'), 21 | ('ṅ', 'ŋ'), 22 | ('c', 'ʧ⁼'), 23 | ('ʧ⁼h', 'ʧʰ'), 24 | ('j', 'ʥ⁼'), 25 | ('ʥ⁼h', 'ʥʰ'), 26 | ('ñ', 'n^'), 27 | ('ṭ', 't`⁼'), 28 | ('t`⁼h', 't`ʰ'), 29 | ('ḍ', 'd`⁼'), 30 | ('d`⁼h', 'd`ʰ'), 31 | ('ṇ', 'n`'), 32 | ('t', 't⁼'), 33 | ('t⁼h', 'tʰ'), 34 | ('d', 'd⁼'), 35 | ('d⁼h', 'dʰ'), 36 | ('p', 'p⁼'), 37 | ('p⁼h', 'pʰ'), 38 | ('b', 'b⁼'), 39 | ('b⁼h', 'bʰ'), 40 | ('y', 'j'), 41 | ('ś', 'ʃ'), 42 | ('ṣ', 's`'), 43 | ('r', 'ɾ'), 44 | ('l̤', 'l`'), 45 | ('h', 'ɦ'), 46 | ("'", ''), 47 | ('~', '^'), 48 | ('ṃ', '^') 49 | ]] 50 | 51 | 52 | def devanagari_to_ipa(text): 53 | text = text.replace('ॐ', 'ओम्') 54 | text = re.sub(r'\s*।\s*$', '.', text) 55 | text = re.sub(r'\s*।\s*', ', ', text) 56 | text = re.sub(r'\s*॥', '.', text) 57 | text = sanscript.transliterate(text, sanscript.DEVANAGARI, sanscript.IAST) 58 | for regex, replacement in _iast_to_ipa: 59 | text = re.sub(regex, replacement, text) 60 | text = re.sub('(.)[`ː]*ḥ', lambda x: x.group(0) 61 | [:-1] + 'h' + x.group(1) + '*', text) 62 | return text 63 | -------------------------------------------------------------------------------- /text/cantonese.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import cn2an 4 | import opencc 5 | 6 | converter = opencc.OpenCC('jyutjyu') 7 | 8 | # List of (Latin alphabet, ipa) pairs: 9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 10 | ('A', 'ei˥'), 11 | ('B', 'biː˥'), 12 | ('C', 'siː˥'), 13 | ('D', 'tiː˥'), 14 | ('E', 'iː˥'), 15 | ('F', 'e˥fuː˨˩'), 16 | ('G', 'tsiː˥'), 17 | ('H', 'ɪk̚˥tsʰyː˨˩'), 18 | ('I', 'ɐi˥'), 19 | ('J', 'tsei˥'), 20 | ('K', 'kʰei˥'), 21 | ('L', 'e˥llou˨˩'), 22 | ('M', 'ɛːm˥'), 23 | ('N', 'ɛːn˥'), 24 | ('O', 'ou˥'), 25 | ('P', 'pʰiː˥'), 26 | ('Q', 'kʰiːu˥'), 27 | ('R', 'aː˥lou˨˩'), 28 | ('S', 'ɛː˥siː˨˩'), 29 | ('T', 'tʰiː˥'), 30 | ('U', 'juː˥'), 31 | ('V', 'wiː˥'), 32 | ('W', 'tʊk̚˥piː˥juː˥'), 33 | ('X', 'ɪk̚˥siː˨˩'), 34 | ('Y', 'waːi˥'), 35 | ('Z', 'iː˨sɛːt̚˥') 36 | ]] 37 | 38 | 39 | def number_to_cantonese(text): 40 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: cn2an.an2cn(x.group()), text) 41 | 42 | 43 | def latin_to_ipa(text): 44 | for regex, replacement in _latin_to_ipa: 45 | text = re.sub(regex, replacement, text) 46 | return text 47 | 48 | 49 | def cantonese_to_ipa(text): 50 | text = number_to_cantonese(text.upper()) 51 | text = converter.convert(text).replace('-', '').replace('$', ' ') 52 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group()) + ' ', text) 53 | text = re.sub(r'[、;:]', ',', text) 54 | text = re.sub(r'\s*,\s*', ', ', text) 55 | text = re.sub(r'\s*。\s*', '. ', text) 56 | text = re.sub(r'\s*?\s*', '? ', text) 57 | text = re.sub(r'\s*!\s*', '! ', text) 58 | text = re.sub(r'\s*$', '', text) 59 | return text 60 | -------------------------------------------------------------------------------- /text/shanghainese.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import cn2an 4 | import opencc 5 | 6 | converter = opencc.OpenCC('chinese_dialect_lexicons/zaonhe') 7 | 8 | # List of (Latin alphabet, ipa) pairs: 9 | _latin_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 10 | ('A', 'ᴇ'), 11 | ('B', 'bi'), 12 | ('C', 'si'), 13 | ('D', 'di'), 14 | ('E', 'i'), 15 | ('F', 'ᴇf'), 16 | ('G', 'dʑi'), 17 | ('H', 'ᴇtɕʰ'), 18 | ('I', 'ᴀi'), 19 | ('J', 'dʑᴇ'), 20 | ('K', 'kʰᴇ'), 21 | ('L', 'ᴇl'), 22 | ('M', 'ᴇm'), 23 | ('N', 'ᴇn'), 24 | ('O', 'o'), 25 | ('P', 'pʰi'), 26 | ('Q', 'kʰiu'), 27 | ('R', 'ᴀl'), 28 | ('S', 'ᴇs'), 29 | ('T', 'tʰi'), 30 | ('U', 'ɦiu'), 31 | ('V', 'vi'), 32 | ('W', 'dᴀbɤliu'), 33 | ('X', 'ᴇks'), 34 | ('Y', 'uᴀi'), 35 | ('Z', 'zᴇ') 36 | ]] 37 | 38 | 39 | def _number_to_shanghainese(num): 40 | num = cn2an.an2cn(num).replace('一十', '十').replace('二十', '廿').replace('二', '两') 41 | return re.sub(r'((?:^|[^三四五六七八九])十|廿)两', r'\1二', num) 42 | 43 | 44 | def number_to_shanghainese(text): 45 | return re.sub(r'\d+(?:\.?\d+)?', lambda x: _number_to_shanghainese(x.group()), text) 46 | 47 | 48 | def latin_to_ipa(text): 49 | for regex, replacement in _latin_to_ipa: 50 | text = re.sub(regex, replacement, text) 51 | return text 52 | 53 | 54 | def shanghainese_to_ipa(text): 55 | text = number_to_shanghainese(text.upper()) 56 | text = converter.convert(text).replace('-', '').replace('$', ' ') 57 | text = re.sub(r'[A-Z]', lambda x: latin_to_ipa(x.group()) + ' ', text) 58 | text = re.sub(r'[、;:]', ',', text) 59 | text = re.sub(r'\s*,\s*', ', ', text) 60 | text = re.sub(r'\s*。\s*', '. ', text) 61 | text = re.sub(r'\s*?\s*', '? ', text) 62 | text = re.sub(r'\s*!\s*', '! ', text) 63 | text = re.sub(r'\s*$', '', text) 64 | return text 65 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Defines the set of symbols used in text input to the model. 3 | ''' 4 | 5 | # japanese_cleaners 6 | _pad = '_' 7 | _punctuation = ',.!?-' 8 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ ' 9 | 10 | '''# japanese_cleaners2 11 | _pad = '_' 12 | _punctuation = ',.!?-~…' 13 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ ' 14 | ''' 15 | 16 | '''# korean_cleaners 17 | _pad = '_' 18 | _punctuation = ',.!?…~' 19 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' 20 | ''' 21 | 22 | '''# chinese_cleaners 23 | _pad = '_' 24 | _punctuation = ',。!?—…' 25 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ ' 26 | ''' 27 | 28 | '''# zh_ja_mixture_cleaners 29 | _pad = '_' 30 | _punctuation = ',.!?-~…' 31 | _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ ' 32 | ''' 33 | 34 | '''# sanskrit_cleaners 35 | _pad = '_' 36 | _punctuation = '।' 37 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ ' 38 | ''' 39 | 40 | '''# cjks_cleaners 41 | _pad = '_' 42 | _punctuation = ',.!?-~…' 43 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ ' 44 | ''' 45 | 46 | '''# thai_cleaners 47 | _pad = '_' 48 | _punctuation = '.!? ' 49 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์' 50 | ''' 51 | 52 | '''# cjke_cleaners2 53 | _pad = '_' 54 | _punctuation = ',.!?-~…' 55 | _letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' 56 | ''' 57 | 58 | '''# shanghainese_cleaners 59 | _pad = '_' 60 | _punctuation = ',.!?…' 61 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 ' 62 | ''' 63 | 64 | '''# chinese_dialect_cleaners 65 | _pad = '_' 66 | _punctuation = ',.!?~…─' 67 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ ' 68 | ''' 69 | 70 | # Export all symbols: 71 | symbols = [_pad] + list(_punctuation) + list(_letters) 72 | 73 | # Special symbol ids 74 | SPACE_ID = symbols.index(" ") 75 | -------------------------------------------------------------------------------- /component/nlp_utils/cut.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/5 上午10:18 3 | # @Author : sudoskys 4 | # @File : cut.py 5 | # @Software: PyCharm 6 | import re 7 | 8 | 9 | class Cut(object): 10 | @staticmethod 11 | def english_sentence_cut(text) -> list: 12 | list_ = list() 13 | for s_str in text.split('.'): 14 | if '?' in s_str: 15 | list_.extend(s_str.split('?')) 16 | elif '!' in s_str: 17 | list_.extend(s_str.split('!')) 18 | else: 19 | list_.append(s_str) 20 | return list_ 21 | 22 | @staticmethod 23 | def chinese_sentence_cut(text) -> list: 24 | """ 25 | 中文断句 26 | """ 27 | # 根据句子末尾的标点符号进行切分 28 | text = re.sub('([^\n])([.!?。!?]+)(?!\d|[a-zA-Z]|[^\s\u4e00-\u9fa5\u3040-\u309f\u30a0-\u30ff\uac00-\ud7a3])', 29 | r'\1\2\n', text) 30 | # 根据中文断句符号进行切分 31 | text = re.sub('([。::!?\?])([^’”])', r'\1\n\2', text) 32 | # 普通断句符号且后面没有引号 33 | text = re.sub('(\.{6})([^’”])', r'\1\n\2', text) 34 | # 英文省略号且后面没有引号 35 | text = re.sub('(\…{2})([^’”])', r'\1\n\2', text) 36 | # 中文省略号且后面没有引号 37 | text = re.sub('([.。!?\?\.{6}\…{2}][’”])([^’”])', r'\1\n\2', text) 38 | # 根据英文断句号+空格进行切分 39 | text = re.sub('(\. )([^a-zA-Z\d])', r'\1\n\2', text) 40 | # 删除多余的换行符 41 | text = re.sub('\n\n+', '\n', text) 42 | # 断句号+引号且后面没有引号 43 | return text.split("\n") 44 | 45 | def cut_chinese_sentence(self, text): 46 | """ 47 | 中文断句 48 | """ 49 | p = re.compile("“.*?”") 50 | listr = [] 51 | index = 0 52 | for i in p.finditer(text): 53 | temp = '' 54 | start = i.start() 55 | end = i.end() 56 | for j in range(index, start): 57 | temp += text[j] 58 | if temp != '': 59 | temp_list = self.chinese_sentence_cut(temp) 60 | listr += temp_list 61 | temp = '' 62 | for k in range(start, end): 63 | temp += text[k] 64 | if temp != ' ': 65 | listr.append(temp) 66 | index = end 67 | return listr 68 | -------------------------------------------------------------------------------- /docs/sdk.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | class VITS: 5 | def __init__(self, base_url): 6 | self.base_url = base_url 7 | 8 | def get_model_list(self, show_speaker=False, show_ms_config=False): 9 | url = f"{self.base_url}/model/list?show_speaker={show_speaker}&show_ms_config={show_ms_config}" 10 | res = requests.get(url) 11 | return res.json() 12 | 13 | def get_model_info(self, model_id): 14 | url = f"{self.base_url}/model/info?model_id={model_id}" 15 | res = requests.get(url) 16 | return res.json() 17 | 18 | def parse_text(self, text, strip=False, merge_same=False, cell_limit=140, filter_space=True): 19 | url = f"{self.base_url}/tts/parse" 20 | data = { 21 | "text": text, 22 | "strip": strip, 23 | "merge_same": merge_same, 24 | "cell_limit": cell_limit, 25 | "filter_space": filter_space 26 | } 27 | res = requests.post(url, json=data) 28 | return res.json() 29 | 30 | def generate_voice(self, model_id, text, speaker_id=0, audio_type="wav", length_scale=1.3, 31 | noise_scale=0.6, noise_scale_w=0.6, load_prefer=False, auto_parse=True): 32 | url = f"{self.base_url}/tts/generate" 33 | data = { 34 | "model_id": model_id, 35 | "text": text, 36 | "speaker_id": speaker_id, 37 | "audio_type": audio_type, 38 | "length_scale": length_scale, 39 | "noise_scale": noise_scale, 40 | "noise_scale_w": noise_scale_w, 41 | "load_prefer": load_prefer 42 | } 43 | if auto_parse: 44 | url += "?auto_parse=True" 45 | res = requests.post(url, json=data, stream=True) 46 | return res 47 | 48 | 49 | if __name__ == "__main__": 50 | client = VITS("http://127.0.0.1:9557") 51 | res = client.get_model_list(show_speaker=True, show_ms_config=True) 52 | print(res) 53 | 54 | res = client.get_model_info(model_id="model_01") 55 | print(res) 56 | 57 | res = client.parse_text(text="Hello world!") 58 | print(res) 59 | 60 | res = client.generate_voice(model_id="model_01", text="你好,世界!", speaker_id=0, audio_type="wav", 61 | length_scale=1.0, noise_scale=0.0, noise_scale_w=0.0) 62 | with open("output.wav", "wb") as f: 63 | for chunk in res.iter_content(chunk_size=1024): 64 | if chunk: 65 | f.write(chunk) 66 | -------------------------------------------------------------------------------- /component/langdetect_fasttext/detect.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import os 4 | from typing import Dict, Union 5 | 6 | import fasttext 7 | import requests 8 | 9 | logger = logging.getLogger(__name__) 10 | models = {"low_mem": None, "high_mem": None} 11 | FTLANG_CACHE = os.getenv("FTLANG_CACHE", "/tmp/fasttext-langdetect") 12 | 13 | try: 14 | # silences warnings as the package does not properly use the python 'warnings' package 15 | # see https://github.com/facebookresearch/fastText/issues/1056 16 | fasttext.FastText.eprint = lambda *args, **kwargs: None 17 | except Exception as e: 18 | pass 19 | 20 | 21 | def check_model(name: str) -> bool: 22 | # 查看模型是否为 0 字节 23 | target_path = os.path.join(FTLANG_CACHE, name) 24 | if os.path.exists(target_path): 25 | if os.path.getsize(target_path) > 0: 26 | return True 27 | else: 28 | # 0 字节,删除 29 | os.remove(target_path) 30 | return False 31 | 32 | 33 | def download_model(name: str) -> str: 34 | target_path = os.path.join(FTLANG_CACHE, name) 35 | if not os.path.exists(target_path): 36 | logger.info(f"Downloading {name} model ...") 37 | url = f"https://dl.fbaipublicfiles.com/fasttext/supervised-models/{name}" # noqa 38 | os.makedirs(FTLANG_CACHE, exist_ok=True) 39 | with open(target_path, "wb") as fp: 40 | response = requests.get(url) 41 | fp.write(response.content) 42 | logger.info(f"Downloaded.") 43 | return target_path 44 | 45 | 46 | def get_or_load_model(low_memory=False): 47 | if low_memory: 48 | model = models.get("low_mem", None) 49 | if not model: 50 | check_model("lid.176.ftz") 51 | model_path = download_model("lid.176.ftz") 52 | model = fasttext.load_model(model_path) 53 | models["low_mem"] = model 54 | return model 55 | else: 56 | model = models.get("high_mem", None) 57 | if not model: 58 | check_model("lid.176.ftz") 59 | model_path = download_model("lid.176.bin") 60 | model = fasttext.load_model(model_path) 61 | models["high_mem"] = model 62 | return model 63 | 64 | 65 | def detect(text: str, low_memory=False) -> Dict[str, Union[str, float]]: 66 | model = get_or_load_model(low_memory) 67 | labels, scores = model.predict(text) 68 | label = labels[0].replace("__label__", '') 69 | score = min(float(scores[0]), 1.0) 70 | return { 71 | "lang": label, 72 | "score": score, 73 | } 74 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Set default behavior to automatically normalize line endings. 3 | ############################################################################### 4 | * text=auto 5 | 6 | ############################################################################### 7 | # Set default behavior for command prompt diff. 8 | # 9 | # This is need for earlier builds of msysgit that does not have it on by 10 | # default for csharp files. 11 | # Note: This is only used by command line 12 | ############################################################################### 13 | #*.cs diff=csharp 14 | 15 | ############################################################################### 16 | # Set the merge driver for project and solution files 17 | # 18 | # Merging from the command prompt will add diff markers to the files if there 19 | # are conflicts (Merging from VS is not affected by the settings below, in VS 20 | # the diff markers are never inserted). Diff markers may cause the following 21 | # file extensions to fail to load in VS. An alternative would be to treat 22 | # these files as binary and thus will always conflict and require user 23 | # intervention with every merge. To do so, just uncomment the entries below 24 | ############################################################################### 25 | #*.sln merge=binary 26 | #*.csproj merge=binary 27 | #*.vbproj merge=binary 28 | #*.vcxproj merge=binary 29 | #*.vcproj merge=binary 30 | #*.dbproj merge=binary 31 | #*.fsproj merge=binary 32 | #*.lsproj merge=binary 33 | #*.wixproj merge=binary 34 | #*.modelproj merge=binary 35 | #*.sqlproj merge=binary 36 | #*.wwaproj merge=binary 37 | 38 | ############################################################################### 39 | # behavior for image files 40 | # 41 | # image files are treated as binary by default. 42 | ############################################################################### 43 | #*.jpg binary 44 | #*.png binary 45 | #*.gif binary 46 | 47 | ############################################################################### 48 | # diff behavior for common document formats 49 | # 50 | # Convert binary document formats to text before diffing them. This feature 51 | # is only available from the command line. Turn it on by uncommenting the 52 | # entries below. 53 | ############################################################################### 54 | #*.doc diff=astextplain 55 | #*.DOC diff=astextplain 56 | #*.docx diff=astextplain 57 | #*.DOCX diff=astextplain 58 | #*.dot diff=astextplain 59 | #*.DOT diff=astextplain 60 | #*.pdf diff=astextplain 61 | #*.PDF diff=astextplain 62 | #*.rtf diff=astextplain 63 | #*.RTF diff=astextplain 64 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/8 下午4:02 3 | # @Author : sudoskys 4 | # @File : main.py.py 5 | # @Software: PyCharm 6 | import os 7 | import pathlib 8 | 9 | import requests 10 | import uvicorn 11 | from dotenv import load_dotenv 12 | from loguru import logger 13 | from pydantic import BaseModel 14 | from tqdm import tqdm 15 | 16 | 17 | def download_file(folder_path, file_name, url, max_retries=3): 18 | # 拼接 19 | file_path = os.path.join(folder_path, file_name) 20 | with requests.get(url, stream=True) as response: 21 | response.raise_for_status() 22 | total_size = int(response.headers.get('content-length', 0)) 23 | block_size = 8192 24 | progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True) 25 | with open(file_path, 'wb') as file: 26 | for chunk in response.iter_content(chunk_size=block_size): 27 | if not chunk: 28 | break 29 | file.write(chunk) 30 | progress_bar.update(len(chunk)) 31 | progress_bar.close() 32 | if os.path.getsize(file_path) == total_size: 33 | logger.success(f"初始化模型下载成功: {file_path}") 34 | return True 35 | else: 36 | os.remove(file_path) 37 | if max_retries > 0: 38 | return download_file(folder_path, file_name, url, max_retries - 1) 39 | return False 40 | 41 | 42 | # Run 43 | 44 | class FastApiConf(BaseModel): 45 | reload: bool = False 46 | host: str = "127.0.0.1" 47 | port: int = 9557 48 | workers: int = 1 49 | 50 | 51 | # Load environment variables from .env file 52 | load_dotenv() 53 | 54 | host = str(os.environ.get('VITS_SERVER_HOST', "0.0.0.0")) 55 | port = int(os.environ.get('VITS_SERVER_PORT', 9557)) 56 | reload = os.environ.get('VITS_SERVER_RELOAD', False) == 'true' 57 | workers = int(os.environ.get('VITS_SERVER_WORKERS', 1)) 58 | FastApi = FastApiConf(reload=reload, host=host, port=port, workers=workers) 59 | 60 | init_model = os.environ.get('VITS_SERVER_INIT_MODEL', None) 61 | init_config = os.environ.get('VITS_SERVER_INIT_CONFIG', None) 62 | 63 | # 查询是否存在init模型在路径下 64 | if not pathlib.Path("model/init.json").exists() and init_config: 65 | download_file("model", "init.json", init_config) 66 | 67 | # 获得文件链接的文件后缀 68 | if init_model: 69 | file_name = os.path.basename(init_model) 70 | file_ext = os.path.splitext(file_name)[-1] 71 | if not pathlib.Path(f"model/init{file_ext}").exists(): 72 | download_file("model", f"init{file_ext}", init_model) 73 | 74 | if FastApi.reload: 75 | logger.warning("reload 参数有内容修改自动重启服务器,启用可能导致连续重启导致 CPU 满载") 76 | 77 | if __name__ == '__main__': 78 | uvicorn.run('server:app', 79 | host=FastApi.host, 80 | port=FastApi.port, 81 | reload=FastApi.reload, 82 | log_level="debug", 83 | workers=FastApi.workers 84 | ) 85 | -------------------------------------------------------------------------------- /docs/sdk.js: -------------------------------------------------------------------------------- 1 | class VITS { 2 | constructor(base_url) { 3 | this.base_url = base_url; 4 | } 5 | 6 | async get_model_list(show_speaker = false, show_ms_config = false) { 7 | const url = `${this.base_url}/model/list?show_speaker=${show_speaker}&show_ms_config=${show_ms_config}`; 8 | const res = await fetch(url); 9 | return await res.json(); 10 | } 11 | 12 | async get_model_info(model_id) { 13 | const url = `${this.base_url}/model/info?model_id=${model_id}`; 14 | const res = await fetch(url); 15 | return await res.json(); 16 | } 17 | 18 | async parse_text(text, strip = false, merge_same = false, cell_limit = 140, filter_space = true) { 19 | const url = `${this.base_url}/tts/parse`; 20 | const data = { 21 | text, 22 | strip, 23 | merge_same, 24 | cell_limit, 25 | filter_space 26 | }; 27 | const res = await fetch(url, { 28 | method: 'POST', 29 | body: JSON.stringify(data), 30 | headers: { 31 | 'Content-Type': 'application/json' 32 | } 33 | }); 34 | return await res.json(); 35 | } 36 | 37 | async generate_voice(model_id, text, speaker_id = 0, audio_type = "wav", length_scale = 1.0, 38 | noise_scale = 0.5, noise_scale_w = 0.7, load_prefer = false, auto_parse = true) { 39 | let url = `${this.base_url}/tts/generate`; 40 | const data = { 41 | model_id, 42 | text, 43 | speaker_id, 44 | audio_type, 45 | length_scale, 46 | noise_scale, 47 | noise_scale_w, 48 | load_prefer 49 | }; 50 | if (auto_parse) { 51 | url += "?auto_parse=True"; 52 | } 53 | const res = await fetch(url, { 54 | method: 'POST', 55 | body: JSON.stringify(data), 56 | headers: { 57 | 'Content-Type': 'application/json' 58 | }, 59 | responseType: 'blob' 60 | }); 61 | return res.data; 62 | } 63 | } 64 | 65 | /* 66 | const client = new VITS("http://0.0.0.0:9557"); 67 | const memory = await client.get_memory(); 68 | console.log(memory); 69 | 70 | const modelList = await client.get_model_list(show_speaker=true, show_ms_config=true); 71 | console.log(modelList); 72 | 73 | const modelInfo = await client.get_model_info(model_id="model_01"); 74 | console.log(modelInfo); 75 | 76 | const parsedText = await client.parse_text(text="Hello world!"); 77 | console.log(parsedText); 78 | 79 | const voiceBlob = await client.generate_voice(model_id="model_01", text="[ZH]你好,世界!", speaker_id=0, audio_type="wav", 80 | length_scale=1.0, noise_scale=0.5, noise_scale_w=0.6); 81 | const voiceUrl = URL.createObjectURL(voiceBlob); 82 | const audio = new Audio(); 83 | audio.src = voiceUrl; 84 | audio.play(); 85 | * */ -------------------------------------------------------------------------------- /onnx_infer/infer/commons.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.jit 4 | from torch.nn import functional as F 5 | 6 | 7 | def script_method(fn, _rcb=None): 8 | return fn 9 | 10 | 11 | def script(obj, optimize=True, _frames_up=0, _rcb=None): 12 | return obj 13 | 14 | 15 | torch.jit.script_method = script_method 16 | torch.jit.script = script 17 | 18 | 19 | def init_weights(m, mean=0.0, std=0.01): 20 | classname = m.__class__.__name__ 21 | if classname.find("Conv") != -1: 22 | m.weight.data.normal_(mean, std) 23 | 24 | 25 | def get_padding(kernel_size, dilation=1): 26 | return int((kernel_size * dilation - dilation) / 2) 27 | 28 | 29 | def intersperse(lst, item): 30 | result = [item] * (len(lst) * 2 + 1) 31 | result[1::2] = lst 32 | return result 33 | 34 | 35 | def slice_segments(x, ids_str, segment_size=4): 36 | ret = torch.zeros_like(x[:, :, :segment_size]) 37 | for i in range(x.size(0)): 38 | idx_str = ids_str[i] 39 | idx_end = idx_str + segment_size 40 | ret[i] = x[i, :, idx_str:idx_end] 41 | return ret 42 | 43 | 44 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 45 | b, d, t = x.size() 46 | if x_lengths is None: 47 | x_lengths = t 48 | ids_str_max = x_lengths - segment_size + 1 49 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 50 | ret = slice_segments(x, ids_str, segment_size) 51 | return ret, ids_str 52 | 53 | 54 | def subsequent_mask(length): 55 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 56 | return mask 57 | 58 | 59 | @torch.jit.script 60 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 61 | n_channels_int = n_channels[0] 62 | in_act = input_a + input_b 63 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 64 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 65 | acts = t_act * s_act 66 | return acts 67 | 68 | 69 | def convert_pad_shape(pad_shape): 70 | l = pad_shape[::-1] 71 | pad_shape = [item for sublist in l for item in sublist] 72 | return pad_shape 73 | 74 | 75 | def sequence_mask(length, max_length=None): 76 | if max_length is None: 77 | max_length = length.max() 78 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 79 | return x.unsqueeze(0) < length.unsqueeze(1) 80 | 81 | 82 | def generate_path(duration, mask): 83 | """ 84 | duration: [b, 1, t_x] 85 | mask: [b, 1, t_y, t_x] 86 | """ 87 | device = duration.device 88 | 89 | b, _, t_y, t_x = mask.shape 90 | cum_duration = torch.cumsum(duration, -1) 91 | 92 | cum_duration_flat = cum_duration.view(b * t_x) 93 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 94 | path = path.view(b, t_x, t_y) 95 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 96 | path = path.unsqueeze(1).transpose(2, 3) * mask 97 | return path 98 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from json import loads 3 | 4 | import dotenv 5 | import librosa 6 | import torch 7 | from loguru import logger 8 | from numpy import float32 9 | from torch import load, FloatTensor 10 | 11 | 12 | def get_device(by_torch: bool = True): 13 | dotenv.load_dotenv() 14 | if torch.cuda.is_available(): 15 | logger.info("GPU Is Available!") 16 | infer_device = "gpu" 17 | if by_torch: 18 | infer_device = "cuda" 19 | only_cpu = os.environ.get('VITS_DISABLE_GPU', False) == 'true' 20 | if only_cpu: 21 | infer_device = "cpu" 22 | else: 23 | infer_device = "cpu" 24 | return infer_device 25 | 26 | 27 | DEVICE = get_device() 28 | 29 | 30 | class HParams(object): 31 | def __init__(self, **kwargs): 32 | for k, v in kwargs.items(): 33 | if type(v) == dict: 34 | v = HParams(**v) 35 | self[k] = v 36 | 37 | def keys(self): 38 | return self.__dict__.keys() 39 | 40 | def items(self): 41 | return self.__dict__.items() 42 | 43 | def values(self): 44 | return self.__dict__.values() 45 | 46 | def __len__(self): 47 | return len(self.__dict__) 48 | 49 | def __getitem__(self, key): 50 | return getattr(self, key) 51 | 52 | def __setitem__(self, key, value): 53 | return setattr(self, key, value) 54 | 55 | def __contains__(self, key): 56 | return key in self.__dict__ 57 | 58 | def __repr__(self): 59 | return self.__dict__.__repr__() 60 | 61 | 62 | def load_checkpoint(checkpoint_path, model, optimizer=None): 63 | assert os.path.isfile(checkpoint_path) 64 | checkpoint_dict = load(checkpoint_path, map_location='cpu') 65 | iteration = checkpoint_dict['iteration'] 66 | learning_rate = checkpoint_dict['learning_rate'] 67 | if optimizer is not None: 68 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 69 | saved_state_dict = checkpoint_dict['model'] 70 | if hasattr(model, 'module'): 71 | state_dict = model.module.state_dict() 72 | else: 73 | state_dict = model.state_dict() 74 | new_state_dict = {} 75 | for k, v in state_dict.items(): 76 | try: 77 | new_state_dict[k] = saved_state_dict[k] 78 | except: 79 | logger.info("%s is not in the checkpoint" % k) 80 | new_state_dict[k] = v 81 | if hasattr(model, 'module'): 82 | model.module.load_state_dict(new_state_dict) 83 | else: 84 | model.load_state_dict(new_state_dict) 85 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 86 | checkpoint_path, iteration)) 87 | return model, optimizer, learning_rate, iteration 88 | 89 | 90 | def get_hparams_from_file(config_path): 91 | with open(config_path, "r") as f: 92 | data = f.read() 93 | config = loads(data) 94 | 95 | hparams = HParams(**config) 96 | return hparams 97 | 98 | 99 | def load_audio_to_torch(full_path, target_sampling_rate): 100 | audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True) 101 | return FloatTensor(audio.astype(float32)) 102 | -------------------------------------------------------------------------------- /component/nlp_utils/detect.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from loguru import logger 3 | 4 | 5 | class DetectSentence(object): 6 | """ 7 | 检测句子 8 | """ 9 | 10 | @staticmethod 11 | def detect_language(sentence: str): 12 | """ 13 | Detect language 14 | :param sentence: sentence 15 | :return: 两位大写语言代码 (EN, ZH, JA, KO, FR, DE, ES, ....) 16 | """ 17 | # 如果全部是空格 18 | if sentence.isspace() or not sentence: 19 | return "" 20 | 21 | # 如果全部是标点 22 | try: 23 | from .. import langdetect_fasttext 24 | lang_type = langdetect_fasttext.detect(text=sentence.replace("\n", "").replace("\r", ""), 25 | low_memory=True).get("lang").upper() 26 | 27 | def is_japanese(string): 28 | for ch in string: 29 | if 0x3040 < ord(ch) < 0x30FF: 30 | return True 31 | return False 32 | 33 | if lang_type == "JA" and not is_japanese(sentence): 34 | lang_type = "ZH" 35 | except Exception as e: 36 | # handle error 37 | logger.error(e) 38 | raise e 39 | return lang_type 40 | 41 | @staticmethod 42 | def detect_help(sentence: str) -> bool: 43 | """ 44 | 检测是否是包含帮助要求,如果是,返回True,否则返回False 45 | """ 46 | _check = ['怎么做', 'How', 'how', 'what', 'What', 'Why', 'why', '复述', '复读', '要求你', '原样', '例子', 47 | '解释', 'exp', '推荐', '说出', '写出', '如何实现', '代码', '写', 'give', 'Give', 48 | '请把', '请给', '请写', 'help', 'Help', '写一', 'code', '如何做', '帮我', '帮助我', '请给我', '什么', 49 | '为何', '给建议', '给我', '给我一些', '请教', '建议', '怎样', '如何', '怎么样', 50 | '为什么', '帮朋友', '怎么', '需要什么', '注意什么', '怎么办', '助け', '何を', 'なぜ', '教えて', '提案', 51 | '何が', '何に', 52 | '何をす', '怎麼做', '複述', '復讀', '原樣', '解釋', '推薦', '說出', '寫出', '如何實現', '代碼', '寫', 53 | '請把', '請給', '請寫', '寫一', '幫我', '幫助我', '請給我', '什麼', '為何', '給建議', '給我', 54 | '給我一些', '請教', '建議', '步驟', '怎樣', '怎麼樣', '為什麼', '幫朋友', '怎麼', '需要什麼', 55 | '註意什麼', '怎麼辦'] 56 | for item in _check: 57 | if item in sentence: 58 | return True 59 | return False 60 | 61 | @staticmethod 62 | def detect_code(sentence) -> bool: 63 | """ 64 | Detect code,if code return True,else return False 65 | :param sentence: sentence 66 | :return: bool 67 | """ 68 | code = False 69 | _reco = [ 70 | '("', 71 | '")', 72 | ").", 73 | "()", 74 | "!=", 75 | "==", 76 | ] 77 | _t = len(_reco) 78 | _r = 0 79 | for i in _reco: 80 | if i in sentence: 81 | _r += 1 82 | if _r > _t / 2: 83 | code = True 84 | rms = [ 85 | "```", 86 | "import " 87 | "print_r(", 88 | "var_dump(", 89 | 'NSLog( @', 90 | 'println(', 91 | '.log(', 92 | 'print(', 93 | 'printf(', 94 | 'WriteLine(', 95 | '.Println(', 96 | '.Write(', 97 | 'alert(', 98 | 'echo(', 99 | ] 100 | for i in rms: 101 | if i in sentence: 102 | code = True 103 | return code 104 | -------------------------------------------------------------------------------- /mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | 6 | MAX_WAV_VALUE = 32768.0 7 | 8 | 9 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 10 | """ 11 | PARAMS 12 | ------ 13 | C: compression factor 14 | """ 15 | return torch.log(torch.clamp(x, min=clip_val) * C) 16 | 17 | 18 | def dynamic_range_decompression_torch(x, C=1): 19 | """ 20 | PARAMS 21 | ------ 22 | C: compression factor used to compress 23 | """ 24 | return torch.exp(x) / C 25 | 26 | 27 | def spectral_normalize_torch(magnitudes): 28 | output = dynamic_range_compression_torch(magnitudes) 29 | return output 30 | 31 | 32 | def spectral_de_normalize_torch(magnitudes): 33 | output = dynamic_range_decompression_torch(magnitudes) 34 | return output 35 | 36 | 37 | mel_basis = {} 38 | hann_window = {} 39 | 40 | 41 | def spectrogram_torch(y, 42 | n_fft, 43 | sampling_rate, 44 | hop_size, 45 | win_size, 46 | center=False): 47 | if torch.min(y) < -1.: 48 | print('min value is ', torch.min(y)) 49 | if torch.max(y) > 1.: 50 | print('max value is ', torch.max(y)) 51 | 52 | global hann_window 53 | dtype_device = str(y.dtype) + '_' + str(y.device) 54 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 55 | if wnsize_dtype_device not in hann_window: 56 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 57 | dtype=y.dtype, device=y.device) 58 | 59 | y = F.pad(y.unsqueeze(1), 60 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 61 | mode='reflect') 62 | y = y.squeeze(1) 63 | 64 | spec = torch.stft(y, 65 | n_fft, 66 | hop_length=hop_size, 67 | win_length=win_size, 68 | window=hann_window[wnsize_dtype_device], 69 | center=center, 70 | pad_mode='reflect', 71 | normalized=False, 72 | onesided=True) 73 | 74 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 75 | return spec 76 | 77 | 78 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 79 | global mel_basis 80 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 81 | fmax_dtype_device = str(fmax) + '_' + dtype_device 82 | if fmax_dtype_device not in mel_basis: 83 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 84 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 85 | dtype=spec.dtype, device=spec.device) 86 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 87 | spec = spectral_normalize_torch(spec) 88 | return spec 89 | 90 | 91 | def mel_spectrogram_torch(y, 92 | n_fft, 93 | num_mels, 94 | sampling_rate, 95 | hop_size, 96 | win_size, 97 | fmin, 98 | fmax, 99 | center=False): 100 | if torch.min(y) < -1.: 101 | print('min value is ', torch.min(y)) 102 | if torch.max(y) > 1.: 103 | print('max value is ', torch.max(y)) 104 | 105 | global mel_basis, hann_window 106 | dtype_device = str(y.dtype) + '_' + str(y.device) 107 | fmax_dtype_device = str(fmax) + '_' + dtype_device 108 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 109 | if fmax_dtype_device not in mel_basis: 110 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 111 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( 112 | dtype=y.dtype, device=y.device) 113 | if wnsize_dtype_device not in hann_window: 114 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( 115 | dtype=y.dtype, device=y.device) 116 | 117 | y = F.pad(y.unsqueeze(1), 118 | (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), 119 | mode='reflect') 120 | y = y.squeeze(1) 121 | 122 | spec = torch.stft(y, 123 | n_fft, 124 | hop_length=hop_size, 125 | win_length=win_size, 126 | window=hann_window[wnsize_dtype_device], 127 | center=center, 128 | pad_mode='reflect', 129 | normalized=False, 130 | onesided=True) 131 | 132 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 133 | 134 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 135 | spec = spectral_normalize_torch(spec) 136 | 137 | return spec 138 | -------------------------------------------------------------------------------- /text/japanese.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pyopenjtalk 4 | from unidecode import unidecode 5 | 6 | # Regular expression matching Japanese without punctuation marks: 7 | _japanese_characters = re.compile( 8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 9 | 10 | # Regular expression matching non-Japanese characters or punctuation marks: 11 | _japanese_marks = re.compile( 12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 13 | 14 | # List of (symbol, Japanese) pairs for marks: 15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ 16 | ('%', 'パーセント') 17 | ]] 18 | 19 | # List of (romaji, ipa) pairs for marks: 20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 21 | ('ts', 'ʦ'), 22 | ('u', 'ɯ'), 23 | ('j', 'ʥ'), 24 | ('y', 'j'), 25 | ('ni', 'n^i'), 26 | ('nj', 'n^'), 27 | ('hi', 'çi'), 28 | ('hj', 'ç'), 29 | ('f', 'ɸ'), 30 | ('I', 'i*'), 31 | ('U', 'ɯ*'), 32 | ('r', 'ɾ') 33 | ]] 34 | 35 | # List of (romaji, ipa2) pairs for marks: 36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 37 | ('u', 'ɯ'), 38 | ('ʧ', 'tʃ'), 39 | ('j', 'dʑ'), 40 | ('y', 'j'), 41 | ('ni', 'n^i'), 42 | ('nj', 'n^'), 43 | ('hi', 'çi'), 44 | ('hj', 'ç'), 45 | ('f', 'ɸ'), 46 | ('I', 'i*'), 47 | ('U', 'ɯ*'), 48 | ('r', 'ɾ') 49 | ]] 50 | 51 | # List of (consonant, sokuon) pairs: 52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 53 | (r'Q([↑↓]*[kg])', r'k#\1'), 54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'), 55 | (r'Q([↑↓]*[sʃ])', r's\1'), 56 | (r'Q([↑↓]*[pb])', r'p#\1') 57 | ]] 58 | 59 | # List of (consonant, hatsuon) pairs: 60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 61 | (r'N([↑↓]*[pbm])', r'm\1'), 62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'), 63 | (r'N([↑↓]*[tdn])', r'n\1'), 64 | (r'N([↑↓]*[kg])', r'ŋ\1') 65 | ]] 66 | 67 | 68 | def symbols_to_japanese(text): 69 | for regex, replacement in _symbols_to_japanese: 70 | text = re.sub(regex, replacement, text) 71 | return text 72 | 73 | 74 | def japanese_to_romaji_with_accent(text): 75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' 76 | text = symbols_to_japanese(text) 77 | sentences = re.split(_japanese_marks, text) 78 | marks = re.findall(_japanese_marks, text) 79 | text = '' 80 | for i, sentence in enumerate(sentences): 81 | if re.match(_japanese_characters, sentence): 82 | if text != '': 83 | text += ' ' 84 | labels = pyopenjtalk.extract_fullcontext(sentence) 85 | for n, label in enumerate(labels): 86 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1) 87 | if phoneme not in ['sil', 'pau']: 88 | text += phoneme.replace('ch', 'ʧ').replace('sh', 89 | 'ʃ').replace('cl', 'Q') 90 | else: 91 | continue 92 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) 93 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) 94 | a2 = int(re.search(r"\+(\d+)\+", label).group(1)) 95 | a3 = int(re.search(r"\+(\d+)/", label).group(1)) 96 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: 97 | a2_next = -1 98 | else: 99 | a2_next = int( 100 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) 101 | # Accent phrase boundary 102 | if a3 == 1 and a2_next == 1: 103 | text += ' ' 104 | # Falling 105 | elif a1 == 0 and a2_next == a2 + 1: 106 | text += '↓' 107 | # Rising 108 | elif a2 == 1 and a2_next == 2: 109 | text += '↑' 110 | if i < len(marks): 111 | text += unidecode(marks[i]).replace(' ', '') 112 | return text 113 | 114 | 115 | def get_real_sokuon(text): 116 | for regex, replacement in _real_sokuon: 117 | text = re.sub(regex, replacement, text) 118 | return text 119 | 120 | 121 | def get_real_hatsuon(text): 122 | for regex, replacement in _real_hatsuon: 123 | text = re.sub(regex, replacement, text) 124 | return text 125 | 126 | 127 | def japanese_to_ipa(text): 128 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 129 | text = re.sub( 130 | r'([aiueo])\1+', lambda x: x.group(0)[0] + 'ː' * (len(x.group(0)) - 1), text) 131 | text = get_real_sokuon(text) 132 | text = get_real_hatsuon(text) 133 | for regex, replacement in _romaji_to_ipa: 134 | text = re.sub(regex, replacement, text) 135 | return text 136 | 137 | 138 | def japanese_to_ipa2(text): 139 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 140 | text = get_real_sokuon(text) 141 | text = get_real_hatsuon(text) 142 | for regex, replacement in _romaji_to_ipa2: 143 | text = re.sub(regex, replacement, text) 144 | return text 145 | 146 | 147 | def japanese_to_ipa3(text): 148 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace( 149 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a') 150 | text = re.sub( 151 | r'([aiɯeo])\1+', lambda x: x.group(0)[0] + 'ː' * (len(x.group(0)) - 1), text) 152 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text) 153 | return text 154 | -------------------------------------------------------------------------------- /component/warp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/5 下午10:13 3 | # @Author : sudoskys 4 | # @File : warp.py 5 | # @Software: PyCharm 6 | from typing import List 7 | 8 | from component.nlp_utils.cut import Cut 9 | from component.nlp_utils.detect import DetectSentence 10 | 11 | Detector = DetectSentence() 12 | 13 | 14 | class Parse(object): 15 | @staticmethod 16 | def merge_cell(result: List[dict]) -> List[dict]: 17 | """ 18 | 合并句子 19 | :param result: 20 | :return: 21 | """ 22 | _merged = [] 23 | _cache = [] 24 | last_lang = None 25 | for _result in result: 26 | if _result["lang"] == last_lang: 27 | _cache.append(_result["text"]) 28 | else: 29 | if _cache: 30 | # 计算列表内文本长度 31 | _length = sum([len(_c) for _c in _cache]) 32 | _merged.append({"text": "".join(_cache), "lang": last_lang, "length": _length}) 33 | _cache = [_result["text"]] 34 | last_lang = _result["lang"] 35 | if _cache: 36 | _length = sum([len(_c) for _c in _cache]) 37 | _merged.append({"text": "".join(_cache), "lang": last_lang, "length": _length}) 38 | return _merged 39 | 40 | def create_cell(self, 41 | sentence: str, 42 | merge_same: bool = True, 43 | cell_limit: int = 150, 44 | filter_space: bool = True 45 | ) -> list: 46 | """ 47 | 分句,识别语言 48 | :param sentence: 句子 49 | :param merge_same: 是否合并相同语言的句子 50 | :param cell_limit: 单元最大长度 51 | :return: 52 | """ 53 | cut = Cut() 54 | cut_list = cut.chinese_sentence_cut(sentence) 55 | _cut_list = [] 56 | for _cut in cut_list: 57 | if len(_cut) > cell_limit: 58 | _text_list = [_cut[i:i + cell_limit] for i in range(0, len(_cut), cell_limit)] 59 | _cut_list.extend(_text_list) 60 | else: 61 | _cut_list.append(_cut) 62 | # 为每个句子标排语言 63 | _result = [] 64 | for _cut in _cut_list: 65 | _lang = Detector.detect_language(_cut) 66 | if not filter_space: 67 | _result.append({"text": _cut, "lang": _lang, "length": len(_cut)}) 68 | else: 69 | if _lang: 70 | _result.append({"text": _cut, "lang": _lang, "length": len(_cut)}) 71 | if merge_same: 72 | _result = self.merge_cell(_result) 73 | return _result 74 | 75 | def build_sentence(self, 76 | sentence_cell: List[dict], 77 | strip: bool = False 78 | ): 79 | # 生成句子 80 | _sentence = [] 81 | for _cell in sentence_cell: 82 | _text = _cell.get('text') 83 | _lang = _cell.get('lang') 84 | if _lang: 85 | if strip: 86 | _sentence.append(f"[{_lang}]{_text.strip()}[{_lang}]") 87 | else: 88 | _sentence.append(f"[{_lang}]{_text}[{_lang}]") 89 | return _sentence 90 | 91 | def pack_up_task(self, 92 | sentence_cell: List[dict], 93 | task_limit: int = 150, 94 | strip: bool = False 95 | ): 96 | """ 97 | 打包单元 98 | :param sentence_cell: 单元列表 99 | :param task_limit: 任务最大长度 100 | :param strip: 是否去除空格 101 | :return: 102 | """ 103 | _task_list = [] 104 | _task = [] 105 | _task_length = 0 106 | for _cell in sentence_cell: 107 | _text = _cell.get('text') 108 | _lang = _cell.get('lang') 109 | _length = _cell.get('length') 110 | if _lang: 111 | if _task_length + _length > task_limit: 112 | _task_list.append(self.build_sentence(_task, strip=strip)) 113 | _task = [] 114 | _task_length = 0 115 | _task.append(_cell) 116 | _task_length += _length 117 | if _task: 118 | _task_list.append(self.build_sentence(_task, strip=strip)) 119 | return _task_list 120 | 121 | 122 | if __name__ == '__main__': 123 | text = """ 124 | 1. 今天是个晴朗的日子,阳光明媚,空气清新。我打算去公园散步,享受这美好的一天。翻译:Today is a sunny day with bright sunshine and fresh air. I plan to take a walk in the park and enjoy this beautiful day. 125 | 126 | 2. 무엇인가를 생각하면 답답하거나 짜증나지 않고 미소 머금을 수 있는 하루였으면 좋겠습니다. 翻译:I hope to have a day where I can smile instead of feeling frustrated or annoyed when thinking about something. 127 | 128 | 3. 早上好的韩文翻译是:짜증나지. 翻译:The Korean translation for "good morning" is "안녕하세요" (annyeonghaseyo). 129 | 130 | 饮食男女,人之大欲也。性格天生,各有千秋。不可移易之物,爱恨情仇皆起于此。人生苦短,何必怀旧?前车之鉴,后事之师。日出而作,日落而息,勤奋正是成功之母。忍耐是一种美德,约束自己才能超越自己。道德是社会文明之基石,诚信是立身之本。文章合适,不在长短,在于内容。行路难,始知世间艰辛,但愿人长久,千里共婵娟。岁月匆匆,光阴如箭,唯有心存善念,方能始终如一。 131 | 132 | 日本語が話せますので、何かお手伝いできることがありましたら、遠慮なくお申し付けください。日本には美しい自然と文化がたくさんあります。桜や紅葉の季節には多くの人々が訪れ、素晴らしい景色を楽しんでいます。また、日本の食文化も世界的に有名で、寿司やラーメン、うどんなど様々な料理があります。日本語は少し難しい言語かもしれませんが、練習することで上達します。一緒に頑張りましょう! 133 | """ 134 | import time 135 | 136 | time1 = time.time() 137 | parse = Parse() 138 | sentence_cell = parse.create_cell(text, merge_same=False, cell_limit=140) 139 | print(sentence_cell) 140 | res = parse.build_sentence(sentence_cell=sentence_cell, strip=True) 141 | print(res) 142 | res2 = parse.pack_up_task(sentence_cell=sentence_cell, task_limit=140, strip=True) 143 | print(res2) 144 | print(res2[0]) 145 | time2 = time.time() 146 | print(time2 - time1) 147 | -------------------------------------------------------------------------------- /text/english.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | # Regular expression matching whitespace: 16 | 17 | 18 | import re 19 | 20 | import eng_to_ipa as ipa 21 | import inflect 22 | from unidecode import unidecode 23 | 24 | _inflect = inflect.engine() 25 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 26 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 27 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 28 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 29 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 30 | _number_re = re.compile(r'[0-9]+') 31 | 32 | # List of (regular expression, replacement) pairs for abbreviations: 33 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 34 | ('mrs', 'misess'), 35 | ('mr', 'mister'), 36 | ('dr', 'doctor'), 37 | ('st', 'saint'), 38 | ('co', 'company'), 39 | ('jr', 'junior'), 40 | ('maj', 'major'), 41 | ('gen', 'general'), 42 | ('drs', 'doctors'), 43 | ('rev', 'reverend'), 44 | ('lt', 'lieutenant'), 45 | ('hon', 'honorable'), 46 | ('sgt', 'sergeant'), 47 | ('capt', 'captain'), 48 | ('esq', 'esquire'), 49 | ('ltd', 'limited'), 50 | ('col', 'colonel'), 51 | ('ft', 'fort'), 52 | ]] 53 | 54 | # List of (ipa, lazy ipa) pairs: 55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 56 | ('r', 'ɹ'), 57 | ('æ', 'e'), 58 | ('ɑ', 'a'), 59 | ('ɔ', 'o'), 60 | ('ð', 'z'), 61 | ('θ', 's'), 62 | ('ɛ', 'e'), 63 | ('ɪ', 'i'), 64 | ('ʊ', 'u'), 65 | ('ʒ', 'ʥ'), 66 | ('ʤ', 'ʥ'), 67 | ('ˈ', '↓'), 68 | ]] 69 | 70 | # List of (ipa, lazy ipa2) pairs: 71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 72 | ('r', 'ɹ'), 73 | ('ð', 'z'), 74 | ('θ', 's'), 75 | ('ʒ', 'ʑ'), 76 | ('ʤ', 'dʑ'), 77 | ('ˈ', '↓'), 78 | ]] 79 | 80 | # List of (ipa, ipa2) pairs 81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 82 | ('r', 'ɹ'), 83 | ('ʤ', 'dʒ'), 84 | ('ʧ', 'tʃ') 85 | ]] 86 | 87 | 88 | def expand_abbreviations(text): 89 | for regex, replacement in _abbreviations: 90 | text = re.sub(regex, replacement, text) 91 | return text 92 | 93 | 94 | def collapse_whitespace(text): 95 | return re.sub(r'\s+', ' ', text) 96 | 97 | 98 | def _remove_commas(m): 99 | return m.group(1).replace(',', '') 100 | 101 | 102 | def _expand_decimal_point(m): 103 | return m.group(1).replace('.', ' point ') 104 | 105 | 106 | def _expand_dollars(m): 107 | match = m.group(1) 108 | parts = match.split('.') 109 | if len(parts) > 2: 110 | return match + ' dollars' # Unexpected format 111 | dollars = int(parts[0]) if parts[0] else 0 112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 113 | if dollars and cents: 114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 115 | cent_unit = 'cent' if cents == 1 else 'cents' 116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 117 | elif dollars: 118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 119 | return '%s %s' % (dollars, dollar_unit) 120 | elif cents: 121 | cent_unit = 'cent' if cents == 1 else 'cents' 122 | return '%s %s' % (cents, cent_unit) 123 | else: 124 | return 'zero dollars' 125 | 126 | 127 | def _expand_ordinal(m): 128 | return _inflect.number_to_words(m.group(0)) 129 | 130 | 131 | def _expand_number(m): 132 | num = int(m.group(0)) 133 | if num > 1000 and num < 3000: 134 | if num == 2000: 135 | return 'two thousand' 136 | elif num > 2000 and num < 2010: 137 | return 'two thousand ' + _inflect.number_to_words(num % 100) 138 | elif num % 100 == 0: 139 | return _inflect.number_to_words(num // 100) + ' hundred' 140 | else: 141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 142 | else: 143 | return _inflect.number_to_words(num, andword='') 144 | 145 | 146 | def normalize_numbers(text): 147 | text = re.sub(_comma_number_re, _remove_commas, text) 148 | text = re.sub(_pounds_re, r'\1 pounds', text) 149 | text = re.sub(_dollars_re, _expand_dollars, text) 150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 151 | text = re.sub(_ordinal_re, _expand_ordinal, text) 152 | text = re.sub(_number_re, _expand_number, text) 153 | return text 154 | 155 | 156 | def mark_dark_l(text): 157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ' + x.group(1), text) 158 | 159 | 160 | def english_to_ipa(text): 161 | text = unidecode(text).lower() 162 | text = expand_abbreviations(text) 163 | text = normalize_numbers(text) 164 | phonemes = ipa.convert(text) 165 | phonemes = collapse_whitespace(phonemes) 166 | return phonemes 167 | 168 | 169 | def english_to_lazy_ipa(text): 170 | text = english_to_ipa(text) 171 | for regex, replacement in _lazy_ipa: 172 | text = re.sub(regex, replacement, text) 173 | return text 174 | 175 | 176 | def english_to_ipa2(text): 177 | text = english_to_ipa(text) 178 | text = mark_dark_l(text) 179 | for regex, replacement in _ipa_to_ipa2: 180 | text = re.sub(regex, replacement, text) 181 | return text.replace('...', '…') 182 | 183 | 184 | def english_to_lazy_ipa2(text): 185 | text = english_to_ipa(text) 186 | for regex, replacement in _lazy_ipa2: 187 | text = re.sub(regex, replacement, text) 188 | return text 189 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![cover.png](docs/cover.png) 2 | 3 | [![Python](https://img.shields.io/badge/python-3.8%2B-blue)](https://www.python.org/downloads/) 4 | [![License](https://img.shields.io/badge/license-BSD3-green)](LICENSE) 5 | [![Docker](https://github.com/LlmKira/VitsServer/actions/workflows/docker-latest.yaml/badge.svg)](https://github.com/LlmKira/VitsServer/actions/workflows/docker-latest.yaml) 6 | ------ 7 | 8 | # Vits-Server 🔥 9 | 10 | ⚡ A VITS ONNX server designed for fast inference, supporting streaming and additional inference settings to enable model 11 | preference settings and optimize performance. 12 | 13 | ## 🧪 Experimental purposes only 14 | 15 | This project is for experimental purposes only. 16 | 17 | **If you are looking for a production-ready TTS implementation, go to https://github.com/RVC-Boss/GPT-SoVITS** 18 | 19 | ## Advantages 💪 20 | 21 | - [x] Long Voice Generation, Support Streaming. 长语音批次推理合并。 22 | - [x] Automatic language type parsing for text, eliminating the need for language recognition segmentation. 23 | 自动识别语言类型并处理一切。 24 | - [x] Supports multiple audio formats, including ogg, wav, flac, and silk. 多格式返回写入。 25 | - [x] Multiple models, streaming inference. 多模型初始化。 26 | - [x] Additional inference settings to enable model preference settings and optimize performance. 额外的推理设置,启用模型偏好设置。 27 | - [x] Auto Convert PTH to ONNX. 自动转换pth到onnx。 28 | - [ ] Support for multiple languages, including Chinese, English, Japanese, and Korean. 多语言多模型合并支持(任务批次分发到不同模型)。 29 | 30 | ## API Documentation 📖 31 | 32 | We offer out-of-the-box call systems. 33 | 34 | - [Python SDK](docs/sdk.py) 35 | - [JavaScript SDK](docs/sdk.js) 36 | 37 | ```python 38 | client = VITS("http://127.0.0.1:9557") 39 | res = client.generate_voice(model_id="model_01", text="你好,世界!", speaker_id=0, audio_type="wav", 40 | length_scale=1.0, noise_scale=0.5, noise_scale_w=0.5, auto_parse=True) 41 | with open("output.wav", "wb") as f: 42 | for chunk in res.iter_content(chunk_size=1024): 43 | if chunk: 44 | f.write(chunk) 45 | ``` 46 | 47 | ## Running 🏃 48 | 49 | We recommend using a virtual environment to isolate the runtime environment. Because this project's dependencies may 50 | potentially disrupt your dependency library, we recommend using `pipenv` to manage the dependency package. 51 | 52 | ### Config Server 🐚 53 | 54 | Configuration is in `.env`, including the following fields: 55 | 56 | ```dotenv 57 | VITS_SERVER_HOST=0.0.0.0 58 | VITS_SERVER_PORT=9557 59 | VITS_SERVER_RELOAD=false 60 | # VITS_SERVER_WORKERS=1 61 | # VITS_SERVER_INIT_CONFIG="https://....json" 62 | # VITS_SERVER_INIT_MODEL="https://.....pth or onnx" 63 | ``` 64 | 65 | or you can use the following command to set the environment variable: 66 | 67 | ```shell 68 | export VITS_SERVER_HOST="0.0.0.0" 69 | export VITS_SERVER_PORT="9557" 70 | export VITS_SERVER_RELOAD="false" 71 | export VITS_DISABLE_GPU="false" 72 | 73 | ``` 74 | 75 | `VITS_SERVER_RELOAD` means auto restart server when file changed. 76 | 77 | ### Running from pipenv 🐍 and pm2.json 🚀 78 | 79 | ```shell 80 | apt-get update && 81 | apt-get install -y build-essential libsndfile1 vim gcc g++ cmake 82 | apt install python3-pip 83 | pip3 install pipenv 84 | pipenv install # Create and install dependency packages 85 | pipenv shell # Activate the virtual environment 86 | python3 main.py # Run 87 | # then ctrl+c exit 88 | ``` 89 | 90 | ```shell 91 | apt install npm 92 | npm install pm2 -g 93 | pm2 start pm2.json 94 | # then the server will run in the background 95 | 96 | ``` 97 | 98 | and we have a one-click script to install `pipenv` and `npm`: 99 | 100 | ```shell 101 | curl -LO https://raw.githubusercontent.com/LlmKira/VitsServer/main/deploy_script.sh && chmod +x deploy_script.sh && ./deploy_script.sh 102 | 103 | ``` 104 | 105 | ### Building from Docker 🐋 106 | 107 | we have `docker pull sudoskys/vits-server:main` to docker hub. 108 | 109 | you can also build from Dockerfile. 110 | 111 | ```shell 112 | docker build -t . 113 | ``` 114 | 115 | where `` is the name you want to give to the image. Then, use the following command to start the container: 116 | 117 | ```shell 118 | docker run -d -p 9557:9557 -v /vits_model:/app/model 119 | ``` 120 | 121 | where `` is the local folder path you want to map to the /app/model directory in the container. 122 | 123 | ## Model Configuration 📁 124 | 125 | In the `model` folder, place the `model.pth`/ `model.onnx` and corresponding `model.json` files. If it is `.pth`, it 126 | will be automatically converted to `.onnx`! 127 | 128 | you can use `.env` to set `VITS_SERVER_INIT_CONFIG` and `VITS_SERVER_INIT_MODEL` to download model files. 129 | 130 | ```dotenv 131 | VITS_SERVER_INIT_CONFIG="https://....json" 132 | VITS_SERVER_INIT_MODEL="https://.....pth?trace=233 or onnx?trace=233" 133 | ``` 134 | 135 | `model` folder structure: 136 | 137 | ``` 138 | . 139 | ├── 1000_epochs.json 140 | ├── 1000_epochs.onnx 141 | ├── 1000_epochs.pth 142 | ├── 233_epochs.json 143 | ├── 233_epochs.onnx 144 | └── 233_epochs.pth 145 | ``` 146 | 147 | `Model ID` is `1000_epochs` and `233_epochs`. 148 | 149 | **when you put model files in the `model` folder, you should restart the server.** 150 | 151 | ### Model Extension Design 🔍 152 | 153 | You can add extra fields in the model configuration to obtain information such as the model name corresponding to the 154 | model ID through the API. 155 | 156 | ```json5 157 | { 158 | //... 159 | "info": { 160 | "name": "coco", 161 | "description": "a vits model", 162 | "author": "someone", 163 | "cover": "https://xxx.com/xxx.jpg", 164 | "email": "xx@ws.com" 165 | }, 166 | "infer": { 167 | "noise_scale": 0.667, 168 | "length_scale": 1.0, 169 | "noise_scale_w": 0.8 170 | } 171 | //.... 172 | } 173 | ``` 174 | 175 | `infer` is the default(prefer) inference settings for the model. 176 | 177 | `info` is the model information. 178 | 179 | ### How can I retrieve these model information? 180 | 181 | You can access `{your_base_url}/model/list?show_speaker=True&show_ms_config=True` to obtain detailed information about 182 | model roles and configurations. 183 | 184 | ## TODO 📝 185 | 186 | - [ ] Test Silk format 187 | - [x] Docker for automatic deployment 188 | - [x] Shell script for automatic deployment 189 | 190 | ## Acknowledgements 🙏 191 | 192 | We would like to acknowledge the contributions of the following projects in the development of this project: 193 | 194 | - MoeGoe: https://github.com/CjangCjengh/MoeGoe 195 | - vits_with_chatbot: https://huggingface.co/Mahiruoshi/vits_with_chatbot 196 | - vits: https://huggingface.co/spaces/Plachta/VITS-Umamusume-voice-synthesizer 197 | - espnet: https://github.com/espnet/espnet_onnx 198 | - onnxruntime: https://onnxruntime.ai/ 199 | -------------------------------------------------------------------------------- /text/korean.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import ko_pron 4 | from jamo import h2j, j2hcj 5 | 6 | # This is a list of Korean classifiers preceded by pure Korean numerals. 7 | _korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통' 8 | 9 | # List of (hangul, hangul divided) pairs: 10 | _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [ 11 | ('ㄳ', 'ㄱㅅ'), 12 | ('ㄵ', 'ㄴㅈ'), 13 | ('ㄶ', 'ㄴㅎ'), 14 | ('ㄺ', 'ㄹㄱ'), 15 | ('ㄻ', 'ㄹㅁ'), 16 | ('ㄼ', 'ㄹㅂ'), 17 | ('ㄽ', 'ㄹㅅ'), 18 | ('ㄾ', 'ㄹㅌ'), 19 | ('ㄿ', 'ㄹㅍ'), 20 | ('ㅀ', 'ㄹㅎ'), 21 | ('ㅄ', 'ㅂㅅ'), 22 | ('ㅘ', 'ㅗㅏ'), 23 | ('ㅙ', 'ㅗㅐ'), 24 | ('ㅚ', 'ㅗㅣ'), 25 | ('ㅝ', 'ㅜㅓ'), 26 | ('ㅞ', 'ㅜㅔ'), 27 | ('ㅟ', 'ㅜㅣ'), 28 | ('ㅢ', 'ㅡㅣ'), 29 | ('ㅑ', 'ㅣㅏ'), 30 | ('ㅒ', 'ㅣㅐ'), 31 | ('ㅕ', 'ㅣㅓ'), 32 | ('ㅖ', 'ㅣㅔ'), 33 | ('ㅛ', 'ㅣㅗ'), 34 | ('ㅠ', 'ㅣㅜ') 35 | ]] 36 | 37 | # List of (Latin alphabet, hangul) pairs: 38 | _latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 39 | ('a', '에이'), 40 | ('b', '비'), 41 | ('c', '시'), 42 | ('d', '디'), 43 | ('e', '이'), 44 | ('f', '에프'), 45 | ('g', '지'), 46 | ('h', '에이치'), 47 | ('i', '아이'), 48 | ('j', '제이'), 49 | ('k', '케이'), 50 | ('l', '엘'), 51 | ('m', '엠'), 52 | ('n', '엔'), 53 | ('o', '오'), 54 | ('p', '피'), 55 | ('q', '큐'), 56 | ('r', '아르'), 57 | ('s', '에스'), 58 | ('t', '티'), 59 | ('u', '유'), 60 | ('v', '브이'), 61 | ('w', '더블유'), 62 | ('x', '엑스'), 63 | ('y', '와이'), 64 | ('z', '제트') 65 | ]] 66 | 67 | # List of (ipa, lazy ipa) pairs: 68 | _ipa_to_lazy_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 69 | ('t͡ɕ', 'ʧ'), 70 | ('d͡ʑ', 'ʥ'), 71 | ('ɲ', 'n^'), 72 | ('ɕ', 'ʃ'), 73 | ('ʷ', 'w'), 74 | ('ɭ', 'l`'), 75 | ('ʎ', 'ɾ'), 76 | ('ɣ', 'ŋ'), 77 | ('ɰ', 'ɯ'), 78 | ('ʝ', 'j'), 79 | ('ʌ', 'ə'), 80 | ('ɡ', 'g'), 81 | ('\u031a', '#'), 82 | ('\u0348', '='), 83 | ('\u031e', ''), 84 | ('\u0320', ''), 85 | ('\u0339', '') 86 | ]] 87 | 88 | 89 | def latin_to_hangul(text): 90 | for regex, replacement in _latin_to_hangul: 91 | text = re.sub(regex, replacement, text) 92 | return text 93 | 94 | 95 | def divide_hangul(text): 96 | text = j2hcj(h2j(text)) 97 | for regex, replacement in _hangul_divided: 98 | text = re.sub(regex, replacement, text) 99 | return text 100 | 101 | 102 | def hangul_number(num, sino=True): 103 | '''Reference https://github.com/Kyubyong/g2pK''' 104 | num = re.sub(',', '', num) 105 | 106 | if num == '0': 107 | return '영' 108 | if not sino and num == '20': 109 | return '스무' 110 | 111 | digits = '123456789' 112 | names = '일이삼사오육칠팔구' 113 | digit2name = {d: n for d, n in zip(digits, names)} 114 | 115 | modifiers = '한 두 세 네 다섯 여섯 일곱 여덟 아홉' 116 | decimals = '열 스물 서른 마흔 쉰 예순 일흔 여든 아흔' 117 | digit2mod = {d: mod for d, mod in zip(digits, modifiers.split())} 118 | digit2dec = {d: dec for d, dec in zip(digits, decimals.split())} 119 | 120 | spelledout = [] 121 | for i, digit in enumerate(num): 122 | i = len(num) - i - 1 123 | if sino: 124 | if i == 0: 125 | name = digit2name.get(digit, '') 126 | elif i == 1: 127 | name = digit2name.get(digit, '') + '십' 128 | name = name.replace('일십', '십') 129 | else: 130 | if i == 0: 131 | name = digit2mod.get(digit, '') 132 | elif i == 1: 133 | name = digit2dec.get(digit, '') 134 | if digit == '0': 135 | if i % 4 == 0: 136 | last_three = spelledout[-min(3, len(spelledout)):] 137 | if ''.join(last_three) == '': 138 | spelledout.append('') 139 | continue 140 | else: 141 | spelledout.append('') 142 | continue 143 | if i == 2: 144 | name = digit2name.get(digit, '') + '백' 145 | name = name.replace('일백', '백') 146 | elif i == 3: 147 | name = digit2name.get(digit, '') + '천' 148 | name = name.replace('일천', '천') 149 | elif i == 4: 150 | name = digit2name.get(digit, '') + '만' 151 | name = name.replace('일만', '만') 152 | elif i == 5: 153 | name = digit2name.get(digit, '') + '십' 154 | name = name.replace('일십', '십') 155 | elif i == 6: 156 | name = digit2name.get(digit, '') + '백' 157 | name = name.replace('일백', '백') 158 | elif i == 7: 159 | name = digit2name.get(digit, '') + '천' 160 | name = name.replace('일천', '천') 161 | elif i == 8: 162 | name = digit2name.get(digit, '') + '억' 163 | elif i == 9: 164 | name = digit2name.get(digit, '') + '십' 165 | elif i == 10: 166 | name = digit2name.get(digit, '') + '백' 167 | elif i == 11: 168 | name = digit2name.get(digit, '') + '천' 169 | elif i == 12: 170 | name = digit2name.get(digit, '') + '조' 171 | elif i == 13: 172 | name = digit2name.get(digit, '') + '십' 173 | elif i == 14: 174 | name = digit2name.get(digit, '') + '백' 175 | elif i == 15: 176 | name = digit2name.get(digit, '') + '천' 177 | spelledout.append(name) 178 | return ''.join(elem for elem in spelledout) 179 | 180 | 181 | def number_to_hangul(text): 182 | '''Reference https://github.com/Kyubyong/g2pK''' 183 | tokens = set(re.findall(r'(\d[\d,]*)([\uac00-\ud71f]+)', text)) 184 | for token in tokens: 185 | num, classifier = token 186 | if classifier[:2] in _korean_classifiers or classifier[0] in _korean_classifiers: 187 | spelledout = hangul_number(num, sino=False) 188 | else: 189 | spelledout = hangul_number(num, sino=True) 190 | text = text.replace(f'{num}{classifier}', f'{spelledout}{classifier}') 191 | # digit by digit for remaining digits 192 | digits = '0123456789' 193 | names = '영일이삼사오육칠팔구' 194 | for d, n in zip(digits, names): 195 | text = text.replace(d, n) 196 | return text 197 | 198 | 199 | def korean_to_lazy_ipa(text): 200 | text = latin_to_hangul(text) 201 | text = number_to_hangul(text) 202 | text = re.sub('[\uac00-\ud7af]+', lambda x: ko_pron.romanise(x.group(0), 'ipa').split('] ~ [')[0], text) 203 | for regex, replacement in _ipa_to_lazy_ipa: 204 | text = re.sub(regex, replacement, text) 205 | return text 206 | 207 | 208 | def korean_to_ipa(text): 209 | text = korean_to_lazy_ipa(text) 210 | return text.replace('ʧ', 'tʃ').replace('ʥ', 'dʑ') 211 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from pypinyin import pinyin, Style 4 | from pypinyin.style._utils import get_finals, get_initials 5 | from pypinyin_dict.phrase_pinyin_data import cc_cedict 6 | from pypinyin_dict.pinyin_data import kmandarin_8105 7 | 8 | # Load pinyin data 9 | kmandarin_8105.load() 10 | cc_cedict.load() 11 | 12 | 13 | def japanese_cleaners(text): 14 | from text.japanese import japanese_to_romaji_with_accent 15 | text = japanese_to_romaji_with_accent(text) 16 | text = re.sub(r'([A-Za-z])$', r'\1.', text) 17 | return text 18 | 19 | 20 | def japanese_cleaners2(text): 21 | return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…') 22 | 23 | 24 | def korean_cleaners(text): 25 | '''Pipeline for Korean text''' 26 | from text.korean import latin_to_hangul, number_to_hangul, divide_hangul 27 | text = latin_to_hangul(text) 28 | text = number_to_hangul(text) 29 | text = divide_hangul(text) 30 | text = re.sub(r'([\u3131-\u3163])$', r'\1.', text) 31 | return text 32 | 33 | 34 | def chinese_cleaners(text): 35 | '''Pipeline for Chinese text''' 36 | from text.mandarin import number_to_chinese, chinese_to_bopomofo, latin_to_bopomofo 37 | text = number_to_chinese(text) 38 | text = chinese_to_bopomofo(text) 39 | text = latin_to_bopomofo(text) 40 | text = re.sub(r'([ˉˊˇˋ˙])$', r'\1。', text) 41 | return text 42 | 43 | 44 | def zh_ja_mixture_cleaners(text): 45 | from text.mandarin import chinese_to_romaji 46 | from text.japanese import japanese_to_romaji_with_accent 47 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 48 | lambda x: chinese_to_romaji(x.group(1)) + ' ', text) 49 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_romaji_with_accent( 50 | x.group(1)).replace('ts', 'ʦ').replace('u', 'ɯ').replace('...', '…') + ' ', text) 51 | text = re.sub(r'\s+$', '', text) 52 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 53 | return text 54 | 55 | 56 | def sanskrit_cleaners(text): 57 | text = text.replace('॥', '।').replace('ॐ', 'ओम्') 58 | text = re.sub(r'([^।])$', r'\1।', text) 59 | return text 60 | 61 | 62 | def cjks_cleaners(text): 63 | from text.mandarin import chinese_to_lazy_ipa 64 | from text.japanese import japanese_to_ipa 65 | from text.korean import korean_to_lazy_ipa 66 | from text.sanskrit import devanagari_to_ipa 67 | from text.english import english_to_lazy_ipa 68 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 69 | lambda x: chinese_to_lazy_ipa(x.group(1)) + ' ', text) 70 | text = re.sub(r'\[JA\](.*?)\[JA\]', 71 | lambda x: japanese_to_ipa(x.group(1)) + ' ', text) 72 | text = re.sub(r'\[KO\](.*?)\[KO\]', 73 | lambda x: korean_to_lazy_ipa(x.group(1)) + ' ', text) 74 | text = re.sub(r'\[SA\](.*?)\[SA\]', 75 | lambda x: devanagari_to_ipa(x.group(1)) + ' ', text) 76 | text = re.sub(r'\[EN\](.*?)\[EN\]', 77 | lambda x: english_to_lazy_ipa(x.group(1)) + ' ', text) 78 | text = re.sub(r'\s+$', '', text) 79 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 80 | return text 81 | 82 | 83 | def cjke_cleaners(text): 84 | from text.mandarin import chinese_to_lazy_ipa 85 | from text.japanese import japanese_to_ipa 86 | from text.korean import korean_to_ipa 87 | from text.english import english_to_ipa2 88 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: chinese_to_lazy_ipa(x.group(1)).replace( 89 | 'ʧ', 'tʃ').replace('ʦ', 'ts').replace('ɥan', 'ɥæn') + ' ', text) 90 | text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa(x.group(1)).replace('ʧ', 'tʃ').replace( 91 | 'ʦ', 'ts').replace('ɥan', 'ɥæn').replace('ʥ', 'dz') + ' ', text) 92 | text = re.sub(r'\[KO\](.*?)\[KO\]', 93 | lambda x: korean_to_ipa(x.group(1)) + ' ', text) 94 | text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_ipa2(x.group(1)).replace('ɑ', 'a').replace( 95 | 'ɔ', 'o').replace('ɛ', 'e').replace('ɪ', 'i').replace('ʊ', 'u') + ' ', text) 96 | text = re.sub(r'\s+$', '', text) 97 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 98 | return text 99 | 100 | 101 | def cjke_cleaners2(text): 102 | from text.mandarin import chinese_to_ipa 103 | from text.japanese import japanese_to_ipa2 104 | from text.korean import korean_to_ipa 105 | from text.english import english_to_ipa2 106 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 107 | lambda x: chinese_to_ipa(x.group(1)) + ' ', text) 108 | text = re.sub(r'\[JA\](.*?)\[JA\]', 109 | lambda x: japanese_to_ipa2(x.group(1)) + ' ', text) 110 | text = re.sub(r'\[KO\](.*?)\[KO\]', 111 | lambda x: korean_to_ipa(x.group(1)) + ' ', text) 112 | text = re.sub(r'\[EN\](.*?)\[EN\]', 113 | lambda x: english_to_ipa2(x.group(1)) + ' ', text) 114 | text = re.sub(r'\s+$', '', text) 115 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 116 | return text 117 | 118 | 119 | def thai_cleaners(text): 120 | from text.thai import num_to_thai, latin_to_thai 121 | text = num_to_thai(text) 122 | text = latin_to_thai(text) 123 | return text 124 | 125 | 126 | def chinese_cleaners2(text): 127 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: x.group(1) + ' ', text) 128 | result = " ".join([ 129 | p 130 | for phone in pinyin(text, style=Style.TONE3, v_to_u=True) 131 | for p in [ 132 | get_initials(phone[0], strict=True), 133 | get_finals(phone[0][:-1], strict=True) + phone[0][-1] 134 | if phone[0][-1].isdigit() 135 | else get_finals(phone[0], strict=True) 136 | if phone[0][-1].isalnum() 137 | else phone[0], 138 | ] 139 | if len(p) != 0 and not p.isdigit() 140 | ]) 141 | return result 142 | 143 | 144 | def shanghainese_cleaners(text): 145 | from text.shanghainese import shanghainese_to_ipa 146 | text = shanghainese_to_ipa(text) 147 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 148 | return text 149 | 150 | 151 | def chinese_dialect_cleaners(text): 152 | from text.mandarin import chinese_to_ipa2 153 | from text.japanese import japanese_to_ipa3 154 | from text.shanghainese import shanghainese_to_ipa 155 | from text.cantonese import cantonese_to_ipa 156 | from text.english import english_to_lazy_ipa2 157 | from text.ngu_dialect import ngu_dialect_to_ipa 158 | text = re.sub(r'\[ZH\](.*?)\[ZH\]', 159 | lambda x: chinese_to_ipa2(x.group(1)) + ' ', text) 160 | text = re.sub(r'\[JA\](.*?)\[JA\]', 161 | lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ') + ' ', text) 162 | text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5', 163 | '˧˧˦').replace( 164 | '6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e') + ' ', text) 165 | text = re.sub(r'\[GD\](.*?)\[GD\]', 166 | lambda x: cantonese_to_ipa(x.group(1)) + ' ', text) 167 | text = re.sub(r'\[EN\](.*?)\[EN\]', 168 | lambda x: english_to_lazy_ipa2(x.group(1)) + ' ', text) 169 | text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group( 170 | 1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ') + ' ', text) 171 | text = re.sub(r'\s+$', '', text) 172 | text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) 173 | return text 174 | -------------------------------------------------------------------------------- /pth2onnx.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | from typing import Union 4 | 5 | import librosa 6 | import torch 7 | 8 | import utils 9 | from onnx_infer import onnx_infer 10 | from onnx_infer.infer import commons 11 | from onnx_infer.utils.onnx_utils import RunONNX 12 | from text import text_to_sequence 13 | 14 | 15 | def get_text(text, hps): 16 | text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners) 17 | if hps.data.add_blank: 18 | text_norm = commons.intersperse(text_norm, 0) 19 | text_norm = torch.LongTensor(text_norm) 20 | return text_norm 21 | 22 | 23 | def to_numpy(tensor): 24 | return tensor.detach().cpu().numpy() if tensor.requires_grad \ 25 | else tensor.detach().numpy() 26 | 27 | 28 | class VitsExtractor(object): 29 | @staticmethod 30 | def write_out(model_path, obj): 31 | """ 32 | 创建一个和模型同名的 onnx 文件。写入 ByteIO 33 | """ 34 | import pathlib 35 | model_path = pathlib.Path(model_path) 36 | onnx_path = model_path.parent / f'{model_path.stem}.onnx' 37 | with open(onnx_path, 'wb') as f: 38 | f.write(obj.getvalue()) 39 | 40 | def convert_model(self, json_path: str, 41 | model_path: str, 42 | write_down: Union[bool, str] = None, 43 | providers=None, 44 | ) -> io.BytesIO: 45 | # Load pa from JSON file 46 | if providers is None: 47 | # providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 48 | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 49 | if utils.get_device() == "cpu": 50 | providers = ['CPUExecutionProvider'] 51 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 52 | hps = utils.get_hparams_from_file(json_path) 53 | 54 | # Get symbols and initialize synthesizer model 55 | symbols = hps.symbols if "symbols" in hps else [] 56 | net_g = onnx_infer.SynthesizerTrn( 57 | len(symbols), 58 | hps.data.filter_length // 2 + 1, 59 | hps.train.segment_size // hps.data.hop_length, 60 | n_speakers=hps.data.n_speakers, 61 | **hps.model) 62 | 63 | # Load model checkpoint and set to evaluation mode 64 | _ = utils.load_checkpoint(model_path, net_g, None) 65 | net_g.forward = net_g.export_forward 66 | _ = net_g.eval() 67 | scales = torch.FloatTensor([0.667, 1.0, 0.8]) 68 | # make triton dynamic shape happy 69 | scales = scales.unsqueeze(0) 70 | 71 | onnx_model = io.BytesIO() 72 | if symbols: 73 | seq = torch.randint(low=0, high=len(symbols), size=(1, 10), dtype=torch.long) 74 | seq_len = torch.IntTensor([seq.size(1)]).long() 75 | sid = torch.IntTensor([0]).long() 76 | else: 77 | hubert = torch.hub.load("bshall/hubert:main", "hubert_soft", trust_repo=True) 78 | audio16000, sampling_rate = librosa.load("sample.wav", sr=16000, mono=True) 79 | seq = hubert.units(torch.FloatTensor(audio16000).unsqueeze(0).unsqueeze(0).to("cpu")) 80 | # seq = torch.randint(low=0, high=1, size=(1, 10), dtype=torch.long) 81 | seq_len = torch.IntTensor([seq.size(1)]).long() 82 | sid = torch.IntTensor([0]).long() 83 | # seq = hubert.units(torch.FloatTensor("").unsqueeze(0).unsqueeze(0)) 84 | # seq_len = torch.IntTensor([seq.size(1)]).long() 85 | dummy_input = (seq, seq_len, scales, sid) 86 | torch.onnx.export(model=net_g, 87 | args=dummy_input, 88 | f=onnx_model, 89 | input_names=['input', 'input_lengths', 'scales', 'sid'], 90 | output_names=['output'], 91 | dynamic_axes={ 92 | 'input': { 93 | 0: 'batch', 94 | 1: 'phonemes' 95 | }, 96 | 'input_lengths': { 97 | 0: 'batch' 98 | }, 99 | 'scales': { 100 | 0: 'batch' 101 | }, 102 | 'sid': { 103 | 0: 'batch' 104 | }, 105 | 'output': { 106 | 0: 'batch', 107 | 1: 'audio', 108 | 2: 'audio_length' 109 | } 110 | }, 111 | opset_version=13, 112 | verbose=False) 113 | 114 | # Verify onnx precision 115 | torch_output = net_g(seq, seq_len, scales, sid) 116 | ort_inputs = { 117 | 'input': to_numpy(seq), 118 | 'input_lengths': to_numpy(seq_len), 119 | 'scales': to_numpy(scales), 120 | 'sid': to_numpy(sid), 121 | } 122 | if not symbols: 123 | # TODO 检查模型结构,似乎无法正常导出 Hubert 模型 124 | ort_inputs.pop("sid") 125 | onnx_output = RunONNX(model=onnx_model, providers=providers).run(model_input=ort_inputs) 126 | # Convert PyTorch model to ONNX format 127 | if write_down: 128 | if type(write_down) == str: 129 | with open(write_down, 'wb') as f: 130 | f.write(onnx_model.getvalue()) 131 | else: 132 | self.write_out(model_path, onnx_model) 133 | # Release memory by deleting PyTorch model 134 | del net_g 135 | return onnx_model 136 | 137 | def warp_pth(self, model_config_path: str, model_path: str = None, return_bytes: bool = False) -> Union[bytes, str]: 138 | import pathlib 139 | model_config_path = pathlib.Path(model_config_path) 140 | if model_config_path.suffix != ".json": 141 | raise ValueError("The model config path must end with .json") 142 | if not model_config_path.exists(): 143 | raise ValueError("The model config path does not exist") 144 | # ONNX 145 | onnx_model_path = model_config_path.parent / f'{model_config_path.stem}.onnx' 146 | if model_path: 147 | model_path = pathlib.Path(model_path) 148 | if model_path.suffix == ".onnx" and model_path.exists(): 149 | # 如果是 .onnx 则直接返回 150 | onnx_model_path = model_path 151 | 152 | # PTH 153 | pth_model_path = model_config_path.parent / f'{model_config_path.stem}.pth' 154 | # 去掉 .json 155 | if pathlib.Path(onnx_model_path).exists(): 156 | if return_bytes: 157 | with open(onnx_model_path, 'rb') as f: 158 | return f.read() 159 | return str(onnx_model_path) 160 | if pathlib.Path(pth_model_path).exists(): 161 | onnx_model_byte = self.convert_model(json_path=str(model_config_path), model_path=str(pth_model_path), 162 | write_down=True) 163 | if return_bytes: 164 | return onnx_model_byte.getvalue() 165 | return str(onnx_model_path) 166 | if True: 167 | raise ValueError("The model files do not exist") 168 | 169 | 170 | if __name__ == "__main__": 171 | model = VitsExtractor().warp_pth( 172 | model_config_path="model/1374_epochs.json", 173 | model_path="model/1374_epochs.pth", 174 | return_bytes=True 175 | ) 176 | # 测试类型 177 | print(type(model)) 178 | # 导入 onnxruntime 库测试是否可以初始化运行时 179 | import onnxruntime as ort 180 | 181 | print("onnxruntime version:", ort.__version__) 182 | _model = ort.InferenceSession(model) 183 | _model.get_outputs() 184 | del _model 185 | -------------------------------------------------------------------------------- /hubert_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 9 | 10 | 11 | class Hubert(nn.Module): 12 | def __init__(self, num_label_embeddings: int = 100, mask: bool = True): 13 | super().__init__() 14 | self._mask = mask 15 | self.feature_extractor = FeatureExtractor() 16 | self.feature_projection = FeatureProjection() 17 | self.positional_embedding = PositionalConvEmbedding() 18 | self.norm = nn.LayerNorm(768) 19 | self.dropout = nn.Dropout(0.1) 20 | self.encoder = TransformerEncoder( 21 | nn.TransformerEncoderLayer( 22 | 768, 12, 3072, activation="gelu", batch_first=True 23 | ), 24 | 12, 25 | ) 26 | self.proj = nn.Linear(768, 256) 27 | 28 | self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_()) 29 | self.label_embedding = nn.Embedding(num_label_embeddings, 256) 30 | 31 | def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 32 | mask = None 33 | if self.training and self._mask: 34 | mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2) 35 | x[mask] = self.masked_spec_embed.to(x.dtype) 36 | return x, mask 37 | 38 | def encode( 39 | self, x: torch.Tensor, layer: Optional[int] = None 40 | ) -> Tuple[torch.Tensor, torch.Tensor]: 41 | x = self.feature_extractor(x) 42 | x = self.feature_projection(x.transpose(1, 2)) 43 | x, mask = self.mask(x) 44 | x = x + self.positional_embedding(x) 45 | x = self.dropout(self.norm(x)) 46 | x = self.encoder(x, output_layer=layer) 47 | return x, mask 48 | 49 | def logits(self, x: torch.Tensor) -> torch.Tensor: 50 | logits = torch.cosine_similarity( 51 | x.unsqueeze(2), 52 | self.label_embedding.weight.unsqueeze(0).unsqueeze(0), 53 | dim=-1, 54 | ) 55 | return logits / 0.1 56 | 57 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 58 | x, mask = self.encode(x) 59 | x = self.proj(x) 60 | logits = self.logits(x) 61 | return logits, mask 62 | 63 | 64 | class HubertSoft(Hubert): 65 | def __init__(self): 66 | super().__init__() 67 | 68 | @torch.inference_mode() 69 | def units(self, wav: torch.Tensor) -> torch.Tensor: 70 | wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2)) 71 | x, _ = self.encode(wav) 72 | return self.proj(x) 73 | 74 | 75 | class FeatureExtractor(nn.Module): 76 | def __init__(self): 77 | super().__init__() 78 | self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False) 79 | self.norm0 = nn.GroupNorm(512, 512) 80 | self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False) 81 | self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False) 82 | self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False) 83 | self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False) 84 | self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False) 85 | self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False) 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | x = F.gelu(self.norm0(self.conv0(x))) 89 | x = F.gelu(self.conv1(x)) 90 | x = F.gelu(self.conv2(x)) 91 | x = F.gelu(self.conv3(x)) 92 | x = F.gelu(self.conv4(x)) 93 | x = F.gelu(self.conv5(x)) 94 | x = F.gelu(self.conv6(x)) 95 | return x 96 | 97 | 98 | class FeatureProjection(nn.Module): 99 | def __init__(self): 100 | super().__init__() 101 | self.norm = nn.LayerNorm(512) 102 | self.projection = nn.Linear(512, 768) 103 | self.dropout = nn.Dropout(0.1) 104 | 105 | def forward(self, x: torch.Tensor) -> torch.Tensor: 106 | x = self.norm(x) 107 | x = self.projection(x) 108 | x = self.dropout(x) 109 | return x 110 | 111 | 112 | class PositionalConvEmbedding(nn.Module): 113 | def __init__(self): 114 | super().__init__() 115 | self.conv = nn.Conv1d( 116 | 768, 117 | 768, 118 | kernel_size=128, 119 | padding=128 // 2, 120 | groups=16, 121 | ) 122 | self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) 123 | 124 | def forward(self, x: torch.Tensor) -> torch.Tensor: 125 | x = self.conv(x.transpose(1, 2)) 126 | x = F.gelu(x[:, :, :-1]) 127 | return x.transpose(1, 2) 128 | 129 | 130 | class TransformerEncoder(nn.Module): 131 | def __init__( 132 | self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int 133 | ) -> None: 134 | super(TransformerEncoder, self).__init__() 135 | self.layers = nn.ModuleList( 136 | [copy.deepcopy(encoder_layer) for _ in range(num_layers)] 137 | ) 138 | self.num_layers = num_layers 139 | 140 | def forward( 141 | self, 142 | src: torch.Tensor, 143 | mask: torch.Tensor = None, 144 | src_key_padding_mask: torch.Tensor = None, 145 | output_layer: Optional[int] = None, 146 | ) -> torch.Tensor: 147 | output = src 148 | for layer in self.layers[:output_layer]: 149 | output = layer( 150 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask 151 | ) 152 | return output 153 | 154 | 155 | def _compute_mask( 156 | shape: Tuple[int, int], 157 | mask_prob: float, 158 | mask_length: int, 159 | device: torch.device, 160 | min_masks: int = 0, 161 | ) -> torch.Tensor: 162 | batch_size, sequence_length = shape 163 | 164 | if mask_length < 1: 165 | raise ValueError("`mask_length` has to be bigger than 0.") 166 | 167 | if mask_length > sequence_length: 168 | raise ValueError( 169 | f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" 170 | ) 171 | 172 | # compute number of masked spans in batch 173 | num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random()) 174 | num_masked_spans = max(num_masked_spans, min_masks) 175 | 176 | # make sure num masked indices <= sequence_length 177 | if num_masked_spans * mask_length > sequence_length: 178 | num_masked_spans = sequence_length // mask_length 179 | 180 | # SpecAugment mask to fill 181 | mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) 182 | 183 | # uniform distribution to sample from, make sure that offset samples are < sequence_length 184 | uniform_dist = torch.ones( 185 | (batch_size, sequence_length - (mask_length - 1)), device=device 186 | ) 187 | 188 | # get random indices to mask 189 | mask_indices = torch.multinomial(uniform_dist, num_masked_spans) 190 | 191 | # expand masked indices to masked spans 192 | mask_indices = ( 193 | mask_indices.unsqueeze(dim=-1) 194 | .expand((batch_size, num_masked_spans, mask_length)) 195 | .reshape(batch_size, num_masked_spans * mask_length) 196 | ) 197 | offsets = ( 198 | torch.arange(mask_length, device=device)[None, None, :] 199 | .expand((batch_size, num_masked_spans, mask_length)) 200 | .reshape(batch_size, num_masked_spans * mask_length) 201 | ) 202 | mask_idxs = mask_indices + offsets 203 | 204 | # scatter indices to mask 205 | mask = mask.scatter(1, mask_idxs, True) 206 | 207 | return mask 208 | 209 | 210 | def hubert_soft( 211 | path: str 212 | ) -> HubertSoft: 213 | r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`. 214 | Args: 215 | path (str): path of a pretrained model 216 | """ 217 | hubert = HubertSoft() 218 | checkpoint = torch.load(path) 219 | consume_prefix_in_state_dict_if_present(checkpoint, "module.") 220 | hubert.load_state_dict(checkpoint) 221 | hubert.eval() 222 | return hubert 223 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | model/*.json 6 | model/*.pth 7 | model/*.bak 8 | model/*.onnx 9 | .idea 10 | # User-specific files 11 | *.rsuser 12 | *.suo 13 | *.user 14 | *.userosscache 15 | *.sln.docstates 16 | 17 | # User-specific files (MonoDevelop/Xamarin Studio) 18 | *.userprefs 19 | 20 | # Mono auto generated files 21 | mono_crash.* 22 | 23 | # Build results 24 | [Dd]ebug/ 25 | [Dd]ebugPublic/ 26 | [Rr]elease/ 27 | [Rr]eleases/ 28 | x64/ 29 | x86/ 30 | [Ww][Ii][Nn]32/ 31 | [Aa][Rr][Mm]/ 32 | [Aa][Rr][Mm]64/ 33 | bld/ 34 | [Bb]in/ 35 | [Oo]bj/ 36 | [Oo]ut/ 37 | [Ll]og/ 38 | [Ll]ogs/ 39 | 40 | # Visual Studio 2015/2017 cache/options directory 41 | .vs/ 42 | # Uncomment if you have tasks that create the project's static files in wwwroot 43 | #wwwroot/ 44 | 45 | # Visual Studio 2017 auto generated files 46 | Generated\ Files/ 47 | 48 | # MSTest test Results 49 | [Tt]est[Rr]esult*/ 50 | [Bb]uild[Ll]og.* 51 | 52 | # NUnit 53 | *.VisualState.xml 54 | TestResult.xml 55 | nunit-*.xml 56 | 57 | # Build Results of an ATL Project 58 | [Dd]ebugPS/ 59 | [Rr]eleasePS/ 60 | dlldata.c 61 | 62 | # Benchmark Results 63 | BenchmarkDotNet.Artifacts/ 64 | 65 | # .NET Core 66 | project.lock.json 67 | project.fragment.lock.json 68 | artifacts/ 69 | 70 | # ASP.NET Scaffolding 71 | ScaffoldingReadMe.txt 72 | 73 | # StyleCop 74 | StyleCopReport.xml 75 | 76 | # Files built by Visual Studio 77 | *_i.c 78 | *_p.c 79 | *_h.h 80 | *.ilk 81 | *.meta 82 | *.obj 83 | *.iobj 84 | *.pch 85 | *.pdb 86 | *.ipdb 87 | *.pgc 88 | *.pgd 89 | *.rsp 90 | *.sbr 91 | *.tlb 92 | *.tli 93 | *.tlh 94 | *.tmp 95 | *.tmp_proj 96 | *_wpftmp.csproj 97 | *.log 98 | *.vspscc 99 | *.vssscc 100 | .builds 101 | *.pidb 102 | *.svclog 103 | *.scc 104 | 105 | # Chutzpah Test files 106 | _Chutzpah* 107 | 108 | # Visual C++ cache files 109 | ipch/ 110 | *.aps 111 | *.ncb 112 | *.opendb 113 | *.opensdf 114 | *.sdf 115 | *.cachefile 116 | *.VC.db 117 | *.VC.VC.opendb 118 | 119 | # Visual Studio profiler 120 | *.psess 121 | *.vsp 122 | *.vspx 123 | *.sap 124 | 125 | # Visual Studio Trace Files 126 | *.e2e 127 | 128 | # TFS 2012 Local Workspace 129 | $tf/ 130 | 131 | # Guidance Automation Toolkit 132 | *.gpState 133 | 134 | # ReSharper is a .NET coding add-in 135 | _ReSharper*/ 136 | *.[Rr]e[Ss]harper 137 | *.DotSettings.user 138 | 139 | # TeamCity is a build add-in 140 | _TeamCity* 141 | 142 | # DotCover is a Code Coverage Tool 143 | *.dotCover 144 | 145 | # AxoCover is a Code Coverage Tool 146 | .axoCover/* 147 | !.axoCover/settings.json 148 | 149 | # Coverlet is a free, cross platform Code Coverage Tool 150 | coverage*.json 151 | coverage*.xml 152 | coverage*.info 153 | 154 | # Visual Studio code coverage results 155 | *.coverage 156 | *.coveragexml 157 | 158 | # NCrunch 159 | _NCrunch_* 160 | .*crunch*.local.xml 161 | nCrunchTemp_* 162 | 163 | # MightyMoose 164 | *.mm.* 165 | AutoTest.Net/ 166 | 167 | # Web workbench (sass) 168 | .sass-cache/ 169 | 170 | # Installshield output folder 171 | [Ee]xpress/ 172 | 173 | # DocProject is a documentation generator add-in 174 | DocProject/buildhelp/ 175 | DocProject/Help/*.HxT 176 | DocProject/Help/*.HxC 177 | DocProject/Help/*.hhc 178 | DocProject/Help/*.hhk 179 | DocProject/Help/*.hhp 180 | DocProject/Help/Html2 181 | DocProject/Help/html 182 | 183 | # Click-Once directory 184 | publish/ 185 | 186 | # Publish Web Output 187 | *.[Pp]ublish.xml 188 | *.azurePubxml 189 | # Note: Comment the next line if you want to checkin your web deploy settings, 190 | # but database connection strings (with potential passwords) will be unencrypted 191 | *.pubxml 192 | *.publishproj 193 | 194 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 195 | # checkin your Azure Web App publish settings, but sensitive information contained 196 | # in these scripts will be unencrypted 197 | PublishScripts/ 198 | 199 | # NuGet Packages 200 | *.nupkg 201 | # NuGet Symbol Packages 202 | *.snupkg 203 | # The packages folder can be ignored because of Package Restore 204 | **/[Pp]ackages/* 205 | # except build/, which is used as an MSBuild target. 206 | !**/[Pp]ackages/build/ 207 | # Uncomment if necessary however generally it will be regenerated when needed 208 | #!**/[Pp]ackages/repositories.models 209 | # NuGet v3's project.json files produces more ignorable files 210 | *.nuget.props 211 | *.nuget.targets 212 | 213 | # Microsoft Azure Build Output 214 | csx/ 215 | *.build.csdef 216 | 217 | # Microsoft Azure Emulator 218 | ecf/ 219 | rcf/ 220 | 221 | # Windows Store app package directories and files 222 | AppPackages/ 223 | BundleArtifacts/ 224 | Package.StoreAssociation.xml 225 | _pkginfo.txt 226 | *.appx 227 | *.appxbundle 228 | *.appxupload 229 | 230 | # Visual Studio cache files 231 | # files ending in .cache can be ignored 232 | *.[Cc]ache 233 | # but keep track of directories ending in .cache 234 | !?*.[Cc]ache/ 235 | 236 | # Others 237 | ClientBin/ 238 | ~$* 239 | *~ 240 | *.dbmdl 241 | *.dbproj.schemaview 242 | *.jfm 243 | *.pfx 244 | *.publishsettings 245 | orleans.codegen.cs 246 | 247 | # Including strong name files can present a security risk 248 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 249 | #*.snk 250 | 251 | # Since there are multiple workflows, uncomment next line to ignore bower_components 252 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 253 | #bower_components/ 254 | 255 | # RIA/Silverlight projects 256 | Generated_Code/ 257 | 258 | # Backup & report files from converting an old project file 259 | # to a newer Visual Studio version. Backup files are not needed, 260 | # because we have git ;-) 261 | _UpgradeReport_Files/ 262 | Backup*/ 263 | UpgradeLog*.XML 264 | UpgradeLog*.htm 265 | ServiceFabricBackup/ 266 | *.rptproj.bak 267 | 268 | # SQL Server files 269 | *.mdf 270 | *.ldf 271 | *.ndf 272 | 273 | # Business Intelligence projects 274 | *.rdl.data 275 | *.bim.layout 276 | *.bim_*.settings 277 | *.rptproj.rsuser 278 | *- [Bb]ackup.rdl 279 | *- [Bb]ackup ([0-9]).rdl 280 | *- [Bb]ackup ([0-9][0-9]).rdl 281 | 282 | # Microsoft Fakes 283 | FakesAssemblies/ 284 | 285 | # GhostDoc plugin setting file 286 | *.GhostDoc.xml 287 | 288 | # Node.js Tools for Visual Studio 289 | .ntvs_analysis.dat 290 | node_modules/ 291 | 292 | # Visual Studio 6 build log 293 | *.plg 294 | 295 | # Visual Studio 6 workspace options file 296 | *.opt 297 | 298 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 299 | *.vbw 300 | 301 | # Visual Studio LightSwitch build output 302 | **/*.HTMLClient/GeneratedArtifacts 303 | **/*.DesktopClient/GeneratedArtifacts 304 | **/*.DesktopClient/ModelManifest.xml 305 | **/*.Server/GeneratedArtifacts 306 | **/*.Server/ModelManifest.xml 307 | _Pvt_Extensions 308 | 309 | # Paket dependency manager 310 | .paket/paket.exe 311 | paket-files/ 312 | 313 | # FAKE - F# Make 314 | .fake/ 315 | 316 | # CodeRush personal settings 317 | .cr/personal 318 | 319 | # Python Tools for Visual Studio (PTVS) 320 | __pycache__/ 321 | *.pyc 322 | 323 | # Cake - Uncomment if you are using it 324 | # tools/** 325 | # !tools/packages.models 326 | 327 | # Tabs Studio 328 | *.tss 329 | 330 | # Telerik's JustMock configuration file 331 | *.jmconfig 332 | 333 | # BizTalk build output 334 | *.btp.cs 335 | *.btm.cs 336 | *.odx.cs 337 | *.xsd.cs 338 | 339 | # OpenCover UI analysis results 340 | OpenCover/ 341 | 342 | # Azure Stream Analytics local run output 343 | ASALocalRun/ 344 | 345 | # MSBuild Binary and Structured Log 346 | *.binlog 347 | 348 | # NVidia Nsight GPU debugger configuration file 349 | *.nvuser 350 | 351 | # MFractors (Xamarin productivity tool) working folder 352 | .mfractor/ 353 | 354 | # Local History for Visual Studio 355 | .localhistory/ 356 | 357 | # BeatPulse healthcheck temp database 358 | healthchecksdb 359 | 360 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 361 | MigrationBackup/ 362 | 363 | # Ionide (cross platform F# VS Code tools) working folder 364 | .ionide/ 365 | 366 | # Fody - auto-generated XML schema 367 | FodyWeavers.xsd 368 | 369 | # build 370 | build 371 | onnx_infer/monotonic_align/core.c 372 | *.o 373 | *.so 374 | *.dll 375 | 376 | # data 377 | /config.toml 378 | /*.pth 379 | *.wav 380 | /monotonic_align/monotonic_align 381 | /resources 382 | /MoeGoe.spec 383 | /dist/MoeGoe 384 | /dist 385 | 386 | # MacOS 387 | .DS_Store 388 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/5 下午11:25 3 | # @Author : sudoskys 4 | # @File : server.py 5 | # @Software: PyCharm 6 | # -*- coding: utf-8 -*- 7 | # @Time : 12/19/22 9:09 PM 8 | # @FileName: main.py 9 | # @Software: PyCharm 10 | # @Github :sudoskys 11 | import io 12 | import pathlib 13 | import shutil 14 | from typing import Annotated 15 | 16 | import psutil 17 | from fastapi import FastAPI, HTTPException, UploadFile, File, Form 18 | from fastapi.responses import StreamingResponse 19 | from loguru import logger 20 | 21 | import utils 22 | from component.nlp_utils.detect import DetectSentence 23 | from component.warp import Parse 24 | # from celery_worker import tts_order 25 | from event import TtsGenerate, TtsSchema, InferTask, ConvertTask, ConvertSchema 26 | 27 | # 日志机器 28 | logger.add(sink='run.log', 29 | format="{time} - {level} - {message}", 30 | level="INFO", 31 | rotation="500 MB", 32 | enqueue=True) 33 | 34 | app = FastAPI() 35 | 36 | _Model_list = {} 37 | pathlib.Path("./model").mkdir(parents=True, exist_ok=True) 38 | for model_config_path in pathlib.Path("./model").iterdir(): 39 | if model_config_path.is_file() and model_config_path.suffix == ".json": 40 | pth_model_path = model_config_path.parent / f'{model_config_path.stem}.pth' 41 | onnx_model_path = model_config_path.parent / f'{model_config_path.stem}.onnx' 42 | if pathlib.Path(pth_model_path).exists() or pathlib.Path(onnx_model_path).exists(): 43 | _load_model = TtsGenerate(model_config_path=str(model_config_path.absolute()), device=utils.get_device()) 44 | _Model_list[model_config_path.stem] = _load_model 45 | if _load_model.net_g_ms: 46 | logger.success(f"{model_config_path} 对应的模型配置加载成功") 47 | else: 48 | logger.warning(f"{model_config_path} 对应的模型配置加载失败") 49 | else: 50 | logger.warning(f"{model_config_path} 没有对应的模型文件") 51 | 52 | 53 | # 主页 54 | @app.get("/") 55 | def index(): 56 | # 获取当前内存剩余 57 | _rest = psutil.virtual_memory().percent 58 | return {"code": 0, "message": "success", "data": {"memory": _rest}} 59 | 60 | 61 | # FastApi 获取模型列表和信息 62 | @app.get("/model/list") 63 | def tts_model(show_speaker: bool = False, show_ms_config: bool = False): 64 | global _Model_list 65 | _data = [] 66 | # 构建模型信息 67 | for _model_name, _model in _Model_list.items(): 68 | _model: TtsGenerate 69 | _item = { 70 | "model_id": _model_name, 71 | "model_info": _model.get_model_info(), 72 | # 是否某个属性是NONE , 73 | "model_class": _model.model_type.value, 74 | } 75 | if show_speaker: 76 | _item["speaker"] = _model.get_speaker_list() 77 | _item["speaker_num"]: _model.n_speakers 78 | if show_ms_config: 79 | _item["ms_config"] = _model.hps_ms_config 80 | _data.append( 81 | _item 82 | ) 83 | return {"code": 0, "message": "success", "data": _data} 84 | 85 | 86 | # 获取模型名称对应的设置参数 87 | @app.get("/model/info") 88 | def tts_model_info(model_id: str): 89 | global _Model_list 90 | server_build = _Model_list.get(model_id) 91 | server_build: TtsGenerate 92 | if not server_build: 93 | return {"code": -1, "message": "Not Found!", "data": {}} 94 | return {"code": 0, "message": "success", "data": server_build.hps_ms_config} 95 | 96 | 97 | # 处理传入文本为Vits格式包装 98 | @app.post("/tts/parse") 99 | def tts_parse(text: str, strip: bool = False, 100 | merge_same: bool = False, cell_limit: int = 140, 101 | filter_space: bool = True): 102 | _result = {} 103 | try: 104 | parse = Parse() 105 | _merge = parse.create_cell(sentence=text, merge_same=merge_same, cell_limit=cell_limit, 106 | filter_space=filter_space) 107 | _result["detect_code"] = DetectSentence.detect_code(text) 108 | _result["parse"] = _merge 109 | _result["raw_text"] = text 110 | _result["result"] = parse.build_sentence(_merge, strip=strip) 111 | except Exception as e: 112 | logger.exception(e) 113 | # raise HTTPException(status_code=500, detail="Error When Process Text!") 114 | return {"code": -1, "message": "Error!", "data": {}} 115 | return {"code": 0, "message": "success", "data": _result} 116 | 117 | 118 | @app.post("/svc/convert") 119 | async def svc_convert(model_id: Annotated[str, Form()], 120 | noise_scale: Annotated[float, Form()] = 0.667, 121 | length_scale: Annotated[float, Form()] = 1.4, 122 | audio_type: Annotated[str, Form()] = "wav", 123 | speaker_id: Annotated[int, Form()] = 0, file: UploadFile = File(...)): 124 | global _Model_list 125 | tts_req = ConvertSchema() 126 | tts_req.model_id = model_id 127 | tts_req.noise_scale = noise_scale 128 | tts_req.length_scale = length_scale 129 | tts_req.audio_type = audio_type 130 | tts_req.speaker_id = speaker_id 131 | server_build = _Model_list.get(tts_req.model_id, None) 132 | server_build: TtsGenerate 133 | # 检查模型是否存在 134 | if not server_build: 135 | raise HTTPException(status_code=404, detail="Model Not Found!") 136 | if not server_build.hubert: 137 | raise HTTPException(status_code=404, detail="Hubert Model Not Found!") 138 | # 检查请求合法性 139 | if not file: 140 | raise HTTPException(status_code=400, detail="Text is Empty!") 141 | # if tts_req.audio_type not in TtsSchema().audio_type: 142 | # raise HTTPException(status_code=400, detail="Audio Type is Invalid!") 143 | conv = io.BytesIO() 144 | with open(file.filename, "wb") as buffer: 145 | shutil.copyfileobj(conv, buffer) 146 | _task = ConvertTask(infer_sample=conv, 147 | speaker_ids=tts_req.speaker_id, 148 | audio_type=tts_req.audio_type, 149 | length_scale=tts_req.length_scale, 150 | noise_scale=tts_req.noise_scale, 151 | ) 152 | # tts_order.delay(_task) 153 | _result = server_build.infer_task(_task) 154 | return StreamingResponse(_result, media_type="application/octet-stream") 155 | 156 | 157 | @app.post("/tts/generate") 158 | async def tts(tts_req: TtsSchema, auto_parse: bool = False): 159 | global _Model_list 160 | server_build = _Model_list.get(tts_req.model_id, None) 161 | server_build: TtsGenerate 162 | # 检查模型是否存在 163 | if not server_build: 164 | raise HTTPException(status_code=404, detail="Model Not Found!") 165 | 166 | # 检查模型 167 | # if server_build.hubert: 168 | # raise HTTPException(status_code=400, detail="self.n_symbols==0 and Hubert Model Be Found!") 169 | 170 | # 检查请求合法性 171 | if not tts_req.text: 172 | raise HTTPException(status_code=400, detail="Text is Empty!") 173 | # if tts_req.audio_type not in TtsSchema().audio_type: 174 | # raise HTTPException(status_code=400, detail="Audio Type is Invalid!") 175 | if auto_parse: 176 | _task = server_build.create_vits_task(c_text=tts_req.text, 177 | speaker_ids=tts_req.speaker_id, 178 | audio_type=tts_req.audio_type, 179 | length_scale=tts_req.length_scale, 180 | noise_scale=tts_req.noise_scale, 181 | noise_scale_w=tts_req.noise_scale_w 182 | ) 183 | else: 184 | _task = [InferTask(infer_sample=tts_req.text, 185 | speaker_ids=tts_req.speaker_id, 186 | audio_type=tts_req.audio_type, 187 | length_scale=tts_req.length_scale, 188 | noise_scale=tts_req.noise_scale, 189 | noise_scale_w=tts_req.noise_scale_w)] 190 | # 检查 speaker_id 合法性 191 | if tts_req.speaker_id >= server_build.n_speakers: 192 | raise HTTPException(status_code=400, detail="Speaker ID is Invalid!") 193 | try: 194 | server_build.load_prefer = tts_req.load_prefer 195 | _result = server_build.infer_task_bat( 196 | task_list=_task 197 | ) 198 | except Exception as e: 199 | logger.exception(e) 200 | raise HTTPException(status_code=500, detail="Error When Generate Voice!") 201 | else: 202 | _result.seek(0) 203 | return StreamingResponse(_result, media_type="application/octet-stream") 204 | -------------------------------------------------------------------------------- /text/mandarin.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | 5 | import cn2an 6 | import jieba 7 | from pypinyin import lazy_pinyin, BOPOMOFO 8 | 9 | logging.getLogger('jieba').setLevel(logging.WARNING) 10 | # 判定当前模块文件目录词库是否存在 11 | if os.path.exists(os.path.join(os.path.dirname(__file__), 'jieba/dict.txt')): 12 | jieba.set_dictionary(os.path.join(os.path.dirname(__file__), 'jieba/dict.txt')) 13 | else: 14 | raise FileNotFoundError('jieba/dict.txt not found') 15 | jieba.initialize() 16 | 17 | # List of (Latin alphabet, bopomofo) pairs: 18 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 19 | ('a', 'ㄟˉ'), 20 | ('b', 'ㄅㄧˋ'), 21 | ('c', 'ㄙㄧˉ'), 22 | ('d', 'ㄉㄧˋ'), 23 | ('e', 'ㄧˋ'), 24 | ('f', 'ㄝˊㄈㄨˋ'), 25 | ('g', 'ㄐㄧˋ'), 26 | ('h', 'ㄝˇㄑㄩˋ'), 27 | ('i', 'ㄞˋ'), 28 | ('j', 'ㄐㄟˋ'), 29 | ('k', 'ㄎㄟˋ'), 30 | ('l', 'ㄝˊㄛˋ'), 31 | ('m', 'ㄝˊㄇㄨˋ'), 32 | ('n', 'ㄣˉ'), 33 | ('o', 'ㄡˉ'), 34 | ('p', 'ㄆㄧˉ'), 35 | ('q', 'ㄎㄧㄡˉ'), 36 | ('r', 'ㄚˋ'), 37 | ('s', 'ㄝˊㄙˋ'), 38 | ('t', 'ㄊㄧˋ'), 39 | ('u', 'ㄧㄡˉ'), 40 | ('v', 'ㄨㄧˉ'), 41 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), 42 | ('x', 'ㄝˉㄎㄨˋㄙˋ'), 43 | ('y', 'ㄨㄞˋ'), 44 | ('z', 'ㄗㄟˋ') 45 | ]] 46 | 47 | # List of (bopomofo, romaji) pairs: 48 | _bopomofo_to_romaji = [(re.compile('%s' % x[0]), x[1]) for x in [ 49 | ('ㄅㄛ', 'p⁼wo'), 50 | ('ㄆㄛ', 'pʰwo'), 51 | ('ㄇㄛ', 'mwo'), 52 | ('ㄈㄛ', 'fwo'), 53 | ('ㄅ', 'p⁼'), 54 | ('ㄆ', 'pʰ'), 55 | ('ㄇ', 'm'), 56 | ('ㄈ', 'f'), 57 | ('ㄉ', 't⁼'), 58 | ('ㄊ', 'tʰ'), 59 | ('ㄋ', 'n'), 60 | ('ㄌ', 'l'), 61 | ('ㄍ', 'k⁼'), 62 | ('ㄎ', 'kʰ'), 63 | ('ㄏ', 'h'), 64 | ('ㄐ', 'ʧ⁼'), 65 | ('ㄑ', 'ʧʰ'), 66 | ('ㄒ', 'ʃ'), 67 | ('ㄓ', 'ʦ`⁼'), 68 | ('ㄔ', 'ʦ`ʰ'), 69 | ('ㄕ', 's`'), 70 | ('ㄖ', 'ɹ`'), 71 | ('ㄗ', 'ʦ⁼'), 72 | ('ㄘ', 'ʦʰ'), 73 | ('ㄙ', 's'), 74 | ('ㄚ', 'a'), 75 | ('ㄛ', 'o'), 76 | ('ㄜ', 'ə'), 77 | ('ㄝ', 'e'), 78 | ('ㄞ', 'ai'), 79 | ('ㄟ', 'ei'), 80 | ('ㄠ', 'au'), 81 | ('ㄡ', 'ou'), 82 | ('ㄧㄢ', 'yeNN'), 83 | ('ㄢ', 'aNN'), 84 | ('ㄧㄣ', 'iNN'), 85 | ('ㄣ', 'əNN'), 86 | ('ㄤ', 'aNg'), 87 | ('ㄧㄥ', 'iNg'), 88 | ('ㄨㄥ', 'uNg'), 89 | ('ㄩㄥ', 'yuNg'), 90 | ('ㄥ', 'əNg'), 91 | ('ㄦ', 'əɻ'), 92 | ('ㄧ', 'i'), 93 | ('ㄨ', 'u'), 94 | ('ㄩ', 'ɥ'), 95 | ('ˉ', '→'), 96 | ('ˊ', '↑'), 97 | ('ˇ', '↓↑'), 98 | ('ˋ', '↓'), 99 | ('˙', ''), 100 | (',', ','), 101 | ('。', '.'), 102 | ('!', '!'), 103 | ('?', '?'), 104 | ('—', '-') 105 | ]] 106 | 107 | # List of (romaji, ipa) pairs: 108 | _romaji_to_ipa = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 109 | ('ʃy', 'ʃ'), 110 | ('ʧʰy', 'ʧʰ'), 111 | ('ʧ⁼y', 'ʧ⁼'), 112 | ('NN', 'n'), 113 | ('Ng', 'ŋ'), 114 | ('y', 'j'), 115 | ('h', 'x') 116 | ]] 117 | 118 | # List of (bopomofo, ipa) pairs: 119 | _bopomofo_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 120 | ('ㄅㄛ', 'p⁼wo'), 121 | ('ㄆㄛ', 'pʰwo'), 122 | ('ㄇㄛ', 'mwo'), 123 | ('ㄈㄛ', 'fwo'), 124 | ('ㄅ', 'p⁼'), 125 | ('ㄆ', 'pʰ'), 126 | ('ㄇ', 'm'), 127 | ('ㄈ', 'f'), 128 | ('ㄉ', 't⁼'), 129 | ('ㄊ', 'tʰ'), 130 | ('ㄋ', 'n'), 131 | ('ㄌ', 'l'), 132 | ('ㄍ', 'k⁼'), 133 | ('ㄎ', 'kʰ'), 134 | ('ㄏ', 'x'), 135 | ('ㄐ', 'tʃ⁼'), 136 | ('ㄑ', 'tʃʰ'), 137 | ('ㄒ', 'ʃ'), 138 | ('ㄓ', 'ts`⁼'), 139 | ('ㄔ', 'ts`ʰ'), 140 | ('ㄕ', 's`'), 141 | ('ㄖ', 'ɹ`'), 142 | ('ㄗ', 'ts⁼'), 143 | ('ㄘ', 'tsʰ'), 144 | ('ㄙ', 's'), 145 | ('ㄚ', 'a'), 146 | ('ㄛ', 'o'), 147 | ('ㄜ', 'ə'), 148 | ('ㄝ', 'ɛ'), 149 | ('ㄞ', 'aɪ'), 150 | ('ㄟ', 'eɪ'), 151 | ('ㄠ', 'ɑʊ'), 152 | ('ㄡ', 'oʊ'), 153 | ('ㄧㄢ', 'jɛn'), 154 | ('ㄩㄢ', 'ɥæn'), 155 | ('ㄢ', 'an'), 156 | ('ㄧㄣ', 'in'), 157 | ('ㄩㄣ', 'ɥn'), 158 | ('ㄣ', 'ən'), 159 | ('ㄤ', 'ɑŋ'), 160 | ('ㄧㄥ', 'iŋ'), 161 | ('ㄨㄥ', 'ʊŋ'), 162 | ('ㄩㄥ', 'jʊŋ'), 163 | ('ㄥ', 'əŋ'), 164 | ('ㄦ', 'əɻ'), 165 | ('ㄧ', 'i'), 166 | ('ㄨ', 'u'), 167 | ('ㄩ', 'ɥ'), 168 | ('ˉ', '→'), 169 | ('ˊ', '↑'), 170 | ('ˇ', '↓↑'), 171 | ('ˋ', '↓'), 172 | ('˙', ''), 173 | (',', ','), 174 | ('。', '.'), 175 | ('!', '!'), 176 | ('?', '?'), 177 | ('—', '-') 178 | ]] 179 | 180 | # List of (bopomofo, ipa2) pairs: 181 | _bopomofo_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 182 | ('ㄅㄛ', 'pwo'), 183 | ('ㄆㄛ', 'pʰwo'), 184 | ('ㄇㄛ', 'mwo'), 185 | ('ㄈㄛ', 'fwo'), 186 | ('ㄅ', 'p'), 187 | ('ㄆ', 'pʰ'), 188 | ('ㄇ', 'm'), 189 | ('ㄈ', 'f'), 190 | ('ㄉ', 't'), 191 | ('ㄊ', 'tʰ'), 192 | ('ㄋ', 'n'), 193 | ('ㄌ', 'l'), 194 | ('ㄍ', 'k'), 195 | ('ㄎ', 'kʰ'), 196 | ('ㄏ', 'h'), 197 | ('ㄐ', 'tɕ'), 198 | ('ㄑ', 'tɕʰ'), 199 | ('ㄒ', 'ɕ'), 200 | ('ㄓ', 'tʂ'), 201 | ('ㄔ', 'tʂʰ'), 202 | ('ㄕ', 'ʂ'), 203 | ('ㄖ', 'ɻ'), 204 | ('ㄗ', 'ts'), 205 | ('ㄘ', 'tsʰ'), 206 | ('ㄙ', 's'), 207 | ('ㄚ', 'a'), 208 | ('ㄛ', 'o'), 209 | ('ㄜ', 'ɤ'), 210 | ('ㄝ', 'ɛ'), 211 | ('ㄞ', 'aɪ'), 212 | ('ㄟ', 'eɪ'), 213 | ('ㄠ', 'ɑʊ'), 214 | ('ㄡ', 'oʊ'), 215 | ('ㄧㄢ', 'jɛn'), 216 | ('ㄩㄢ', 'yæn'), 217 | ('ㄢ', 'an'), 218 | ('ㄧㄣ', 'in'), 219 | ('ㄩㄣ', 'yn'), 220 | ('ㄣ', 'ən'), 221 | ('ㄤ', 'ɑŋ'), 222 | ('ㄧㄥ', 'iŋ'), 223 | ('ㄨㄥ', 'ʊŋ'), 224 | ('ㄩㄥ', 'jʊŋ'), 225 | ('ㄥ', 'ɤŋ'), 226 | ('ㄦ', 'əɻ'), 227 | ('ㄧ', 'i'), 228 | ('ㄨ', 'u'), 229 | ('ㄩ', 'y'), 230 | ('ˉ', '˥'), 231 | ('ˊ', '˧˥'), 232 | ('ˇ', '˨˩˦'), 233 | ('ˋ', '˥˩'), 234 | ('˙', ''), 235 | (',', ','), 236 | ('。', '.'), 237 | ('!', '!'), 238 | ('?', '?'), 239 | ('—', '-') 240 | ]] 241 | 242 | 243 | def number_to_chinese(text): 244 | numbers = re.findall(r'\d+(?:\.?\d+)?', text) 245 | for number in numbers: 246 | text = text.replace(number, cn2an.an2cn(number), 1) 247 | return text 248 | 249 | 250 | def chinese_to_bopomofo(text): 251 | text = text.replace('、', ',').replace(';', ',').replace(':', ',') 252 | words = jieba.lcut(text, cut_all=False) 253 | text = '' 254 | for word in words: 255 | bopomofos = lazy_pinyin(word, BOPOMOFO) 256 | if not re.search('[\u4e00-\u9fff]', word): 257 | text += word 258 | continue 259 | for i in range(len(bopomofos)): 260 | bopomofos[i] = re.sub(r'([\u3105-\u3129])$', r'\1ˉ', bopomofos[i]) 261 | if text != '': 262 | text += ' ' 263 | text += ''.join(bopomofos) 264 | return text 265 | 266 | 267 | def latin_to_bopomofo(text): 268 | for regex, replacement in _latin_to_bopomofo: 269 | text = re.sub(regex, replacement, text) 270 | return text 271 | 272 | 273 | def bopomofo_to_romaji(text): 274 | for regex, replacement in _bopomofo_to_romaji: 275 | text = re.sub(regex, replacement, text) 276 | return text 277 | 278 | 279 | def bopomofo_to_ipa(text): 280 | for regex, replacement in _bopomofo_to_ipa: 281 | text = re.sub(regex, replacement, text) 282 | return text 283 | 284 | 285 | def bopomofo_to_ipa2(text): 286 | for regex, replacement in _bopomofo_to_ipa2: 287 | text = re.sub(regex, replacement, text) 288 | return text 289 | 290 | 291 | def chinese_to_romaji(text): 292 | text = number_to_chinese(text) 293 | text = chinese_to_bopomofo(text) 294 | text = latin_to_bopomofo(text) 295 | text = bopomofo_to_romaji(text) 296 | text = re.sub('i([aoe])', r'y\1', text) 297 | text = re.sub('u([aoəe])', r'w\1', text) 298 | text = re.sub('([ʦsɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 299 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 300 | text = re.sub('([ʦs][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 301 | return text 302 | 303 | 304 | def chinese_to_lazy_ipa(text): 305 | text = chinese_to_romaji(text) 306 | for regex, replacement in _romaji_to_ipa: 307 | text = re.sub(regex, replacement, text) 308 | return text 309 | 310 | 311 | def chinese_to_ipa(text): 312 | text = number_to_chinese(text) 313 | text = chinese_to_bopomofo(text) 314 | text = latin_to_bopomofo(text) 315 | text = bopomofo_to_ipa(text) 316 | text = re.sub('i([aoe])', r'j\1', text) 317 | text = re.sub('u([aoəe])', r'w\1', text) 318 | text = re.sub('([sɹ]`[⁼ʰ]?)([→↓↑ ]+|$)', 319 | r'\1ɹ`\2', text).replace('ɻ', 'ɹ`') 320 | text = re.sub('([s][⁼ʰ]?)([→↓↑ ]+|$)', r'\1ɹ\2', text) 321 | return text 322 | 323 | 324 | def chinese_to_ipa2(text): 325 | text = number_to_chinese(text) 326 | text = chinese_to_bopomofo(text) 327 | text = latin_to_bopomofo(text) 328 | text = bopomofo_to_ipa2(text) 329 | text = re.sub(r'i([aoe])', r'j\1', text) 330 | text = re.sub(r'u([aoəe])', r'w\1', text) 331 | text = re.sub(r'([ʂɹ]ʰ?)([˩˨˧˦˥ ]+|$)', r'\1ʅ\2', text) 332 | text = re.sub(r'(sʰ?)([˩˨˧˦˥ ]+|$)', r'\1ɿ\2', text) 333 | return text 334 | -------------------------------------------------------------------------------- /onnx_infer/infer/transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | DEFAULT_MIN_BIN_WIDTH = 1e-3 7 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 8 | DEFAULT_MIN_DERIVATIVE = 1e-3 9 | 10 | 11 | def piecewise_rational_quadratic_transform(inputs, 12 | unnormalized_widths, 13 | unnormalized_heights, 14 | unnormalized_derivatives, 15 | inverse=False, 16 | tails=None, 17 | tail_bound=1., 18 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 19 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 20 | min_derivative=DEFAULT_MIN_DERIVATIVE): 21 | if tails is None: 22 | spline_fn = rational_quadratic_spline 23 | spline_kwargs = {} 24 | else: 25 | spline_fn = unconstrained_rational_quadratic_spline 26 | spline_kwargs = { 27 | 'tails': tails, 28 | 'tail_bound': tail_bound 29 | } 30 | 31 | outputs, logabsdet = spline_fn( 32 | inputs=inputs, 33 | unnormalized_widths=unnormalized_widths, 34 | unnormalized_heights=unnormalized_heights, 35 | unnormalized_derivatives=unnormalized_derivatives, 36 | inverse=inverse, 37 | min_bin_width=min_bin_width, 38 | min_bin_height=min_bin_height, 39 | min_derivative=min_derivative, 40 | **spline_kwargs 41 | ) 42 | return outputs, logabsdet 43 | 44 | 45 | def searchsorted(bin_locations: torch.Tensor, inputs: torch.Tensor, eps=1e-6) -> torch.Tensor: 46 | bin_locations[..., -1] += eps 47 | return torch.sum( 48 | inputs[..., None] >= bin_locations, 49 | dim=-1 50 | ) - 1 51 | 52 | 53 | def unconstrained_rational_quadratic_spline(inputs, 54 | unnormalized_widths, 55 | unnormalized_heights, 56 | unnormalized_derivatives, 57 | inverse=False, 58 | tails='linear', 59 | tail_bound=1., 60 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 61 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 62 | min_derivative=DEFAULT_MIN_DERIVATIVE): 63 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 64 | outside_interval_mask = ~inside_interval_mask 65 | 66 | outputs = torch.zeros_like(inputs) 67 | logabsdet = torch.zeros_like(inputs) 68 | 69 | if tails == 'linear': 70 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 71 | constant = np.log(np.exp(1 - min_derivative) - 1) 72 | unnormalized_derivatives[..., 0] = constant 73 | unnormalized_derivatives[..., -1] = constant 74 | 75 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 76 | logabsdet[outside_interval_mask] = 0 77 | else: 78 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 79 | 80 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 81 | inputs=inputs[inside_interval_mask], 82 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 83 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 84 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 85 | inverse=inverse, 86 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 87 | min_bin_width=min_bin_width, 88 | min_bin_height=min_bin_height, 89 | min_derivative=min_derivative 90 | ) 91 | 92 | return outputs, logabsdet 93 | 94 | 95 | def rational_quadratic_spline(inputs, 96 | unnormalized_widths, 97 | unnormalized_heights, 98 | unnormalized_derivatives, 99 | inverse=False, 100 | left=0., right=1., bottom=0., top=1., 101 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 102 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 103 | min_derivative=DEFAULT_MIN_DERIVATIVE): 104 | if torch.min(inputs) < left or torch.max(inputs) > right: 105 | raise ValueError('Input to a transform is not within its domain') 106 | 107 | num_bins = unnormalized_widths.shape[-1] 108 | 109 | if min_bin_width * num_bins > 1.0: 110 | raise ValueError('Minimal bin width too large for the number of bins') 111 | if min_bin_height * num_bins > 1.0: 112 | raise ValueError('Minimal bin height too large for the number of bins') 113 | 114 | widths = F.softmax(unnormalized_widths, dim=-1) 115 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 116 | cumwidths = torch.cumsum(widths, dim=-1) 117 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 118 | cumwidths = (right - left) * cumwidths + left 119 | cumwidths[..., 0] = left 120 | cumwidths[..., -1] = right 121 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 122 | 123 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 124 | 125 | heights = F.softmax(unnormalized_heights, dim=-1) 126 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 127 | cumheights = torch.cumsum(heights, dim=-1) 128 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 129 | cumheights = (top - bottom) * cumheights + bottom 130 | cumheights[..., 0] = bottom 131 | cumheights[..., -1] = top 132 | heights = cumheights[..., 1:] - cumheights[..., :-1] 133 | 134 | if inverse: 135 | bin_idx = searchsorted(cumheights, inputs)[..., None] 136 | else: 137 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 138 | 139 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 140 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 141 | 142 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 143 | delta = heights / widths 144 | input_delta = delta.gather(-1, bin_idx)[..., 0] 145 | 146 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 147 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 148 | 149 | input_heights = heights.gather(-1, bin_idx)[..., 0] 150 | 151 | if inverse: 152 | a = (((inputs - input_cumheights) * (input_derivatives 153 | + input_derivatives_plus_one 154 | - 2 * input_delta) 155 | + input_heights * (input_delta - input_derivatives))) 156 | b = (input_heights * input_derivatives 157 | - (inputs - input_cumheights) * (input_derivatives 158 | + input_derivatives_plus_one 159 | - 2 * input_delta)) 160 | c = - input_delta * (inputs - input_cumheights) 161 | 162 | discriminant = b.pow(2) - 4 * a * c 163 | assert (discriminant >= 0).all() 164 | 165 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 166 | outputs = root * input_bin_widths + input_cumwidths 167 | 168 | theta_one_minus_theta = root * (1 - root) 169 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 170 | * theta_one_minus_theta) 171 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 172 | + 2 * input_delta * theta_one_minus_theta 173 | + input_derivatives * (1 - root).pow(2)) 174 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 175 | 176 | return outputs, -logabsdet 177 | else: 178 | theta = (inputs - input_cumwidths) / input_bin_widths 179 | theta_one_minus_theta = theta * (1 - theta) 180 | 181 | numerator = input_heights * (input_delta * theta.pow(2) 182 | + input_derivatives * theta_one_minus_theta) 183 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 184 | * theta_one_minus_theta) 185 | outputs = input_cumheights + numerator / denominator 186 | 187 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 188 | + 2 * input_delta * theta_one_minus_theta 189 | + input_derivatives * (1 - theta).pow(2)) 190 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 191 | 192 | return outputs, logabsdet 193 | -------------------------------------------------------------------------------- /onnx_infer/utils/onnx_transforms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 作为原始副本存在,此文件被禁用 3 | import numpy as np 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | unnormalized_derivatives, 16 | inverse=False, 17 | tails=None, 18 | tail_bound=1., 19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 21 | min_derivative=DEFAULT_MIN_DERIVATIVE): 22 | if tails is None: 23 | spline_fn = rational_quadratic_spline 24 | spline_kwargs = {} 25 | else: 26 | spline_fn = unconstrained_rational_quadratic_spline 27 | spline_kwargs = { 28 | 'tails': tails, 29 | 'tail_bound': tail_bound 30 | } 31 | 32 | outputs, logabsdet = spline_fn( 33 | inputs=inputs, 34 | unnormalized_widths=unnormalized_widths, 35 | unnormalized_heights=unnormalized_heights, 36 | unnormalized_derivatives=unnormalized_derivatives, 37 | inverse=inverse, 38 | min_bin_width=min_bin_width, 39 | min_bin_height=min_bin_height, 40 | min_derivative=min_derivative, 41 | **spline_kwargs 42 | ) 43 | return outputs, logabsdet 44 | 45 | 46 | def searchsorted(bin_locations: torch.Tensor, inputs: torch.Tensor, eps=1e-6) -> torch.Tensor: 47 | bin_locations[..., -1] += eps 48 | return torch.sum( 49 | inputs[..., None] >= bin_locations, 50 | dim=-1 51 | ) - 1 52 | 53 | 54 | def unconstrained_rational_quadratic_spline(inputs, 55 | unnormalized_widths, 56 | unnormalized_heights, 57 | unnormalized_derivatives, 58 | inverse=False, 59 | tails='linear', 60 | tail_bound=1., 61 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 62 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 63 | min_derivative=DEFAULT_MIN_DERIVATIVE): 64 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 65 | outside_interval_mask = ~inside_interval_mask 66 | 67 | outputs = torch.zeros_like(inputs) 68 | logabsdet = torch.zeros_like(inputs) 69 | 70 | if tails == 'linear': 71 | # unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 72 | unnormalized_derivatives_ = torch.zeros( 73 | (1, 1, unnormalized_derivatives.size(2), unnormalized_derivatives.size(3) + 2)) 74 | unnormalized_derivatives_[..., 1:-1] = unnormalized_derivatives 75 | unnormalized_derivatives = unnormalized_derivatives_ 76 | constant = np.log(np.exp(1 - min_derivative) - 1) 77 | unnormalized_derivatives[..., 0] = constant 78 | unnormalized_derivatives[..., -1] = constant 79 | 80 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 81 | logabsdet[outside_interval_mask] = 0 82 | else: 83 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 84 | 85 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 86 | inputs=inputs[inside_interval_mask], 87 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 88 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 89 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 90 | inverse=inverse, 91 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 92 | min_bin_width=min_bin_width, 93 | min_bin_height=min_bin_height, 94 | min_derivative=min_derivative 95 | ) 96 | 97 | return outputs, logabsdet 98 | 99 | 100 | def rational_quadratic_spline(inputs, 101 | unnormalized_widths, 102 | unnormalized_heights, 103 | unnormalized_derivatives, 104 | inverse=False, 105 | left=0., right=1., bottom=0., top=1., 106 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 107 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 108 | min_derivative=DEFAULT_MIN_DERIVATIVE): 109 | if torch.min(inputs) < left or torch.max(inputs) > right: 110 | raise ValueError('Input to a transform is not within its domain') 111 | 112 | num_bins = unnormalized_widths.shape[-1] 113 | 114 | if min_bin_width * num_bins > 1.0: 115 | raise ValueError('Minimal bin width too large for the number of bins') 116 | if min_bin_height * num_bins > 1.0: 117 | raise ValueError('Minimal bin height too large for the number of bins') 118 | 119 | widths = F.softmax(unnormalized_widths, dim=-1) 120 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 121 | cumwidths = torch.cumsum(widths, dim=-1) 122 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 123 | cumwidths = (right - left) * cumwidths + left 124 | cumwidths[..., 0] = left 125 | cumwidths[..., -1] = right 126 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 127 | 128 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 129 | 130 | heights = F.softmax(unnormalized_heights, dim=-1) 131 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 132 | cumheights = torch.cumsum(heights, dim=-1) 133 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 134 | cumheights = (top - bottom) * cumheights + bottom 135 | cumheights[..., 0] = bottom 136 | cumheights[..., -1] = top 137 | heights = cumheights[..., 1:] - cumheights[..., :-1] 138 | 139 | if inverse: 140 | bin_idx = searchsorted(cumheights, inputs)[..., None] 141 | else: 142 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 143 | 144 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 145 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 146 | 147 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 148 | delta = heights / widths 149 | input_delta = delta.gather(-1, bin_idx)[..., 0] 150 | 151 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 152 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 153 | 154 | input_heights = heights.gather(-1, bin_idx)[..., 0] 155 | 156 | if inverse: 157 | a = (((inputs - input_cumheights) * (input_derivatives 158 | + input_derivatives_plus_one 159 | - 2 * input_delta) 160 | + input_heights * (input_delta - input_derivatives))) 161 | b = (input_heights * input_derivatives 162 | - (inputs - input_cumheights) * (input_derivatives 163 | + input_derivatives_plus_one 164 | - 2 * input_delta)) 165 | c = - input_delta * (inputs - input_cumheights) 166 | 167 | discriminant = b.pow(2) - 4 * a * c 168 | assert (discriminant >= 0).all() 169 | 170 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 171 | outputs = root * input_bin_widths + input_cumwidths 172 | 173 | theta_one_minus_theta = root * (1 - root) 174 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 175 | * theta_one_minus_theta) 176 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 177 | + 2 * input_delta * theta_one_minus_theta 178 | + input_derivatives * (1 - root).pow(2)) 179 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 180 | 181 | return outputs, -logabsdet 182 | else: 183 | theta = (inputs - input_cumwidths) / input_bin_widths 184 | theta_one_minus_theta = theta * (1 - theta) 185 | 186 | numerator = input_heights * (input_delta * theta.pow(2) 187 | + input_derivatives * theta_one_minus_theta) 188 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 189 | * theta_one_minus_theta) 190 | outputs = input_cumheights + numerator / denominator 191 | 192 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 193 | + 2 * input_delta * theta_one_minus_theta 194 | + input_derivatives * (1 - theta).pow(2)) 195 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 196 | 197 | return outputs, logabsdet 198 | -------------------------------------------------------------------------------- /test/onnx_test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 11, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import logging\n", 10 | "logging.getLogger('numba').setLevel(logging.WARNING)\n", 11 | "%matplotlib inline\n", 12 | "import IPython.display as ipd\n", 13 | "import torch\n", 14 | "import commons\n", 15 | "import utils\n", 16 | "import ONNXVITS_infer\n", 17 | "from text import text_to_sequence\n", 18 | "\n", 19 | "def get_text(text, hps):\n", 20 | " text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)\n", 21 | " if hps.data.add_blank:\n", 22 | " text_norm = commons.intersperse(text_norm, 0)\n", 23 | " text_norm = torch.LongTensor(text_norm)\n", 24 | " return text_norm" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 12, 30 | "outputs": [ 31 | { 32 | "name": "stdout", 33 | "output_type": "stream", 34 | "text": [ 35 | "Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple\r\n", 36 | "Requirement already satisfied: matplotlib in /home/nano/miniconda3/envs/vits/lib/python3.10/site-packages (3.7.1)\r\n", 37 | "Requirement already satisfied: contourpy>=1.0.1 in /home/nano/miniconda3/envs/vits/lib/python3.10/site-packages (from matplotlib) (1.0.7)\r\n", 38 | "Requirement already satisfied: numpy>=1.20 in /home/nano/miniconda3/envs/vits/lib/python3.10/site-packages (from matplotlib) (1.23.5)\r\n", 39 | "Requirement already satisfied: pillow>=6.2.0 in /home/nano/miniconda3/envs/vits/lib/python3.10/site-packages (from matplotlib) (9.4.0)\r\n", 40 | "Requirement already satisfied: fonttools>=4.22.0 in /home/nano/miniconda3/envs/vits/lib/python3.10/site-packages (from matplotlib) (4.39.2)\r\n", 41 | "Requirement already satisfied: pyparsing>=2.3.1 in /home/nano/miniconda3/envs/vits/lib/python3.10/site-packages (from matplotlib) (3.0.9)\r\n", 42 | "Requirement already satisfied: python-dateutil>=2.7 in /home/nano/miniconda3/envs/vits/lib/python3.10/site-packages (from matplotlib) (2.8.2)\r\n", 43 | "Requirement already satisfied: packaging>=20.0 in /home/nano/.local/lib/python3.10/site-packages (from matplotlib) (23.0)\r\n", 44 | "Requirement already satisfied: cycler>=0.10 in /home/nano/miniconda3/envs/vits/lib/python3.10/site-packages (from matplotlib) (0.11.0)\r\n", 45 | "Requirement already satisfied: kiwisolver>=1.0.1 in /home/nano/miniconda3/envs/vits/lib/python3.10/site-packages (from matplotlib) (1.4.4)\r\n", 46 | "Requirement already satisfied: six>=1.5 in /home/nano/miniconda3/envs/vits/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\r\n" 47 | ] 48 | } 49 | ], 50 | "source": [ 51 | "!pip install matplotlib" 52 | ], 53 | "metadata": { 54 | "collapsed": false 55 | } 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 13, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "hps = utils.get_hparams_from_file(\"model/1374_epochs.pth.json\")" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 14, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "net_g = ONNXVITS_infer.SynthesizerTrn(\n", 73 | " len(hps.symbols),\n", 74 | " hps.data.filter_length // 2 + 1,\n", 75 | " hps.train.segment_size // hps.data.hop_length,\n", 76 | " n_speakers=hps.data.n_speakers,\n", 77 | " **hps.model)\n", 78 | "_ = net_g.eval()\n", 79 | "\n", 80 | "_ = utils.load_checkpoint(\"model/1374_epochs.pth\", net_g)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 15, 86 | "metadata": {}, 87 | "outputs": [ 88 | { 89 | "ename": "ValueError", 90 | "evalue": "Model requires 3 inputs. Input Feed contains 1", 91 | "output_type": "error", 92 | "traceback": [ 93 | "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", 94 | "\u001B[0;31mValueError\u001B[0m Traceback (most recent call last)", 95 | "Cell \u001B[0;32mIn[15], line 7\u001B[0m\n\u001B[1;32m 5\u001B[0m x_tst_lengths \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mLongTensor([stn_tst\u001B[38;5;241m.\u001B[39msize(\u001B[38;5;241m0\u001B[39m)])\n\u001B[1;32m 6\u001B[0m sid \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mLongTensor([\u001B[38;5;241m0\u001B[39m])\n\u001B[0;32m----> 7\u001B[0m audio \u001B[38;5;241m=\u001B[39m \u001B[43mnet_g\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43minfer\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx_tst\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mx_tst_lengths\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msid\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43msid\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnoise_scale\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m.667\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnoise_scale_w\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.8\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mlength_scale\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m1\u001B[39;49m\u001B[43m)\u001B[49m[\u001B[38;5;241m0\u001B[39m][\u001B[38;5;241m0\u001B[39m,\u001B[38;5;241m0\u001B[39m]\u001B[38;5;241m.\u001B[39mdata\u001B[38;5;241m.\u001B[39mcpu()\u001B[38;5;241m.\u001B[39mfloat()\u001B[38;5;241m.\u001B[39mnumpy()\n\u001B[1;32m 8\u001B[0m ipd\u001B[38;5;241m.\u001B[39mdisplay(ipd\u001B[38;5;241m.\u001B[39mAudio(audio, rate\u001B[38;5;241m=\u001B[39mhps\u001B[38;5;241m.\u001B[39mdata\u001B[38;5;241m.\u001B[39msampling_rate, normalize\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m))\n", 96 | "File \u001B[0;32m~/Project/PycharmProjects/VitsServer/ONNXVITS_infer.py:74\u001B[0m, in \u001B[0;36mSynthesizerTrn.infer\u001B[0;34m(self, x, x_lengths, sid, noise_scale, length_scale, noise_scale_w, max_len, emotion_embedding)\u001B[0m\n\u001B[1;32m 72\u001B[0m zinput \u001B[38;5;241m=\u001B[39m (torch\u001B[38;5;241m.\u001B[39mrandn(x\u001B[38;5;241m.\u001B[39msize(\u001B[38;5;241m0\u001B[39m), \u001B[38;5;241m2\u001B[39m, x\u001B[38;5;241m.\u001B[39msize(\u001B[38;5;241m2\u001B[39m))\u001B[38;5;241m.\u001B[39mto(device\u001B[38;5;241m=\u001B[39mx\u001B[38;5;241m.\u001B[39mdevice, dtype\u001B[38;5;241m=\u001B[39mx\u001B[38;5;241m.\u001B[39mdtype) \u001B[38;5;241m*\u001B[39m noise_scale_w)\n\u001B[1;32m 73\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m sid \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m---> 74\u001B[0m g \u001B[38;5;241m=\u001B[39m \u001B[43mrunonnx\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43mONNX_net/dp.onnx\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msid\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43msid\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mnumpy\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 75\u001B[0m g \u001B[38;5;241m=\u001B[39m torch\u001B[38;5;241m.\u001B[39mfrom_numpy(g)\u001B[38;5;241m.\u001B[39munsqueeze(\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m 76\u001B[0m logw \u001B[38;5;241m=\u001B[39m runonnx(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mONNX_net/dp.onnx\u001B[39m\u001B[38;5;124m\"\u001B[39m, x\u001B[38;5;241m=\u001B[39mx\u001B[38;5;241m.\u001B[39mnumpy(), x_mask\u001B[38;5;241m=\u001B[39mx_mask\u001B[38;5;241m.\u001B[39mnumpy(), zin\u001B[38;5;241m=\u001B[39mzinput\u001B[38;5;241m.\u001B[39mnumpy(), g\u001B[38;5;241m=\u001B[39mg\u001B[38;5;241m.\u001B[39mnumpy())\n", 97 | "File \u001B[0;32m~/Project/PycharmProjects/VitsServer/ONNXVITS_utils.py:35\u001B[0m, in \u001B[0;36mrunonnx\u001B[0;34m(model_path, x, x_lengths, sid, x_mask, zin, g, y_mask, z_p)\u001B[0m\n\u001B[1;32m 32\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m z_p \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m 33\u001B[0m input_feed[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mz_p\u001B[39m\u001B[38;5;124m'\u001B[39m] \u001B[38;5;241m=\u001B[39m z_p\n\u001B[0;32m---> 35\u001B[0m outputs \u001B[38;5;241m=\u001B[39m \u001B[43msess\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43;01mNone\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minput_feed\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 36\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m outputs\n", 98 | "File \u001B[0;32m~/miniconda3/envs/vits/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py:196\u001B[0m, in \u001B[0;36mSession.run\u001B[0;34m(self, output_names, input_feed, run_options)\u001B[0m\n\u001B[1;32m 194\u001B[0m \u001B[38;5;66;03m# the graph may have optional inputs used to override initializers. allow for that.\u001B[39;00m\n\u001B[1;32m 195\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m num_inputs \u001B[38;5;241m<\u001B[39m num_required_inputs:\n\u001B[0;32m--> 196\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mModel requires \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m inputs. Input Feed contains \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;241m.\u001B[39mformat(num_required_inputs, num_inputs))\n\u001B[1;32m 197\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m output_names:\n\u001B[1;32m 198\u001B[0m output_names \u001B[38;5;241m=\u001B[39m [output\u001B[38;5;241m.\u001B[39mname \u001B[38;5;28;01mfor\u001B[39;00m output \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_outputs_meta]\n", 99 | "\u001B[0;31mValueError\u001B[0m: Model requires 3 inputs. Input Feed contains 1" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "text1 = get_text(\"おはようございます。\", hps)\n", 105 | "stn_tst = text1\n", 106 | "with torch.no_grad():\n", 107 | " x_tst = stn_tst.unsqueeze(0)\n", 108 | " x_tst_lengths = torch.LongTensor([stn_tst.size(0)])\n", 109 | " sid = torch.LongTensor([0])\n", 110 | " audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()\n", 111 | "ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))" 112 | ] 113 | } 114 | ], 115 | "metadata": { 116 | "kernelspec": { 117 | "display_name": "Python 3.7.13 ('tacotron2')", 118 | "language": "python", 119 | "name": "python3" 120 | }, 121 | "language_info": { 122 | "codemirror_mode": { 123 | "name": "ipython", 124 | "version": 3 125 | }, 126 | "file_extension": ".py", 127 | "mimetype": "text/x-python", 128 | "name": "python", 129 | "nbconvert_exporter": "python", 130 | "pygments_lexer": "ipython3", 131 | "version": "3.7.13" 132 | }, 133 | "vscode": { 134 | "interpreter": { 135 | "hash": "8aad0106d9baa662dc9c45cd138d3d95e54a0f2f791dfb890dc91ac1c34ec80a" 136 | } 137 | } 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 4 141 | } 142 | -------------------------------------------------------------------------------- /component/langdetect_fasttext/README.md: -------------------------------------------------------------------------------- 1 | # fasttext-langdetect 2 | 3 | This library is a wrapper for the language detection model trained on fasttext by Facebook. For more information, please 4 | visit: https://fasttext.cc/docs/en/language-identification.html 5 | 6 | ## Supported languages 7 | 8 | ``` 9 | af als am an ar arz as ast av az azb ba bar bcl be bg bh bn bo bpy br bs bxr ca cbk ce cebckb co cs cv cy da de diq dsb dty dv el eml en eo es et eu fa fi fr frr fy ga gd gl gn gom gu gv he hi hif hr hsb ht hu hy ia id ie ilo io is it ja jbo jv ka kk km kn ko krc ku kv kw ky la lb lez li lmo lo lrc lt lv mai mg mhr min mk ml mn mr mrj ms mt mwl my myv mzn nah nap nds ne new nl nn no oc or os pa pam pfl pl pms pnb ps pt qu rm ro ru rue sa sah sc scn sco sd sh si sk sl so sq sr su sv sw ta te tg th tk tl tr tt tyv ug uk ur uz vec vep vi vls vo wa war wuu xal xmf yi yo yue zh 10 | ``` 11 | 12 | ## Install 13 | 14 | ``` 15 | pip install fasttext-langdetect 16 | ``` 17 | 18 | ## Usage 19 | 20 | `detect` method expects UTF-8 data. `low_memory` option enables getting predictions with the compressed version of the 21 | fasttext model by sacrificing the accuracy a bit. 22 | 23 | ``` 24 | from ftlangdetect import detect 25 | 26 | result = detect(text="Bugün hava çok güzel", low_memory=False) 27 | print(result) 28 | > {'lang': 'tr', 'score': 1.00} 29 | 30 | result = detect(text="Bugün hava çok güzel", low_memory=True) 31 | print(result) 32 | > {'lang': 'tr', 'score': 0.9982126951217651} 33 | ``` 34 | 35 | ## Benchmark 36 | 37 | We benchmarked the fasttext model 38 | against [cld2](https://github.com/CLD2Owners/cld2), [langid](https://github.com/saffsd/langid.py), 39 | and [langdetect](https://github.com/Mimino666/langdetect) on Wili-2018 dataset. 40 | 41 | | | fasttext | langid | langdetect | cld2 | 42 | |--------------------------|-------------|-------------|-------------|-----------------| 43 | | Average time (ms) | 0,158273381 | 1,726618705 | 12,44604317 | **0,028776978** | 44 | | 139 langs - not weighted | 76,8 | 61,6 | 37,6 | **80,8** | 45 | | 139 langs - pop weighted | **95,5** | 93,1 | 86,6 | 92,7 | 46 | | 44 langs - not weighted | **93,3** | 89,2 | 81,6 | 91,5 | 47 | | 44 langs - pop weighted | **96,6** | 94,8 | 89,4 | 93,4 | 48 | 49 | - `pop weighted` means recall for each language is multipled 50 | by [its number of speakers](https://en.wikipedia.org/wiki/List_of_languages_by_total_number_of_speakers). 51 | - 139 languages = all languages with ISO 639-1 2-letter code 52 | - 44 languages = top 44 languages spoken in the world 53 | 54 | #### Recall per language 55 | 56 | | lang | cld2 | fasttext | langdetect | langid | 57 | |-------------------------|-------|----------|------------|--------| 58 | | Afrikaans | 0,94 | 0,918 | 0,992 | 0,966 | 59 | | Albanian | 0,958 | 0,966 | 0,964 | 0,954 | 60 | | Amharic | 0,976 | 0,982 | 0 | 0,982 | 61 | | Arabic | 0,994 | 0,998 | 0,998 | 0,996 | 62 | | Aragonese | 0 | 0,43 | 0 | 0,788 | 63 | | Armenian | 0,966 | 0,972 | 0 | 0,968 | 64 | | Assamese | 0,946 | 0,956 | 0 | 0,14 | 65 | | Avar | 0 | 0,626 | 0 | 0 | 66 | | Aymara | 0,596 | 0 | 0 | 0 | 67 | | Azerbaijani | 0,97 | 0,988 | 0 | 0,984 | 68 | | Bashkir | 0,97 | 0,97 | 0 | 0 | 69 | | Basque | 0,978 | 0,99 | 0 | 0,962 | 70 | | Belarusian | 0,94 | 0,97 | 0 | 0,964 | 71 | | Bengali | 0,898 | 0,922 | 0,904 | 0,942 | 72 | | Bhojpuri | 0,716 | 0,15 | 0 | 0 | 73 | | Bokmål | 0,852 | 0,966 | 0,976 | 0,95 | 74 | | Bosnian | 0,422 | 0,108 | 0 | 0,054 | 75 | | Breton | 0,946 | 0,974 | 0 | 0,976 | 76 | | Bulgarian | 0,892 | 0,964 | 0,964 | 0,942 | 77 | | Burmese | 0,998 | 0,998 | 0 | 0 | 78 | | Catalan | 0,882 | 0,95 | 0,93 | 0,928 | 79 | | Central Khmer | 0,876 | 0,878 | 0 | 0,876 | 80 | | Chechen | 0 | 0,99 | 0 | 0 | 81 | | Chuvash | 0 | 0,96 | 0 | 0 | 82 | | Cornish | 0 | 0,792 | 0 | 0 | 83 | | Corsican | 0,88 | 0,016 | 0 | 0 | 84 | | Croatian | 0,688 | 0,806 | 0,982 | 0,932 | 85 | | Czech | 0,978 | 0,986 | 0,984 | 0,982 | 86 | | Danish | 0,886 | 0,958 | 0,95 | 0,896 | 87 | | Dhivehi | 0,996 | 0,998 | 0 | 0 | 88 | | Dutch | 0,9 | 0,978 | 0,968 | 0,97 | 89 | | English | 0,992 | 1 | 0,998 | 0,986 | 90 | | Esperanto | 0,936 | 0,978 | 0 | 0,948 | 91 | | Estonian | 0,918 | 0,952 | 0,948 | 0,932 | 92 | | Faroese | 0,912 | 0 | 0 | 0,618 | 93 | | Finnish | 0,988 | 0,998 | 0,998 | 0,994 | 94 | | French | 0,946 | 0,996 | 0,99 | 0,992 | 95 | | Galician | 0,89 | 0,912 | 0 | 0,93 | 96 | | Georgian | 0,974 | 0,976 | 0 | 0,976 | 97 | | German | 0,958 | 0,984 | 0,978 | 0,978 | 98 | | Guarani | 0,968 | 0,728 | 0 | 0 | 99 | | Gujarati | 0,932 | 0,932 | 0,93 | 0,932 | 100 | | Haitian Creole | 0,988 | 0,536 | 0 | 0,99 | 101 | | Hausa | 0,976 | 0 | 0 | 0 | 102 | | Hebrew | 0,994 | 0,996 | 0,998 | 0,998 | 103 | | Hindi | 0,982 | 0,984 | 0,982 | 0,972 | 104 | | Hungarian | 0,96 | 0,988 | 0,968 | 0,986 | 105 | | Icelandic | 0,984 | 0,996 | 0 | 0,996 | 106 | | Ido | 0 | 0,76 | 0 | 0 | 107 | | Igbo | 0,798 | 0 | 0 | 0 | 108 | | Indonesian | 0,88 | 0,946 | 0,958 | 0,836 | 109 | | Interlingua | 0,27 | 0,688 | 0 | 0 | 110 | | Interlingue | 0,198 | 0,192 | 0 | 0 | 111 | | Irish | 0,968 | 0,978 | 0 | 0,984 | 112 | | Italian | 0,866 | 0,948 | 0,932 | 0,936 | 113 | | Japanese | 0,97 | 0,986 | 0,98 | 0,986 | 114 | | Javanese | 0 | 0,864 | 0 | 0,938 | 115 | | Kannada | 0,998 | 0,998 | 0,998 | 0,998 | 116 | | Kazakh | 0,978 | 0,992 | 0 | 0,916 | 117 | | Kinyarwanda | 0,86 | 0 | 0 | 0,44 | 118 | | Kirghiz | 0,974 | 0,99 | 0 | 0,408 | 119 | | Komi | 0 | 0,544 | 0 | 0 | 120 | | Korean | 0,986 | 0,99 | 0,988 | 0,99 | 121 | | Kurdish | 0 | 0,972 | 0 | 0,976 | 122 | | Lao | 0,84 | 0,842 | 0 | 0,85 | 123 | | Latin | 0,778 | 0,864 | 0 | 0,854 | 124 | | Latvian | 0,98 | 0,992 | 0,992 | 0,99 | 125 | | Limburgan | 0 | 0,324 | 0 | 0 | 126 | | Lingala | 0,85 | 0 | 0 | 0 | 127 | | Lithuanian | 0,96 | 0,976 | 0,974 | 0,97 | 128 | | Luganda | 0,952 | 0 | 0 | 0 | 129 | | Luxembourgish | 0,864 | 0,894 | 0 | 0,93 | 130 | | Macedonian | 0,88 | 0,984 | 0,982 | 0,974 | 131 | | Malagasy | 0,99 | 0,99 | 0 | 0,988 | 132 | | Malay | 0,896 | 0,586 | 0 | 0,39 | 133 | | Malayalam | 0,988 | 0,988 | 0,988 | 0,988 | 134 | | Maltese | 0,962 | 0,966 | 0 | 0,964 | 135 | | Manx | 0,972 | 0,294 | 0 | 0 | 136 | | Maori | 0,994 | 0 | 0 | 0 | 137 | | Marathi | 0,958 | 0,966 | 0,964 | 0,942 | 138 | | Modern Greek | 0,99 | 0,992 | 0,99 | 0,992 | 139 | | Mongolian | 0,964 | 0,994 | 0 | 0,996 | 140 | | Navajo | 0 | 0 | 0 | 0 | 141 | | Nepali (macrolanguage) | 0,96 | 0,98 | 0,978 | 0,922 | 142 | | Northern Sami | 0 | 0 | 0 | 0,866 | 143 | | Norwegian Nynorsk | 0,94 | 0,79 | 0 | 0,796 | 144 | | Occitan | 0,66 | 0,48 | 0 | 0,724 | 145 | | Oriya | 0,96 | 0,958 | 0 | 0,96 | 146 | | Oromo | 0,956 | 0 | 0 | 0 | 147 | | Ossetian | 0 | 0,938 | 0 | 0 | 148 | | Panjabi | 0,994 | 0,994 | 0,994 | 0,994 | 149 | | Persian | 0,992 | 0,998 | 0,996 | 0,998 | 150 | | Polish | 0,982 | 0,998 | 0,998 | 0,992 | 151 | | Portuguese | 0,908 | 0,956 | 0,946 | 0,952 | 152 | | Pushto | 0,938 | 0,922 | 0 | 0,754 | 153 | | Quechua | 0,926 | 0,808 | 0 | 0,852 | 154 | | Romanian | 0,932 | 0,986 | 0,984 | 0,984 | 155 | | Romansh | 0,934 | 0,328 | 0 | 0 | 156 | | Russian | 0,728 | 0,986 | 0,984 | 0,988 | 157 | | Sanskrit | 0,964 | 0,976 | 0 | 0 | 158 | | Sardinian | 0 | 0,01 | 0 | 0 | 159 | | Scottish Gaelic | 0,964 | 0,942 | 0 | 0 | 160 | | Serbian | 0,942 | 0,946 | 0 | 0,902 | 161 | | Serbo-Croatian | 0 | 0,402 | 0 | 0 | 162 | | Shona | 0,844 | 0 | 0 | 0 | 163 | | Sindhi | 0,978 | 0,982 | 0 | 0 | 164 | | Sinhala | 0,962 | 0,962 | 0 | 0,962 | 165 | | Slovak | 0,964 | 0,974 | 0,982 | 0,97 | 166 | | Slovene | 0,876 | 0,966 | 0,968 | 0,946 | 167 | | Somali | 0,924 | 0,696 | 0,956 | 0 | 168 | | Spanish | 0,894 | 0,986 | 0,976 | 0,98 | 169 | | Standard Chinese | 0,946 | 0,984 | 0,746 | 0,978 | 170 | | Sundanese | 0,91 | 0,854 | 0 | 0 | 171 | | Swahili (macrolanguage) | 0,924 | 0,92 | 0,938 | 0,934 | 172 | | Swedish | 0,872 | 0,994 | 0,992 | 0,986 | 173 | | Tagalog | 0,928 | 0,972 | 0,974 | 0,964 | 174 | | Tajik | 0,82 | 0,85 | 0 | 0 | 175 | | Tamil | 0,992 | 0,992 | 0,992 | 0,994 | 176 | | Tatar | 0,978 | 0,984 | 0 | 0 | 177 | | Telugu | 0,958 | 0,958 | 0,958 | 0,96 | 178 | | Thai | 0,988 | 0,988 | 0,988 | 0,988 | 179 | | Tibetan | 0,986 | 0,992 | 0 | 0 | 180 | | Tongan | 0,968 | 0 | 0 | 0 | 181 | | Tswana | 0,928 | 0 | 0 | 0 | 182 | | Turkish | 0,968 | 0,986 | 0,982 | 0,976 | 183 | | Turkmen | 0,94 | 0,936 | 0 | 0 | 184 | | Uighur | 0,978 | 0,986 | 0 | 0,964 | 185 | | Ukrainian | 0,97 | 0,988 | 0,986 | 0,986 | 186 | | Urdu | 0,86 | 0,958 | 0,89 | 0,896 | 187 | | Uzbek | 0,984 | 0,99 | 0 | 0 | 188 | | Vietnamese | 0,978 | 0,986 | 0,984 | 0,984 | 189 | | Volapük | 0,994 | 0,982 | 0 | 0,986 | 190 | | Walloon | 0 | 0,664 | 0 | 0,98 | 191 | | Welsh | 0,98 | 0,992 | 0,992 | 0,984 | 192 | | Western Frisian | 0,888 | 0,956 | 0 | 0 | 193 | | Wolof | 0,926 | 0 | 0 | 0 | 194 | | Xhosa | 0,928 | 0 | 0 | 0,912 | 195 | | Yiddish | 0,956 | 0,958 | 0 | 0 | 196 | | Yoruba | 0,75 | 0,262 | 0 | 0 | 197 | 198 | ## References 199 | 200 | [1] A. Joulin, E. Grave, P. Bojanowski, T. 201 | Mikolov, [Bag of Tricks for Efficient Text Classification](https://arxiv.org/abs/1607.01759) 202 | 203 | ``` 204 | @article{joulin2016bag, 205 | title={Bag of Tricks for Efficient Text Classification}, 206 | author={Joulin, Armand and Grave, Edouard and Bojanowski, Piotr and Mikolov, Tomas}, 207 | journal={arXiv preprint arXiv:1607.01759}, 208 | year={2016} 209 | } 210 | ``` 211 | 212 | [2] A. Joulin, E. Grave, P. Bojanowski, M. Douze, H. Jégou, T. 213 | Mikolov, [FastText.zip: Compressing text classification models](https://arxiv.org/abs/1612.03651) 214 | 215 | ``` 216 | @article{joulin2016fasttext, 217 | title={FastText.zip: Compressing text classification models}, 218 | author={Joulin, Armand and Grave, Edouard and Bojanowski, Piotr and Douze, Matthijs and J{\'e}gou, H{\'e}rve and Mikolov, Tomas}, 219 | journal={arXiv preprint arXiv:1612.03651}, 220 | year={2016} 221 | } 222 | ``` 223 | -------------------------------------------------------------------------------- /onnx_infer/infer/modules.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import Conv1d 7 | from torch.nn import functional as F 8 | from torch.nn.utils import weight_norm, remove_weight_norm 9 | 10 | from . import commons 11 | from .commons import init_weights, get_padding 12 | from .transforms import piecewise_rational_quadratic_transform 13 | 14 | LRELU_SLOPE = 0.1 15 | 16 | 17 | class LayerNorm(nn.Module): 18 | def __init__(self, channels, eps=1e-5): 19 | super().__init__() 20 | self.channels = channels 21 | self.eps = eps 22 | 23 | self.gamma = nn.Parameter(torch.ones(channels)) 24 | self.beta = nn.Parameter(torch.zeros(channels)) 25 | 26 | def forward(self, x): 27 | x = x.transpose(1, -1) 28 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 29 | return x.transpose(1, -1) 30 | 31 | 32 | class ConvReluNorm(nn.Module): 33 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 34 | super().__init__() 35 | self.in_channels = in_channels 36 | self.hidden_channels = hidden_channels 37 | self.out_channels = out_channels 38 | self.kernel_size = kernel_size 39 | self.n_layers = n_layers 40 | self.p_dropout = p_dropout 41 | assert n_layers > 1, "Number of layers should be larger than 0." 42 | 43 | self.conv_layers = nn.ModuleList() 44 | self.norm_layers = nn.ModuleList() 45 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) 46 | self.norm_layers.append(LayerNorm(hidden_channels)) 47 | self.relu_drop = nn.Sequential( 48 | nn.ReLU(), 49 | nn.Dropout(p_dropout)) 50 | for _ in range(n_layers - 1): 51 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) 52 | self.norm_layers.append(LayerNorm(hidden_channels)) 53 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 54 | self.proj.weight.data.zero_() 55 | self.proj.bias.data.zero_() 56 | 57 | def forward(self, x, x_mask): 58 | x_org = x 59 | for i in range(self.n_layers): 60 | x = self.conv_layers[i](x * x_mask) 61 | x = self.norm_layers[i](x) 62 | x = self.relu_drop(x) 63 | x = x_org + self.proj(x) 64 | return x * x_mask 65 | 66 | 67 | class DDSConv(nn.Module): 68 | """ 69 | Dilated and Depth-Separable Convolution 70 | """ 71 | 72 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 73 | super().__init__() 74 | self.channels = channels 75 | self.kernel_size = kernel_size 76 | self.n_layers = n_layers 77 | self.p_dropout = p_dropout 78 | 79 | self.drop = nn.Dropout(p_dropout) 80 | self.convs_sep = nn.ModuleList() 81 | self.convs_1x1 = nn.ModuleList() 82 | self.norms_1 = nn.ModuleList() 83 | self.norms_2 = nn.ModuleList() 84 | for i in range(n_layers): 85 | dilation = kernel_size ** i 86 | padding = (kernel_size * dilation - dilation) // 2 87 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 88 | groups=channels, dilation=dilation, padding=padding 89 | )) 90 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 91 | self.norms_1.append(LayerNorm(channels)) 92 | self.norms_2.append(LayerNorm(channels)) 93 | 94 | def forward(self, x, x_mask, g=None): 95 | if g is not None: 96 | x = x + g 97 | for i in range(self.n_layers): 98 | y = self.convs_sep[i](x * x_mask) 99 | y = self.norms_1[i](y) 100 | y = F.gelu(y) 101 | y = self.convs_1x1[i](y) 102 | y = self.norms_2[i](y) 103 | y = F.gelu(y) 104 | y = self.drop(y) 105 | x = x + y 106 | return x * x_mask 107 | 108 | 109 | class WN(torch.nn.Module): 110 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 111 | super(WN, self).__init__() 112 | assert (kernel_size % 2 == 1) 113 | self.hidden_channels = hidden_channels 114 | self.kernel_size = kernel_size, 115 | self.dilation_rate = dilation_rate 116 | self.n_layers = n_layers 117 | self.gin_channels = gin_channels 118 | self.p_dropout = p_dropout 119 | 120 | self.in_layers = torch.nn.ModuleList() 121 | self.res_skip_layers = torch.nn.ModuleList() 122 | self.drop = nn.Dropout(p_dropout) 123 | 124 | if gin_channels != 0: 125 | cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) 126 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 127 | 128 | for i in range(n_layers): 129 | dilation = dilation_rate ** i 130 | padding = int((kernel_size * dilation - dilation) / 2) 131 | in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, 132 | dilation=dilation, padding=padding) 133 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 134 | self.in_layers.append(in_layer) 135 | 136 | # last one is not necessary 137 | if i < n_layers - 1: 138 | res_skip_channels = 2 * hidden_channels 139 | else: 140 | res_skip_channels = hidden_channels 141 | 142 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 143 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 144 | self.res_skip_layers.append(res_skip_layer) 145 | 146 | def forward(self, x, x_mask, g=None, **kwargs): 147 | output = torch.zeros_like(x) 148 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 149 | 150 | if g is not None: 151 | g = self.cond_layer(g) 152 | 153 | for i in range(self.n_layers): 154 | x_in = self.in_layers[i](x) 155 | if g is not None: 156 | cond_offset = i * 2 * self.hidden_channels 157 | g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] 158 | else: 159 | g_l = torch.zeros_like(x_in) 160 | 161 | acts = commons.fused_add_tanh_sigmoid_multiply( 162 | x_in, 163 | g_l, 164 | n_channels_tensor) 165 | acts = self.drop(acts) 166 | 167 | res_skip_acts = self.res_skip_layers[i](acts) 168 | if i < self.n_layers - 1: 169 | res_acts = res_skip_acts[:, :self.hidden_channels, :] 170 | x = (x + res_acts) * x_mask 171 | output = output + res_skip_acts[:, self.hidden_channels:, :] 172 | else: 173 | output = output + res_skip_acts 174 | return output * x_mask 175 | 176 | def remove_weight_norm(self): 177 | if self.gin_channels != 0: 178 | torch.nn.utils.remove_weight_norm(self.cond_layer) 179 | for l in self.in_layers: 180 | torch.nn.utils.remove_weight_norm(l) 181 | for l in self.res_skip_layers: 182 | torch.nn.utils.remove_weight_norm(l) 183 | 184 | 185 | class ResBlock1(torch.nn.Module): 186 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 187 | super(ResBlock1, self).__init__() 188 | self.convs1 = nn.ModuleList([ 189 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 190 | padding=get_padding(kernel_size, dilation[0]))), 191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 192 | padding=get_padding(kernel_size, dilation[1]))), 193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 194 | padding=get_padding(kernel_size, dilation[2]))) 195 | ]) 196 | self.convs1.apply(init_weights) 197 | 198 | self.convs2 = nn.ModuleList([ 199 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 200 | padding=get_padding(kernel_size, 1))), 201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 202 | padding=get_padding(kernel_size, 1))), 203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 204 | padding=get_padding(kernel_size, 1))) 205 | ]) 206 | self.convs2.apply(init_weights) 207 | 208 | def forward(self, x, x_mask=None): 209 | for c1, c2 in zip(self.convs1, self.convs2): 210 | xt = F.leaky_relu(x, LRELU_SLOPE) 211 | if x_mask is not None: 212 | xt = xt * x_mask 213 | xt = c1(xt) 214 | xt = F.leaky_relu(xt, LRELU_SLOPE) 215 | if x_mask is not None: 216 | xt = xt * x_mask 217 | xt = c2(xt) 218 | x = xt + x 219 | if x_mask is not None: 220 | x = x * x_mask 221 | return x 222 | 223 | def remove_weight_norm(self): 224 | for l in self.convs1: 225 | remove_weight_norm(l) 226 | for l in self.convs2: 227 | remove_weight_norm(l) 228 | 229 | 230 | class ResBlock2(torch.nn.Module): 231 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 232 | super(ResBlock2, self).__init__() 233 | self.convs = nn.ModuleList([ 234 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 235 | padding=get_padding(kernel_size, dilation[0]))), 236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 237 | padding=get_padding(kernel_size, dilation[1]))) 238 | ]) 239 | self.convs.apply(init_weights) 240 | 241 | def forward(self, x, x_mask=None): 242 | for c in self.convs: 243 | xt = F.leaky_relu(x, LRELU_SLOPE) 244 | if x_mask is not None: 245 | xt = xt * x_mask 246 | xt = c(xt) 247 | x = xt + x 248 | if x_mask is not None: 249 | x = x * x_mask 250 | return x 251 | 252 | def remove_weight_norm(self): 253 | for l in self.convs: 254 | remove_weight_norm(l) 255 | 256 | 257 | class Log(nn.Module): 258 | def forward(self, x, x_mask, reverse=False, **kwargs): 259 | if not reverse: 260 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 261 | logdet = torch.sum(-y, [1, 2]) 262 | return y, logdet 263 | else: 264 | x = torch.exp(x) * x_mask 265 | return x 266 | 267 | 268 | class Flip(nn.Module): 269 | def forward(self, x, *args, reverse=False, **kwargs): 270 | x = torch.flip(x, [1]) 271 | if not reverse: 272 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 273 | return x, logdet 274 | else: 275 | return x 276 | 277 | 278 | class ElementwiseAffine(nn.Module): 279 | def __init__(self, channels): 280 | super().__init__() 281 | self.channels = channels 282 | self.m = nn.Parameter(torch.zeros(channels, 1)) 283 | self.logs = nn.Parameter(torch.zeros(channels, 1)) 284 | 285 | def forward(self, x, x_mask, reverse=False, **kwargs): 286 | if not reverse: 287 | y = self.m + torch.exp(self.logs) * x 288 | y = y * x_mask 289 | logdet = torch.sum(self.logs * x_mask, [1, 2]) 290 | return y, logdet 291 | else: 292 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 293 | return x 294 | 295 | 296 | class ResidualCouplingLayer(nn.Module): 297 | def __init__(self, 298 | channels, 299 | hidden_channels, 300 | kernel_size, 301 | dilation_rate, 302 | n_layers, 303 | p_dropout=0, 304 | gin_channels=0, 305 | mean_only=False): 306 | assert channels % 2 == 0, "channels should be divisible by 2" 307 | super().__init__() 308 | self.channels = channels 309 | self.hidden_channels = hidden_channels 310 | self.kernel_size = kernel_size 311 | self.dilation_rate = dilation_rate 312 | self.n_layers = n_layers 313 | self.half_channels = channels // 2 314 | self.mean_only = mean_only 315 | 316 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 317 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, 318 | gin_channels=gin_channels) 319 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 320 | self.post.weight.data.zero_() 321 | self.post.bias.data.zero_() 322 | 323 | def forward(self, x, x_mask, g=None, reverse=False): 324 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 325 | h = self.pre(x0) * x_mask 326 | h = self.enc(h, x_mask, g=g) 327 | stats = self.post(h) * x_mask 328 | if not self.mean_only: 329 | m, logs = torch.split(stats, [self.half_channels] * 2, 1) 330 | else: 331 | m = stats 332 | logs = torch.zeros_like(m) 333 | 334 | if not reverse: 335 | x1 = m + x1 * torch.exp(logs) * x_mask 336 | x = torch.cat([x0, x1], 1) 337 | logdet = torch.sum(logs, [1, 2]) 338 | return x, logdet 339 | else: 340 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 341 | x = torch.cat([x0, x1], 1) 342 | return x 343 | 344 | 345 | class ConvFlow(nn.Module): 346 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 347 | super().__init__() 348 | self.in_channels = in_channels 349 | self.filter_channels = filter_channels 350 | self.kernel_size = kernel_size 351 | self.n_layers = n_layers 352 | self.num_bins = num_bins 353 | self.tail_bound = tail_bound 354 | self.half_channels = in_channels // 2 355 | 356 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 357 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 358 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 359 | self.proj.weight.data.zero_() 360 | self.proj.bias.data.zero_() 361 | 362 | def forward(self, x, x_mask, g=None, reverse=False): 363 | x0, x1 = torch.split(x, [self.half_channels] * 2, 1) 364 | h = self.pre(x0) 365 | h = self.convs(h, x_mask, g=g) 366 | h = self.proj(h) * x_mask 367 | 368 | b, c, t = x0.shape 369 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 370 | 371 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 372 | unnormalized_heights = h[..., self.num_bins:2 * self.num_bins] / math.sqrt(self.filter_channels) 373 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 374 | 375 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 376 | unnormalized_widths, 377 | unnormalized_heights, 378 | unnormalized_derivatives, 379 | inverse=reverse, 380 | tails='linear', 381 | tail_bound=self.tail_bound 382 | ) 383 | 384 | x = torch.cat([x0, x1], 1) * x_mask 385 | logdet = torch.sum(logabsdet * x_mask, [1, 2]) 386 | if not reverse: 387 | return x, logdet 388 | else: 389 | return x 390 | --------------------------------------------------------------------------------