├── .gitattributes ├── .github └── workflows │ ├── checksum.yml │ ├── close-issue.yml │ ├── pull-format.yml │ ├── push-format.yml │ ├── unitest.yml │ └── upload-pypi.yml ├── .gitignore ├── ChatTTS ├── __init__.py ├── config │ ├── __init__.py │ └── config.py ├── core.py ├── model │ ├── __init__.py │ ├── cuda │ │ ├── __init__.py │ │ ├── patch.py │ │ └── te_llama.py │ ├── dvae.py │ ├── embed.py │ ├── gpt.py │ ├── processors.py │ ├── speaker.py │ ├── tokenizer.py │ └── velocity │ │ ├── __init__.py │ │ ├── block_manager.py │ │ ├── configs.py │ │ ├── llama.py │ │ ├── llm.py │ │ ├── llm_engine.py │ │ ├── model_loader.py │ │ ├── model_runner.py │ │ ├── output.py │ │ ├── sampler.py │ │ ├── sampling_params.py │ │ ├── scheduler.py │ │ ├── sequence.py │ │ └── worker.py ├── norm.py ├── res │ ├── __init__.py │ ├── homophones_map.json │ └── sha256_map.json └── utils │ ├── __init__.py │ ├── dl.py │ ├── gpu.py │ ├── io.py │ └── log.py ├── LICENSE ├── README.md ├── docs ├── cn │ └── README.md ├── es │ └── README.md ├── fr │ └── README.md ├── jp │ └── README.md ├── kr │ └── README.md └── ru │ └── README.md ├── examples ├── __init__.py ├── api │ ├── README.md │ ├── client.py │ ├── main.py │ ├── openai_api.py │ ├── postScript.py │ └── requirements.txt ├── cmd │ ├── run.py │ └── stream.py ├── ipynb │ ├── colab.ipynb │ └── example.ipynb ├── onnx │ ├── README.md │ ├── exporter.py │ ├── gpt.py │ └── modeling_llama.py └── web │ ├── __init__.py │ ├── ex.py │ ├── funcs.py │ └── webui.py ├── openai_api.ipynb ├── requirements.txt ├── setup.py ├── tests ├── #511.py ├── #588.py ├── #655.py └── testall.sh └── tools ├── __init__.py ├── audio ├── __init__.py ├── av.py ├── ffmpeg.py ├── np.py └── pcm.py ├── checksum ├── main.go └── tmpl.go ├── llm ├── __init__.py └── llm.py ├── logger ├── __init__.py └── log.py ├── normalizer ├── __init__.py ├── en.py └── zh.py └── seeder ├── __init__.py └── ctx.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # ignore jupyter notebooks in the language bar on github 2 | **/*.ipynb linguist-vendored 3 | *.ipynb 4 | -------------------------------------------------------------------------------- /.github/workflows/checksum.yml: -------------------------------------------------------------------------------- 1 | name: Calculate and Sync SHA256 2 | on: 3 | workflow_dispatch: 4 | 5 | jobs: 6 | checksum: 7 | runs-on: ubuntu-24.04 8 | steps: 9 | - uses: actions/checkout@v4 10 | 11 | - name: Setup Go Environment 12 | uses: actions/setup-go@v5 13 | 14 | - name: Run RVC-Models-Downloader 15 | run: | 16 | wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.10/rvcmd_linux_amd64.deb 17 | sudo apt -y install ./rvcmd_linux_amd64.deb 18 | rm -f ./rvcmd_linux_amd64.deb 19 | rvcmd -notrs -w 1 -notui assets/chtts 20 | 21 | - name: Calculate all Checksums 22 | run: go run tools/checksum/*.go 23 | 24 | - name: Commit back 25 | if: ${{ !github.head_ref }} 26 | id: commitback 27 | continue-on-error: true 28 | run: | 29 | git config --local user.name 'github-actions[bot]' 30 | git config --local user.email 'github-actions[bot]@users.noreply.github.com' 31 | git add --all 32 | git commit -m "chore(env): sync checksum on ${{github.ref_name}}" 33 | 34 | - name: Create Pull Request 35 | if: steps.commitback.outcome == 'success' 36 | continue-on-error: true 37 | uses: peter-evans/create-pull-request@v5 38 | with: 39 | delete-branch: true 40 | body: "Automatically sync checksum in .env" 41 | title: "chore(env): sync checksum on ${{github.ref_name}}" 42 | commit-message: "chore(env): sync checksum on ${{github.ref_name}}" 43 | branch: checksum-${{github.ref_name}} 44 | -------------------------------------------------------------------------------- /.github/workflows/close-issue.yml: -------------------------------------------------------------------------------- 1 | name: Close Inactive Issues 2 | on: 3 | schedule: 4 | - cron: "0 4 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-24.04 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v5 14 | with: 15 | exempt-issue-labels: "help wanted,following up,todo list,enhancement,algorithm,delayed,performance" 16 | days-before-issue-stale: 30 17 | days-before-issue-close: 15 18 | stale-issue-label: "stale" 19 | close-issue-message: "This issue was closed because it has been inactive for 15 days since being marked as stale." 20 | days-before-pr-stale: -1 21 | days-before-pr-close: -1 22 | operations-per-run: 10000 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | -------------------------------------------------------------------------------- /.github/workflows/pull-format.yml: -------------------------------------------------------------------------------- 1 | name: Check Pull Request Format 2 | 3 | on: 4 | pull_request_target: 5 | types: [opened, reopened, synchronize] 6 | 7 | jobs: 8 | # This workflow closes invalid PR 9 | change-or-close-pr: 10 | # The type of runner that the job will run on 11 | runs-on: ubuntu-24.04 12 | permissions: write-all 13 | 14 | # Steps represent a sequence of tasks that will be executed as part of the job 15 | steps: 16 | - name: Change Base Branch 17 | if: github.event.pull_request.base.ref != 'dev' 18 | uses: actions/github-script@v4 19 | id: change-base 20 | with: 21 | github-token: ${{ secrets.GITHUB_TOKEN }} 22 | script: | 23 | const { owner, repo, number } = context.issue; 24 | const newBase = 'dev'; 25 | try { 26 | const result = await github.pulls.update({ 27 | owner, 28 | repo, 29 | pull_number: number, 30 | base: newBase 31 | }); 32 | console.log(result); 33 | return 'success'; 34 | } catch (error) { 35 | console.log(error); 36 | return 'failed'; 37 | } 38 | 39 | - name: Close PR if it is not pointed to dev Branch 40 | if: "github.event.pull_request.base.ref != 'dev' && steps.change-base.outputs.result == 'failed'" 41 | uses: superbrothers/close-pull-request@v3 42 | with: 43 | # Optional. Post a issue comment just before closing a pull request. 44 | comment: "Invalid PR to `non-dev` branch `${{ github.event.pull_request.base.ref }}`." 45 | 46 | pull-format: 47 | runs-on: ubuntu-latest 48 | permissions: 49 | contents: write 50 | 51 | continue-on-error: true 52 | 53 | steps: 54 | - name: Checkout Repo 55 | continue-on-error: true 56 | uses: actions/checkout@v4 57 | 58 | - name: Checkout PR # see https://github.com/orgs/community/discussions/24945 59 | env: 60 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 61 | run: gh pr checkout ${{ github.event.pull_request.number }} 62 | 63 | - name: Set up Python 64 | uses: actions/setup-python@v5 65 | 66 | - name: Create venv 67 | run: python3 -m venv .venv 68 | 69 | - name: Activate venv 70 | run: | 71 | . .venv/bin/activate 72 | echo PATH=$PATH >> $GITHUB_ENV 73 | 74 | - name: Install Black 75 | run: pip install "black[jupyter]" 76 | 77 | - name: Run Black 78 | # run: black $(git ls-files '*.py') 79 | run: black . 80 | 81 | - name: Commit back 82 | env: 83 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 84 | continue-on-error: true 85 | run: | 86 | git config --local user.name 'github-actions[bot]' 87 | git config --local user.email 'github-actions[bot]@users.noreply.github.com' 88 | git add --all 89 | git commit -m "chore(format): run black on ${{github.ref_name}}" 90 | git push 91 | -------------------------------------------------------------------------------- /.github/workflows/push-format.yml: -------------------------------------------------------------------------------- 1 | name: Standardize Code Format 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | 9 | jobs: 10 | push-format: 11 | runs-on: ubuntu-latest 12 | 13 | if: "!contains(github.event.head_commit.message, 'chore(format): ') && !contains(github.event.head_commit.message, 'chore(env): ')" 14 | 15 | permissions: 16 | contents: write 17 | pull-requests: write 18 | 19 | steps: 20 | - uses: actions/checkout@v4 21 | with: 22 | ref: ${{github.ref_name}} 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v5 26 | 27 | - name: Create venv 28 | run: python3 -m venv .venv 29 | 30 | - name: Activate venv 31 | run: | 32 | . .venv/bin/activate 33 | echo PATH=$PATH >> $GITHUB_ENV 34 | 35 | - name: Install Black 36 | run: pip install "black[jupyter]" 37 | 38 | - name: Run Black 39 | # run: black $(git ls-files '*.py') 40 | run: black . 41 | 42 | - name: Commit Back 43 | continue-on-error: true 44 | id: commitback 45 | run: | 46 | git config --local user.email "github-actions[bot]@users.noreply.github.com" 47 | git config --local user.name "github-actions[bot]" 48 | git add --all 49 | git commit -m "chore(format): run black on ${{github.ref_name}}" 50 | 51 | - name: Create Pull Request 52 | if: steps.commitback.outcome == 'success' 53 | continue-on-error: true 54 | uses: peter-evans/create-pull-request@v5 55 | with: 56 | delete-branch: true 57 | body: "Automatically apply code formatter change" 58 | title: "chore(format): run black on ${{github.ref_name}}" 59 | commit-message: "chore(format): run black on ${{github.ref_name}}" 60 | branch: formatter-${{github.ref_name}} 61 | -------------------------------------------------------------------------------- /.github/workflows/unitest.yml: -------------------------------------------------------------------------------- 1 | name: Unit Test 2 | on: [ push, pull_request ] 3 | jobs: 4 | build: 5 | runs-on: ${{ matrix.os }} 6 | 7 | if: "!contains(github.event.head_commit.message, 'chore(format): ') && !contains(github.event.head_commit.message, 'chore(env): ')" 8 | 9 | strategy: 10 | matrix: 11 | python-version: ["3.8", "3.9", "3.10"] 12 | os: [ubuntu-latest] 13 | fail-fast: true 14 | 15 | steps: 16 | 17 | - uses: actions/checkout@v4 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install Dependents 25 | run: | 26 | sudo apt-get install -y portaudio19-dev python3-pyaudio 27 | 28 | - name: Create venv 29 | run: python3 -m venv .venv 30 | 31 | - name: Activate venv 32 | run: | 33 | . .venv/bin/activate 34 | echo PATH=$PATH >> $GITHUB_ENV 35 | 36 | - name: Test Install 37 | run: pip install . 38 | 39 | - name: Install Dependencies 40 | run: pip install -r requirements.txt 41 | 42 | - name: Run Test 43 | run: tests/testall.sh 44 | -------------------------------------------------------------------------------- /.github/workflows/upload-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | build: 10 | runs-on: ubuntu-22.04 11 | 12 | steps: 13 | 14 | - uses: actions/checkout@v4 15 | with: 16 | ref: ${{github.ref_name}} 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | 21 | - name: Install Dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | python -m pip install --upgrade setuptools 25 | python -m pip install --upgrade wheel 26 | pip install twine 27 | 28 | - name: Build Package 29 | env: 30 | CHTTS_VER: ${{ github.ref_name }} 31 | run: | 32 | echo "Release Tag: ${{ github.ref_name }}" 33 | sed -i 's/v0.0.0/${{ github.ref_name }}/g' setup.py 34 | python setup.py sdist 35 | 36 | - name: Upload Package 37 | run: | 38 | twine upload dist/* -u "__token__" -p ${{ secrets.PYPI_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.ckpt 6 | # C extensions 7 | *.so 8 | *.pt 9 | 10 | # Distribution / packaging 11 | .Python 12 | outputs/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | asset/* 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | .idea/ 164 | 165 | # MacOS System 166 | .DS_Store 167 | 168 | # assets and configs of ChatTTS 169 | /asset 170 | /config 171 | 172 | # inferred result 173 | *.wav 174 | *.mp3 175 | -------------------------------------------------------------------------------- /ChatTTS/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Chat 2 | -------------------------------------------------------------------------------- /ChatTTS/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import Config 2 | -------------------------------------------------------------------------------- /ChatTTS/config/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass(repr=False, eq=False) 5 | class Path: 6 | vocos_ckpt_path: str = "asset/Vocos.safetensors" 7 | dvae_ckpt_path: str = "asset/DVAE.safetensors" 8 | gpt_ckpt_path: str = "asset/gpt" 9 | decoder_ckpt_path: str = "asset/Decoder.safetensors" 10 | tokenizer_path: str = "asset/tokenizer" 11 | embed_path: str = "asset/Embed.safetensors" 12 | 13 | 14 | @dataclass(repr=False, eq=False) 15 | class Decoder: 16 | idim: int = 384 17 | odim: int = 384 18 | hidden: int = 512 19 | n_layer: int = 12 20 | bn_dim: int = 128 21 | 22 | 23 | @dataclass(repr=False, eq=False) 24 | class VQ: 25 | dim: int = 1024 26 | levels: tuple = (5, 5, 5, 5) 27 | G: int = 2 28 | R: int = 2 29 | 30 | 31 | @dataclass(repr=False, eq=False) 32 | class DVAE: 33 | encoder: Decoder = Decoder( 34 | idim=512, 35 | odim=1024, 36 | hidden=256, 37 | n_layer=12, 38 | bn_dim=128, 39 | ) 40 | decoder: Decoder = Decoder( 41 | idim=512, 42 | odim=512, 43 | hidden=256, 44 | n_layer=12, 45 | bn_dim=128, 46 | ) 47 | vq: VQ = VQ() 48 | 49 | 50 | @dataclass(repr=False, eq=False) 51 | class GPT: 52 | hidden_size: int = 768 53 | intermediate_size: int = 3072 54 | num_attention_heads: int = 12 55 | num_hidden_layers: int = 20 56 | use_cache: bool = False 57 | max_position_embeddings: int = 4096 58 | 59 | spk_emb_dim: int = 192 60 | spk_KL: bool = False 61 | num_audio_tokens: int = 626 62 | num_text_tokens: int = 21178 63 | num_vq: int = 4 64 | 65 | 66 | @dataclass(repr=False, eq=False) 67 | class Embed: 68 | hidden_size: int = 768 69 | num_audio_tokens: int = 626 70 | num_text_tokens: int = 21178 71 | num_vq: int = 4 72 | 73 | 74 | @dataclass(repr=False, eq=False) 75 | class FeatureExtractorInitArgs: 76 | sample_rate: int = 24000 77 | n_fft: int = 1024 78 | hop_length: int = 256 79 | n_mels: int = 100 80 | padding: str = "center" 81 | 82 | 83 | @dataclass(repr=False, eq=False) 84 | class FeatureExtractor: 85 | class_path: str = "vocos.feature_extractors.MelSpectrogramFeatures" 86 | init_args: FeatureExtractorInitArgs = FeatureExtractorInitArgs() 87 | 88 | 89 | @dataclass(repr=False, eq=False) 90 | class BackboneInitArgs: 91 | input_channels: int = 100 92 | dim: int = 512 93 | intermediate_dim: int = 1536 94 | num_layers: int = 8 95 | 96 | 97 | @dataclass(repr=False, eq=False) 98 | class Backbone: 99 | class_path: str = "vocos.models.VocosBackbone" 100 | init_args: BackboneInitArgs = BackboneInitArgs() 101 | 102 | 103 | @dataclass(repr=False, eq=False) 104 | class FourierHeadInitArgs: 105 | dim: int = 512 106 | n_fft: int = 1024 107 | hop_length: int = 256 108 | padding: str = "center" 109 | 110 | 111 | @dataclass(repr=False, eq=False) 112 | class FourierHead: 113 | class_path: str = "vocos.heads.ISTFTHead" 114 | init_args: FourierHeadInitArgs = FourierHeadInitArgs() 115 | 116 | 117 | @dataclass(repr=False, eq=False) 118 | class Vocos: 119 | feature_extractor: FeatureExtractor = FeatureExtractor() 120 | backbone: Backbone = Backbone() 121 | head: FourierHead = FourierHead() 122 | 123 | 124 | @dataclass(repr=False, eq=False) 125 | class Config: 126 | path: Path = Path() 127 | decoder: Decoder = Decoder() 128 | dvae: DVAE = DVAE() 129 | gpt: GPT = GPT() 130 | embed: Embed = Embed() 131 | vocos: Vocos = Vocos() 132 | spk_stat: str = ( 133 | "愐穤巩噅廷戇笉屈癐媄垹垧帶爲漈塀殐慄亅倴庲舴猂瑈圐狴夥圓帍戛挠腉耐劤坽喳幾战謇聀崒栄呥倸庭燡欈杁襐褄乭埗幺爃弔摁斐捔兕佖廐舏竾豃磐姓趡佄幒爚欄豄讐皳訵仩帆投謌荃蝐叄圝伆幦抂茁呄掑斃讹傮庞爣蜀橁偐祄亥兡常爂欍扉丐浔佱僈強払伅扂蛐徴憍傞巀戺欀艂琐嗴啥値彷刂權穈扒卤俔贲庛初笂卄贐枴仭亁庛剎猢扃缐趤刁偵幪舏伌煁婐潤晍位弾舙茥穁葏蠣訑企庤刊笍橁溑僔云偁庯戚伍潉膐脴僵噔廃艅匊祂唐憴壝嗙席爥欁虁谐牴帽势弿牳蜁兀蛐傄喩丿帔刔圆衁廐罤庁促帙劢伈汄樐檄勵伴弝舑欍罅虐昴劭勅帜刼朊蕁虐蓴樑伫幨扑謪剀堐稴丵伱弐舮諸赁習俔容厱幫牶謃孄糐答嗝僊帜燲笄終瀒判久僤帘爴茇千孑冄凕佳引扐蜁歁缏裄剽儺恘爋朏眿廐呄塍嘇幻爱茠詁訐剴唭俐幾戊欀硁菐贄楕偒巡爀弎屄莐睳賙凶彎刅漄區唐溴剑劋庽舽猄煃跐夔惥伾庮舎伈罁垑坄怅业怯刁朇獁嶏覔坩俳巶爜朐潁崐萄俹凛常爺笌穀聐此夡倛帡刀匉終窏舣販侽怿扉伥贿憐忓謩姆幌犊漂慆癒却甝兎帼戏欅詂浐朔仹壭帰臷弎恇菐獤帡偖帘爞伅腂皐纤囅充幓戠伥灂丐訤戱倱弋爮嬌癁恐孄侥劬忶刓國詀桒古偩嘄庬戚茝赂监燤嘑勌幦舽持呂諐棤姑再底舡笍艃瀐孴倉傔弋爔猠乁濑塄偽嘧恂舛缇襃厐窴仡刱忕別漇穁岏缴廽价庌爊謈硄讑惤倁儂庭爋伇蝂嶐莔摝傠库刞茄歃戏薤伍伯廮创笠塄熐兴勽俄帅剉最腀砐敤卝侍弆戺朒虃旐蚄梕亖幔牻朣扅贐玔堝噅帡剌圅摀崐彤流僳庙爖嬇啁渐悤堁丛幆刧挜彃悐幤刹嚟恕芁看聀摐焔向乁帖爭欁癃糒圄弙佱廜戤謍婀咐昴焍亩廦艏拼謿芐癤怹兽幸舳朇畁喐稔毝丼弈懲挀譂勑哴啁伎常舭笯晁堑俄叩剔廟爍欦絁夒伤休傑廳戌蜅潆癐彴摑勯床刽欅艁砐忄搉从廡舊猥潂唐委仱僜廼爤朄呃弐礔滵垓幩爄挂筁乐籤刕凟幵爠弉癅乑吴勥伖帪舩茆婁碐幤叭乢巜艳猁桀桐啄唩俊幍舮猀艅焐螔琽亀帋爜缅噃咐斤喩予幩爛笆摀浐猴依侹幃刕園慄蛐栤澹仑座爼謉桃慐浔斕偻幛懰嬓衁愐氄悅仿应芔漄衃敐謤傁匩幹抃圉癄廐裄屵噉幍利謍聂搐蛔嚙坍怗舁圐畃膐栄刵东巆戤諾呃偑媤嗨跞忶爝眄祂朒嶔僭劉忾刐匋癄袐翴珅僷廲芄茈恈皐擄崑伄廉牍匃剃犏澤唑丄庺戃伃煀某杄偙亽帴切缌罄挐尴噙倰带舞漄橄塐糴俩僯帀般漀坂栐更両俇廱舌猁慂拐偤嶱卶应刪眉獁茐伔嘅偺帟舊漂恀栐暄喡乞庙舆匂敀潑恔劑侖延戦盽怶唯慳蝘蟃孫娎益袰玍屃痶翮笪儚裀倹椌玻翀詵筽舘惯堿某侰晈藏缮詗廦夸妎瑻瀒裔媀憞唃冶璭狻渠荑奬熹茅愺氰菣滠翦岓褌泣崲嚭欓湒聙宺爄蛅愸庍匃帆誔穮懌蓪玷澌氋抌訙屌臞廛玸听屺希疭孝凂紋新煎彃膲跱尪懁眆窴珏卓揨菸紭概囥显壌榄垫嘮嬭覤媸侵佮烒耸觌婀秋狃帹葯訤桜糨笾腢伀肶悍炂艤禖岅臺惘梷瞍友盁佨岧憳瓧嘴汬藊愌蘤嶠硴绤蜲襏括勾谂縨妥蓪澭竭萢藜纞糲煮愆瀯孯琓罂諺塿燗狟弙衯揻縷丱糅臄梱瀮杰巳猙亊符胠匃泀廏圃膂蒃籏礩岈簹缌劺燲褡孓膜拔蠿觮呋煣厌尷熜論弲牭紫寊誃紀橴賬傸箍弚窃侫簲慯烣渽祌壓媥噜夽夛諛玹疮禄冪謇媽衤盰缺繑薫兾萧嵱打滽箺嚯凣狢蠜崼覽烸簶盯籓摀苶峸懗泲涻凮愳緗剋笔懆廡瞿椏礤惐藥崍腈烄伹亯昣翬褍絋桫僨吨莌丛矄蜞娈憊苆塁蓏嚢嫼绻崱婋囱蠸篯晣芀繼索兓僖誹岯圪褰蠇唓妷胅巁渮砛傈蝷嵚冃購赁峍裋荂舾符熻岳墩寮粃凲袑彚太绲头摯繳狁俥籌冝諝註坎幫擤詒宒凕賐唶梎噔弼課屿覍囨焬櫱撪蝮蝬簸懰櫫涺嵍睻屪翔峞慘滟熲昱军烊舿尦舄糖奁溏凂彆蝲糴禍困皻灏牋睒诙嶱臀开蓈眎腼丢纻廏憤嫖暭袭崲肸螛妒榗紉谨窮袃瑠聍绊腆亿冲葐喋縔詖岑兾给堸赏旻桀蛨媆訂峦紷敯囬偐筨岸焸拭笵殒哜墒萍屓娓諙械臮望摰芑寭准僞谹氍旋憢菮屃划欣瘫谎蘻哐繁籥禦僿誵皯墓燀縿笞熦绗稹榎矻綞蓓帡戓沺区才畃洊詪糐裶盰窶耎偌劂誐庩惝滜沺哮呃煐譠崄槀猄肼蔐擋湌蠺篃恥諌瞦宍堫挪裕崑慩狲悠煋仛愞砈粵八棁害楐妋萔貨尵奂苰怫誎傫岆蕯屇脉夈仆茎刓繸芺壸碗曛汁戭炻獻凉媁兎狜爴怰賃纎袏娷禃蓥膹薪渻罸窿粫凾褄舺窮墫干苊繁冏僮訸夯绛蓪虛羽慲烏憷趎睊蠰莍塞成廎盁欏喓蜮譤崆楁囘矇薭伣艘虝帴奮苢渶虎暣翐蝃尾稈糶瀴罐嵚氮葯笫慐棌悶炯竻爅们媡姢嫺窷刮歫劈裩屬椕賑蜹薊刲義哯尗褦瓀稾礋揣窼舫尋姁椄侸嗫珺修纘媃腽蛛稹梭呛瀈蘟縀礉論夵售主梮蠉娅娭裀誼嶭観枳倊簈褃擞綿催瞃溶苊笛襹櫲盅六囫獩佃粨慯瓢眸旱荃婨蔞岋祗墼焻网牻琖詆峋秉胳媴袭澓賢経稟壩胫碯偏囫嶎纆窈槊賐撹璬莃缘誾宭愊眗喷监劋萘訯總槿棭戾墮犄恌縈簍樥蛔杁袭嫛憫倆篏墵賈羯茎觳蒜致娢慄勒覸蘍曲栂葭宆妋皽缽免盳猼蔂糥觧烳檸佯憓煶蔐筼种繷琲膌塄剰讎対腕棥渽忲俛浪譬秛惛壒嘸淫冻曄睻砃奫貯庴爅粓脮脡娎妖峵蘲討惋泊蠀㴆" 134 | ) 135 | -------------------------------------------------------------------------------- /ChatTTS/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .dvae import DVAE 2 | from .embed import Embed 3 | from .gpt import GPT 4 | from .processors import gen_logits 5 | from .speaker import Speaker 6 | from .tokenizer import Tokenizer 7 | -------------------------------------------------------------------------------- /ChatTTS/model/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | from .te_llama import TELlamaModel 2 | -------------------------------------------------------------------------------- /ChatTTS/model/cuda/patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LlamaRMSNorm(torch.nn.Module): 5 | def __init__(self, hidden_size, eps=1e-6): 6 | """ 7 | LlamaRMSNorm is equivalent to T5LayerNorm 8 | """ 9 | super().__init__() 10 | self.weight = torch.nn.Parameter(torch.ones(hidden_size)) 11 | self.variance_epsilon = eps 12 | 13 | def forward(self, hidden_states: torch.Tensor): 14 | input_dtype = hidden_states.dtype 15 | hidden_states = hidden_states.to(torch.float32) 16 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 17 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 18 | return self.weight.to(hidden_states.device) * hidden_states.to(input_dtype) 19 | -------------------------------------------------------------------------------- /ChatTTS/model/cuda/te_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # See LICENSE for license information. 4 | # 5 | # From https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/te_llama.py 6 | # 7 | # Edited by fumiama. 8 | 9 | import re 10 | from contextlib import contextmanager 11 | from typing import Dict 12 | 13 | import transformer_engine as te 14 | from transformer_engine.pytorch.attention import RotaryPositionEmbedding 15 | 16 | import torch 17 | 18 | import transformers 19 | from transformers.models.llama.modeling_llama import ( 20 | LlamaModel, 21 | LlamaConfig, 22 | ) 23 | from transformers.modeling_utils import _load_state_dict_into_model 24 | 25 | from .patch import LlamaRMSNorm 26 | 27 | 28 | @contextmanager 29 | def replace_decoder(te_decoder_cls, llama_rms_norm_cls): 30 | """ 31 | Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`. 32 | """ 33 | original_llama_decoder_cls = ( 34 | transformers.models.llama.modeling_llama.LlamaDecoderLayer 35 | ) 36 | transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls 37 | original_llama_rms_norm_cls = transformers.models.llama.modeling_llama.LlamaRMSNorm 38 | transformers.models.llama.modeling_llama.LlamaRMSNorm = llama_rms_norm_cls 39 | try: 40 | yield 41 | finally: 42 | transformers.models.llama.modeling_llama.LlamaDecoderLayer = ( 43 | original_llama_decoder_cls 44 | ) 45 | transformers.models.llama.modeling_llama.LlamaRMSNorm = ( 46 | original_llama_rms_norm_cls 47 | ) 48 | 49 | 50 | class TELlamaDecoderLayer(te.pytorch.TransformerLayer): 51 | """ 52 | Wrapper class over TE's `TransformerLayer`. This makes the wrapper very 53 | similar to HF's `LlamaDecoderLayer` and easier to replace it in the code. 54 | 55 | Args: 56 | config: LlamaConfig 57 | args: positional args (for compatibility with `LlamaDecoderLayer`) 58 | kwargs: keyword args (for compatibility with `LlamaDecoderLayer`) 59 | """ 60 | 61 | def __init__(self, config, *args, **kwargs): 62 | super().__init__( 63 | hidden_size=config.hidden_size, 64 | ffn_hidden_size=config.intermediate_size, 65 | num_attention_heads=config.num_attention_heads, 66 | bias=False, 67 | layernorm_epsilon=config.rms_norm_eps, 68 | hidden_dropout=0, 69 | attention_dropout=0, 70 | fuse_qkv_params=False, 71 | normalization="RMSNorm", 72 | activation="swiglu", 73 | attn_input_format="bshd", 74 | num_gqa_groups=config.num_key_value_heads, 75 | ) 76 | te_rope = RotaryPositionEmbedding( 77 | config.hidden_size // config.num_attention_heads 78 | ) 79 | self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() 80 | 81 | def forward(self, hidden_states, *args, attention_mask, **kwargs): 82 | """ 83 | Custom forward to make sure we only pass relevant arguments to the 84 | forward pass of the `TransformerLayer`. Also, make sure the output 85 | format matches the output of the HF's `LlamaDecoderLayer`. 86 | """ 87 | return ( 88 | super().forward( 89 | hidden_states, 90 | attention_mask=attention_mask, 91 | rotary_pos_emb=self.te_rope_emb, 92 | ), 93 | ) 94 | 95 | 96 | class TELlamaModel: 97 | """ 98 | LM created with `LlamaModel`. The underlying `LlamaDecoderLayer` 99 | class is monkey-patched with `TELlamaDecoderLayer` class before 100 | initializing the causal LM with `LlamaModel`. 101 | 102 | Args: 103 | config: LlamaConfig 104 | """ 105 | 106 | def __new__(cls, config: LlamaConfig): 107 | with replace_decoder( 108 | te_decoder_cls=TELlamaDecoderLayer, llama_rms_norm_cls=LlamaRMSNorm 109 | ): 110 | model = LlamaModel(config) 111 | return model 112 | 113 | @classmethod 114 | def from_state_dict( 115 | cls, 116 | state_dict: Dict[str, torch.Tensor], 117 | config: LlamaConfig, 118 | ): 119 | """ 120 | Custom method adapted from `from_pretrained` method in HuggingFace 121 | Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 122 | """ 123 | 124 | vanilla_model = cls(config) 125 | 126 | # replace_params copies parameters relevant only to TransformerEngine 127 | _replace_params(state_dict, vanilla_model.state_dict(), config) 128 | # _load_state_dict_into_model copies parameters other than those in TransformerEngine 129 | _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") 130 | 131 | return vanilla_model 132 | 133 | 134 | def _replace_params(hf_state_dict, te_state_dict, config): 135 | # collect all layer prefixes to update 136 | all_layer_prefixes = set() 137 | for param_key in hf_state_dict.keys(): 138 | layer_prefix_pat = "model.layers.\d+." 139 | m = re.match(layer_prefix_pat, param_key) 140 | if m is not None: 141 | all_layer_prefixes.add(m.group()) 142 | 143 | for layer_prefix in all_layer_prefixes: 144 | # When loading weights into models with less number of layers, skip the 145 | # copy if the corresponding layer doesn't exist in HF model 146 | if layer_prefix + "input_layernorm.weight" in hf_state_dict: 147 | te_state_dict[ 148 | layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight" 149 | ].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:] 150 | 151 | if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict: 152 | te_state_dict[ 153 | layer_prefix + "self_attention.layernorm_qkv.query_weight" 154 | ].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:] 155 | 156 | if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict: 157 | te_state_dict[ 158 | layer_prefix + "self_attention.layernorm_qkv.key_weight" 159 | ].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:] 160 | 161 | if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict: 162 | te_state_dict[ 163 | layer_prefix + "self_attention.layernorm_qkv.value_weight" 164 | ].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:] 165 | 166 | if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict: 167 | te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = ( 168 | hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:] 169 | ) 170 | 171 | if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict: 172 | te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = ( 173 | hf_state_dict[layer_prefix + "post_attention_layernorm.weight"].data[:] 174 | ) 175 | 176 | # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to 177 | # load them separately. 178 | if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict: 179 | te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[ 180 | : config.intermediate_size 181 | ] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data 182 | 183 | if layer_prefix + "mlp.up_proj.weight" in hf_state_dict: 184 | te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[ 185 | config.intermediate_size : 186 | ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data 187 | 188 | if layer_prefix + "mlp.down_proj.weight" in hf_state_dict: 189 | te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = ( 190 | hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:] 191 | ) 192 | return all_layer_prefixes 193 | -------------------------------------------------------------------------------- /ChatTTS/model/dvae.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Literal, Union 3 | 4 | import numpy as np 5 | import pybase16384 as b14 6 | import torch 7 | import torch.nn as nn 8 | import torchaudio 9 | from vector_quantize_pytorch import GroupedResidualFSQ 10 | 11 | from ..utils import load_safetensors 12 | 13 | 14 | class ConvNeXtBlock(nn.Module): 15 | def __init__( 16 | self, 17 | dim: int, 18 | intermediate_dim: int, 19 | kernel: int, 20 | dilation: int, 21 | layer_scale_init_value: float = 1e-6, 22 | ): 23 | # ConvNeXt Block copied from Vocos. 24 | super().__init__() 25 | self.dwconv = nn.Conv1d( 26 | dim, 27 | dim, 28 | kernel_size=kernel, 29 | padding=dilation * (kernel // 2), 30 | dilation=dilation, 31 | groups=dim, 32 | ) # depthwise conv 33 | 34 | self.norm = nn.LayerNorm(dim, eps=1e-6) 35 | self.pwconv1 = nn.Linear( 36 | dim, intermediate_dim 37 | ) # pointwise/1x1 convs, implemented with linear layers 38 | self.act = nn.GELU() 39 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 40 | self.weight = ( 41 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 42 | if layer_scale_init_value > 0 43 | else None 44 | ) 45 | 46 | def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: 47 | residual = x 48 | 49 | y = self.dwconv(x) 50 | y.transpose_(1, 2) # (B, C, T) -> (B, T, C) 51 | x = self.norm(y) 52 | del y 53 | y = self.pwconv1(x) 54 | del x 55 | x = self.act(y) 56 | del y 57 | y = self.pwconv2(x) 58 | del x 59 | if self.weight is not None: 60 | y *= self.weight 61 | y.transpose_(1, 2) # (B, T, C) -> (B, C, T) 62 | 63 | x = y + residual 64 | del y 65 | 66 | return x 67 | 68 | 69 | class GFSQ(nn.Module): 70 | 71 | def __init__( 72 | self, dim: int, levels: List[int], G: int, R: int, eps=1e-5, transpose=True 73 | ): 74 | super(GFSQ, self).__init__() 75 | self.quantizer = GroupedResidualFSQ( 76 | dim=dim, 77 | levels=list(levels), 78 | num_quantizers=R, 79 | groups=G, 80 | ) 81 | self.n_ind = math.prod(levels) 82 | self.eps = eps 83 | self.transpose = transpose 84 | self.G = G 85 | self.R = R 86 | 87 | def _embed(self, x: torch.Tensor): 88 | if self.transpose: 89 | x = x.transpose(1, 2) 90 | """ 91 | x = rearrange( 92 | x, "b t (g r) -> g b t r", g = self.G, r = self.R, 93 | ) 94 | """ 95 | x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3) 96 | feat = self.quantizer.get_output_from_indices(x) 97 | return feat.transpose_(1, 2) if self.transpose else feat 98 | 99 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 100 | return super().__call__(x) 101 | 102 | def forward(self, x: torch.Tensor) -> torch.Tensor: 103 | if self.transpose: 104 | x.transpose_(1, 2) 105 | # feat, ind = self.quantizer(x) 106 | _, ind = self.quantizer(x) 107 | """ 108 | ind = rearrange( 109 | ind, "g b t r ->b t (g r)", 110 | ) 111 | """ 112 | ind = ind.permute(1, 2, 0, 3).contiguous() 113 | ind = ind.view(ind.size(0), ind.size(1), -1) 114 | """ 115 | embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind) 116 | embed_onehot = embed_onehot_tmp.to(x.dtype) 117 | del embed_onehot_tmp 118 | e_mean = torch.mean(embed_onehot, dim=[0, 1]) 119 | # e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1) 120 | torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean) 121 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1)) 122 | 123 | return 124 | torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device), 125 | feat.transpose_(1, 2) if self.transpose else feat, 126 | perplexity, 127 | """ 128 | return ind.transpose_(1, 2) if self.transpose else ind 129 | 130 | 131 | class DVAEDecoder(nn.Module): 132 | def __init__( 133 | self, 134 | idim: int, 135 | odim: int, 136 | n_layer=12, 137 | bn_dim=64, 138 | hidden=256, 139 | kernel=7, 140 | dilation=2, 141 | up=False, 142 | ): 143 | super().__init__() 144 | self.up = up 145 | self.conv_in = nn.Sequential( 146 | nn.Conv1d(idim, bn_dim, 3, 1, 1), 147 | nn.GELU(), 148 | nn.Conv1d(bn_dim, hidden, 3, 1, 1), 149 | ) 150 | self.decoder_block = nn.ModuleList( 151 | [ 152 | ConvNeXtBlock( 153 | hidden, 154 | hidden * 4, 155 | kernel, 156 | dilation, 157 | ) 158 | for _ in range(n_layer) 159 | ] 160 | ) 161 | self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) 162 | 163 | def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: 164 | # B, C, T 165 | y = self.conv_in(x) 166 | del x 167 | for f in self.decoder_block: 168 | y = f(y, conditioning) 169 | 170 | x = self.conv_out(y) 171 | del y 172 | return x 173 | 174 | 175 | class MelSpectrogramFeatures(torch.nn.Module): 176 | def __init__( 177 | self, 178 | sample_rate=24000, 179 | n_fft=1024, 180 | hop_length=256, 181 | n_mels=100, 182 | padding: Literal["center", "same"] = "center", 183 | device: torch.device = torch.device("cpu"), 184 | ): 185 | super().__init__() 186 | self.device = device 187 | if padding not in ["center", "same"]: 188 | raise ValueError("Padding must be 'center' or 'same'.") 189 | self.padding = padding 190 | self.mel_spec = torchaudio.transforms.MelSpectrogram( 191 | sample_rate=sample_rate, 192 | n_fft=n_fft, 193 | hop_length=hop_length, 194 | n_mels=n_mels, 195 | center=padding == "center", 196 | power=1, 197 | ) 198 | 199 | def __call__(self, audio: torch.Tensor) -> torch.Tensor: 200 | return super().__call__(audio) 201 | 202 | def forward(self, audio: torch.Tensor) -> torch.Tensor: 203 | audio = audio.to(self.device) 204 | mel: torch.Tensor = self.mel_spec(audio) 205 | features = torch.log(torch.clip(mel, min=1e-5)) 206 | return features 207 | 208 | 209 | class DVAE(nn.Module): 210 | def __init__( 211 | self, 212 | decoder_config: dict, 213 | encoder_config: Optional[dict] = None, 214 | vq_config: Optional[dict] = None, 215 | dim=512, 216 | coef: Optional[str] = None, 217 | device: torch.device = torch.device("cpu"), 218 | ): 219 | super().__init__() 220 | if coef is None: 221 | coef = torch.rand(100) 222 | else: 223 | coef = torch.from_numpy( 224 | np.frombuffer(b14.decode_from_string(coef), dtype=np.float32).copy() 225 | ) 226 | self.register_buffer("coef", coef.unsqueeze(0).unsqueeze_(2)) 227 | 228 | if encoder_config is not None: 229 | self.downsample_conv = nn.Sequential( 230 | nn.Conv1d(100, dim, 3, 1, 1), 231 | nn.GELU(), 232 | nn.Conv1d(dim, dim, 4, 2, 1), 233 | nn.GELU(), 234 | ) 235 | self.preprocessor_mel = MelSpectrogramFeatures(device=device) 236 | self.encoder: Optional[DVAEDecoder] = DVAEDecoder(**encoder_config) 237 | 238 | self.decoder = DVAEDecoder(**decoder_config) 239 | self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False) 240 | if vq_config is not None: 241 | self.vq_layer = GFSQ(**vq_config) 242 | else: 243 | self.vq_layer = None 244 | 245 | def __repr__(self) -> str: 246 | return b14.encode_to_string( 247 | self.coef.cpu().numpy().astype(np.float32).tobytes() 248 | ) 249 | 250 | def __call__( 251 | self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode" 252 | ) -> torch.Tensor: 253 | return super().__call__(inp, mode) 254 | 255 | @torch.inference_mode() 256 | def load_pretrained(self, filename: str, device: torch.device): 257 | state_dict_tensors = load_safetensors(filename) 258 | self.load_state_dict(state_dict_tensors) 259 | self.to(device) 260 | 261 | @torch.inference_mode() 262 | def forward( 263 | self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode" 264 | ) -> torch.Tensor: 265 | if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None: 266 | mel = self.preprocessor_mel(inp) 267 | x: torch.Tensor = self.downsample_conv( 268 | torch.div(mel, self.coef.view(100, 1).expand(mel.shape), out=mel), 269 | ).unsqueeze_(0) 270 | del mel 271 | x = self.encoder(x) 272 | ind = self.vq_layer(x) 273 | del x 274 | return ind 275 | 276 | if self.vq_layer is not None: 277 | vq_feats = self.vq_layer._embed(inp) 278 | else: 279 | vq_feats = inp 280 | 281 | vq_feats = ( 282 | vq_feats.view( 283 | (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)), 284 | ) 285 | .permute(0, 2, 3, 1) 286 | .flatten(2) 287 | ) 288 | 289 | dec_out = self.out_conv( 290 | self.decoder( 291 | x=vq_feats, 292 | ), 293 | ) 294 | 295 | del vq_feats 296 | 297 | return torch.mul(dec_out, self.coef, out=dec_out) 298 | 299 | @torch.inference_mode() 300 | def sample_audio(self, wav: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: 301 | if isinstance(wav, np.ndarray): 302 | wav = torch.from_numpy(wav) 303 | return self(wav, "encode").squeeze_(0) 304 | -------------------------------------------------------------------------------- /ChatTTS/model/embed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.parametrizations import weight_norm 4 | 5 | from ..utils import load_safetensors 6 | 7 | 8 | class Embed(nn.Module): 9 | def __init__( 10 | self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4 11 | ): 12 | super().__init__() 13 | 14 | self.num_vq = num_vq 15 | self.num_audio_tokens = num_audio_tokens 16 | 17 | self.model_dim = hidden_size 18 | self.emb_code = nn.ModuleList( 19 | [nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)], 20 | ) 21 | self.emb_text = nn.Embedding(num_text_tokens, self.model_dim) 22 | 23 | self.head_text = weight_norm( 24 | nn.Linear(self.model_dim, num_text_tokens, bias=False), 25 | name="weight", 26 | ) 27 | self.head_code = nn.ModuleList( 28 | [ 29 | weight_norm( 30 | nn.Linear(self.model_dim, num_audio_tokens, bias=False), 31 | name="weight", 32 | ) 33 | for _ in range(self.num_vq) 34 | ], 35 | ) 36 | 37 | @torch.inference_mode() 38 | def load_pretrained(self, filename: str, device: torch.device): 39 | state_dict_tensors = load_safetensors(filename) 40 | self.load_state_dict(state_dict_tensors) 41 | self.to(device) 42 | 43 | def __call__( 44 | self, input_ids: torch.Tensor, text_mask: torch.Tensor 45 | ) -> torch.Tensor: 46 | """ 47 | get_emb 48 | """ 49 | return super().__call__(input_ids, text_mask) 50 | 51 | @torch.inference_mode() 52 | def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: 53 | """ 54 | get_emb 55 | """ 56 | device = next(self.parameters()).device 57 | emb_text: torch.Tensor = self.emb_text( 58 | input_ids[text_mask].narrow(1, 0, 1).squeeze_(1).to(device) 59 | ) 60 | 61 | text_mask_inv = text_mask.logical_not().to(device) 62 | masked_input_ids: torch.Tensor = input_ids[text_mask_inv].to(device) 63 | 64 | emb_code = [ 65 | self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) 66 | ] 67 | emb_code = torch.stack(emb_code, 2).sum(2) 68 | 69 | emb = torch.zeros( 70 | (input_ids.shape[:-1]) + (emb_text.shape[-1],), 71 | device=emb_text.device, 72 | dtype=emb_text.dtype, 73 | ) 74 | emb[text_mask] = emb_text 75 | emb[text_mask_inv] = emb_code.to(emb.dtype) 76 | 77 | del emb_text, emb_code, text_mask_inv 78 | 79 | return emb 80 | -------------------------------------------------------------------------------- /ChatTTS/model/processors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers.generation import TopKLogitsWarper, TopPLogitsWarper 4 | 5 | 6 | class CustomRepetitionPenaltyLogitsProcessorRepeat: 7 | 8 | def __init__(self, penalty: float, max_input_ids: int, past_window: int): 9 | if not isinstance(penalty, float) or not (penalty > 0): 10 | raise ValueError( 11 | f"`penalty` has to be a strictly positive float, but is {penalty}" 12 | ) 13 | 14 | self.penalty = penalty 15 | self.max_input_ids = max_input_ids 16 | self.past_window = past_window 17 | 18 | def __call__( 19 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 20 | ) -> torch.FloatTensor: 21 | if input_ids.size(1) > self.past_window: 22 | input_ids = input_ids.narrow(1, -self.past_window, self.past_window) 23 | freq = F.one_hot(input_ids, scores.size(1)).sum(1) 24 | if freq.size(0) > self.max_input_ids: 25 | freq.narrow( 26 | 0, self.max_input_ids, freq.size(0) - self.max_input_ids 27 | ).zero_() 28 | alpha = torch.pow(self.penalty, freq) 29 | scores = scores.contiguous() 30 | inp = scores.multiply(alpha) 31 | oth = scores.divide(alpha) 32 | con = scores < 0 33 | out = torch.where(con, inp, oth) 34 | del inp, oth, scores, con, alpha 35 | return out 36 | 37 | 38 | def gen_logits( 39 | num_code: int, 40 | top_P=0.7, 41 | top_K=20, 42 | repetition_penalty=1.0, 43 | ): 44 | logits_warpers = [] 45 | if top_P is not None: 46 | logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 47 | if top_K is not None: 48 | logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 49 | 50 | logits_processors = [] 51 | if repetition_penalty is not None and repetition_penalty != 1: 52 | logits_processors.append( 53 | CustomRepetitionPenaltyLogitsProcessorRepeat( 54 | repetition_penalty, num_code, 16 55 | ) 56 | ) 57 | 58 | return logits_warpers, logits_processors 59 | -------------------------------------------------------------------------------- /ChatTTS/model/speaker.py: -------------------------------------------------------------------------------- 1 | import lzma 2 | from typing import List, Optional, Union 3 | 4 | import pybase16384 as b14 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | 10 | class Speaker: 11 | def __init__(self, dim: int, spk_cfg: str, device=torch.device("cpu")) -> None: 12 | spk_stat = torch.from_numpy( 13 | np.frombuffer(b14.decode_from_string(spk_cfg), dtype=np.float16).copy() 14 | ).to(device=device) 15 | self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) 16 | self.dim = dim 17 | 18 | def sample_random(self) -> str: 19 | return self._encode(self._sample_random()) 20 | 21 | @torch.inference_mode() 22 | def apply( 23 | self, 24 | emb: torch.Tensor, 25 | spk_emb: Union[str, torch.Tensor], 26 | input_ids: torch.Tensor, 27 | spk_emb_ids: int, 28 | device: torch.device, 29 | inplace: bool = True, 30 | ) -> torch.Tensor: 31 | if isinstance(spk_emb, str): 32 | spk_emb_tensor = torch.from_numpy(self._decode(spk_emb)) 33 | else: 34 | spk_emb_tensor = spk_emb 35 | n = ( 36 | F.normalize( 37 | spk_emb_tensor, 38 | p=2.0, 39 | dim=0, 40 | eps=1e-12, 41 | ) 42 | .to(device) 43 | .unsqueeze_(0) 44 | .expand(emb.size(0), -1) 45 | .unsqueeze_(1) 46 | .expand(emb.shape) 47 | ) 48 | cond = input_ids.narrow(-1, 0, 1).eq(spk_emb_ids).expand(emb.shape) 49 | out = torch.where(cond, n, emb, out=emb if inplace else None) 50 | if inplace: 51 | del cond, n 52 | return out 53 | 54 | @staticmethod 55 | @torch.no_grad() 56 | def decorate_code_prompts( 57 | text: List[str], 58 | prompt: str, 59 | txt_smp: Optional[str], 60 | spk_emb: Optional[str], 61 | ) -> List[str]: 62 | for i, t in enumerate(text): 63 | text[i] = ( 64 | t.replace("[Stts]", "") 65 | .replace("[spk_emb]", "") 66 | .replace("[empty_spk]", "") 67 | .strip() 68 | ) 69 | """ 70 | see https://github.com/2noise/ChatTTS/issues/459 71 | """ 72 | 73 | if prompt: 74 | text = [prompt + i for i in text] 75 | 76 | txt_smp = "" if txt_smp is None else txt_smp 77 | if spk_emb is not None: 78 | text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text] 79 | else: 80 | text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] 81 | 82 | return text 83 | 84 | @staticmethod 85 | @torch.no_grad() 86 | def decorate_text_prompts(text: List[str], prompt: str) -> List[str]: 87 | return [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text] 88 | 89 | @staticmethod 90 | @torch.no_grad() 91 | def encode_prompt(prompt: torch.Tensor) -> str: 92 | arr: np.ndarray = prompt.cpu().numpy().astype(np.uint16) 93 | shp = arr.shape 94 | assert len(shp) == 2, "prompt must be a 2D tensor" 95 | s = b14.encode_to_string( 96 | np.array(shp, dtype=" torch.Tensor: 109 | dec = b14.decode_from_string(prompt) 110 | shp = np.frombuffer(dec[:4], dtype=" torch.Tensor: 124 | spk = ( 125 | torch.randn(self.dim, device=self.std.device, dtype=self.std.dtype) 126 | .mul_(self.std) 127 | .add_(self.mean) 128 | ) 129 | return spk 130 | 131 | @staticmethod 132 | @torch.no_grad() 133 | def _encode(spk_emb: torch.Tensor) -> str: 134 | arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy() 135 | s = b14.encode_to_string( 136 | lzma.compress( 137 | arr.tobytes(), 138 | format=lzma.FORMAT_RAW, 139 | filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], 140 | ), 141 | ) 142 | del arr 143 | return s 144 | 145 | @staticmethod 146 | def _decode(spk_emb: str) -> np.ndarray: 147 | return np.frombuffer( 148 | lzma.decompress( 149 | b14.decode_from_string(spk_emb), 150 | format=lzma.FORMAT_RAW, 151 | filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], 152 | ), 153 | dtype=np.float16, 154 | ).copy() 155 | -------------------------------------------------------------------------------- /ChatTTS/model/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 4 | """ 5 | https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning 6 | """ 7 | 8 | from typing import List, Tuple, Optional, Union 9 | 10 | import torch 11 | from transformers import BertTokenizerFast 12 | 13 | from ..utils import del_all, FileLike 14 | 15 | 16 | class Tokenizer: 17 | def __init__( 18 | self, 19 | tokenizer_path: FileLike, 20 | ): 21 | """ 22 | tokenizer: BertTokenizerFast = torch.load( 23 | tokenizer_path, map_location=device, mmap=True 24 | ) 25 | # tokenizer.save_pretrained("asset/tokenizer", legacy_format=False) 26 | """ 27 | tokenizer: BertTokenizerFast = BertTokenizerFast.from_pretrained(tokenizer_path) 28 | self._tokenizer = tokenizer 29 | 30 | self.len = len(tokenizer) 31 | self.spk_emb_ids = tokenizer.convert_tokens_to_ids("[spk_emb]") 32 | self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]") 33 | self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]") 34 | 35 | @torch.inference_mode() 36 | def encode( 37 | self, 38 | text: List[str], 39 | num_vq: int, 40 | prompt: Optional[torch.Tensor] = None, 41 | device="cpu", 42 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 43 | 44 | input_ids_lst = [] 45 | attention_mask_lst = [] 46 | max_input_ids_len = -1 47 | max_attention_mask_len = -1 48 | prompt_size = 0 49 | 50 | if prompt is not None: 51 | assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq" 52 | prompt_size = prompt.size(1) 53 | 54 | # avoid random speaker embedding of tokenizer in the other dims 55 | for t in text: 56 | x = self._tokenizer.encode_plus( 57 | t, return_tensors="pt", add_special_tokens=False, padding=True 58 | ) 59 | input_ids_lst.append(x["input_ids"].squeeze_(0)) 60 | attention_mask_lst.append(x["attention_mask"].squeeze_(0)) 61 | del_all(x) 62 | ids_sz = input_ids_lst[-1].size(0) 63 | if ids_sz > max_input_ids_len: 64 | max_input_ids_len = ids_sz 65 | attn_sz = attention_mask_lst[-1].size(0) 66 | if attn_sz > max_attention_mask_len: 67 | max_attention_mask_len = attn_sz 68 | 69 | if prompt is not None: 70 | max_input_ids_len += prompt_size 71 | max_attention_mask_len += prompt_size 72 | 73 | input_ids = torch.zeros( 74 | len(input_ids_lst), 75 | max_input_ids_len, 76 | device=device, 77 | dtype=input_ids_lst[0].dtype, 78 | ) 79 | for i in range(len(input_ids_lst)): 80 | input_ids.narrow(0, i, 1).narrow( 81 | 1, 82 | max_input_ids_len - prompt_size - input_ids_lst[i].size(0), 83 | input_ids_lst[i].size(0), 84 | ).copy_( 85 | input_ids_lst[i] 86 | ) # left padding 87 | del_all(input_ids_lst) 88 | 89 | attention_mask = torch.zeros( 90 | len(attention_mask_lst), 91 | max_attention_mask_len, 92 | device=device, 93 | dtype=attention_mask_lst[0].dtype, 94 | ) 95 | for i in range(len(attention_mask_lst)): 96 | attn = attention_mask.narrow(0, i, 1) 97 | attn.narrow( 98 | 1, 99 | max_attention_mask_len - prompt_size - attention_mask_lst[i].size(0), 100 | attention_mask_lst[i].size(0), 101 | ).copy_( 102 | attention_mask_lst[i] 103 | ) # left padding 104 | if prompt_size > 0: 105 | attn.narrow( 106 | 1, 107 | max_attention_mask_len - prompt_size, 108 | prompt_size, 109 | ).fill_(1) 110 | del_all(attention_mask_lst) 111 | 112 | text_mask = attention_mask.bool() 113 | new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone() 114 | del input_ids 115 | 116 | if prompt_size > 0: 117 | text_mask.narrow(1, max_input_ids_len - prompt_size, prompt_size).fill_(0) 118 | prompt_t = prompt.t().unsqueeze_(0).expand(new_input_ids.size(0), -1, -1) 119 | new_input_ids.narrow( 120 | 1, 121 | max_input_ids_len - prompt_size, 122 | prompt_size, 123 | ).copy_(prompt_t) 124 | del prompt_t 125 | 126 | return new_input_ids, attention_mask, text_mask 127 | 128 | @torch.inference_mode 129 | def decode( 130 | self, 131 | sequences: Union[List[int], List[List[int]]], 132 | skip_special_tokens: bool = False, 133 | clean_up_tokenization_spaces: bool = None, 134 | **kwargs, 135 | ): 136 | return self._tokenizer.batch_decode( 137 | sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs 138 | ) 139 | -------------------------------------------------------------------------------- /ChatTTS/model/velocity/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm import LLM 2 | from .sampling_params import SamplingParams 3 | -------------------------------------------------------------------------------- /ChatTTS/model/velocity/llm.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | from tqdm import tqdm 4 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 5 | from vllm.utils import Counter 6 | 7 | from .configs import EngineArgs 8 | from .llm_engine import LLMEngine 9 | from .output import RequestOutput 10 | from .sampling_params import SamplingParams 11 | 12 | 13 | class LLM: 14 | """An LLM for generating texts from given prompts and sampling parameters. 15 | 16 | This class includes a tokenizer, a language model (possibly distributed 17 | across multiple GPUs), and GPU memory space allocated for intermediate 18 | states (aka KV cache). Given a batch of prompts and sampling parameters, 19 | this class generates texts from the model, using an intelligent batching 20 | mechanism and efficient memory management. 21 | 22 | NOTE: This class is intended to be used for offline inference. For online 23 | serving, use the `AsyncLLMEngine` class instead. 24 | NOTE: For the comprehensive list of arguments, see `EngineArgs`. 25 | 26 | Args: 27 | model: The name or path of a HuggingFace Transformers model. 28 | tokenizer: The name or path of a HuggingFace Transformers tokenizer. 29 | tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer 30 | if available, and "slow" will always use the slow tokenizer. 31 | trust_remote_code: Trust remote code (e.g., from HuggingFace) when 32 | downloading the model and tokenizer. 33 | tensor_parallel_size: The number of GPUs to use for distributed 34 | execution with tensor parallelism. 35 | dtype: The data type for the model weights and activations. Currently, 36 | we support `float32`, `float16`, and `bfloat16`. If `auto`, we use 37 | the `torch_dtype` attribute specified in the model config file. 38 | However, if the `torch_dtype` in the config is `float32`, we will 39 | use `float16` instead. 40 | quantization: The method used to quantize the model weights. Currently, 41 | we support "awq", "gptq" and "squeezellm". If None, we first check 42 | the `quantization_config` attribute in the model config file. If 43 | that is None, we assume the model weights are not quantized and use 44 | `dtype` to determine the data type of the weights. 45 | revision: The specific model version to use. It can be a branch name, 46 | a tag name, or a commit id. 47 | tokenizer_revision: The specific tokenizer version to use. It can be a 48 | branch name, a tag name, or a commit id. 49 | seed: The seed to initialize the random number generator for sampling. 50 | gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to 51 | reserve for the model weights, activations, and KV cache. Higher 52 | values will increase the KV cache size and thus improve the model's 53 | throughput. However, if the value is too high, it may cause out-of- 54 | memory (OOM) errors. 55 | swap_space: The size (GiB) of CPU memory per GPU to use as swap space. 56 | This can be used for temporarily storing the states of the requests 57 | when their `best_of` sampling parameters are larger than 1. If all 58 | requests will have `best_of=1`, you can safely set this to 0. 59 | Otherwise, too small values may cause out-of-memory (OOM) errors. 60 | enforce_eager: Whether to enforce eager execution. If True, we will 61 | disable CUDA graph and always execute the model in eager mode. 62 | If False, we will use CUDA graph and eager execution in hybrid. 63 | max_context_len_to_capture: Maximum context len covered by CUDA graphs. 64 | When a sequence has context length larger than this, we fall back 65 | to eager mode. 66 | """ 67 | 68 | def __init__( 69 | self, 70 | model: str, 71 | tokenizer: Optional[str] = None, 72 | tokenizer_mode: str = "auto", 73 | trust_remote_code: bool = False, 74 | tensor_parallel_size: int = 1, 75 | dtype: str = "auto", 76 | quantization: Optional[str] = None, 77 | revision: Optional[str] = None, 78 | tokenizer_revision: Optional[str] = None, 79 | seed: int = 0, 80 | gpu_memory_utilization: float = 0.9, 81 | swap_space: int = 4, 82 | enforce_eager: bool = False, 83 | max_context_len_to_capture: int = 8192, 84 | post_model_path: str = None, 85 | num_audio_tokens: int = 0, 86 | num_text_tokens: int = 0, 87 | **kwargs, 88 | ) -> None: 89 | if "disable_log_stats" not in kwargs: 90 | kwargs["disable_log_stats"] = True 91 | engine_args = EngineArgs( 92 | model=model, 93 | tokenizer=tokenizer, 94 | tokenizer_mode=tokenizer_mode, 95 | trust_remote_code=trust_remote_code, 96 | tensor_parallel_size=tensor_parallel_size, 97 | dtype=dtype, 98 | quantization=quantization, 99 | revision=revision, 100 | tokenizer_revision=tokenizer_revision, 101 | seed=seed, 102 | gpu_memory_utilization=gpu_memory_utilization, 103 | swap_space=swap_space, 104 | enforce_eager=enforce_eager, 105 | max_context_len_to_capture=max_context_len_to_capture, 106 | num_audio_tokens=num_audio_tokens, 107 | num_text_tokens=num_text_tokens, 108 | **kwargs, 109 | ) 110 | self.llm_engine = LLMEngine.from_engine_args(engine_args, post_model_path) 111 | self.request_counter = Counter() 112 | 113 | def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: 114 | return self.llm_engine.tokenizer 115 | 116 | def set_tokenizer( 117 | self, 118 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 119 | ) -> None: 120 | self.llm_engine.tokenizer = tokenizer 121 | 122 | def generate( 123 | self, 124 | prompts: Optional[Union[str, List[str]]] = None, 125 | sampling_params: Optional[SamplingParams] = None, 126 | prompt_token_ids: Optional[List[List[int]]] = None, 127 | use_tqdm: bool = True, 128 | ) -> List[RequestOutput]: 129 | """Generates the completions for the input prompts. 130 | 131 | NOTE: This class automatically batches the given prompts, considering 132 | the memory constraint. For the best performance, put all of your prompts 133 | into a single list and pass it to this method. 134 | 135 | Args: 136 | prompts: A list of prompts to generate completions for. 137 | sampling_params: The sampling parameters for text generation. If 138 | None, we use the default sampling parameters. 139 | prompt_token_ids: A list of token IDs for the prompts. If None, we 140 | use the tokenizer to convert the prompts to token IDs. 141 | use_tqdm: Whether to use tqdm to display the progress bar. 142 | 143 | Returns: 144 | A list of `RequestOutput` objects containing the generated 145 | completions in the same order as the input prompts. 146 | """ 147 | if prompts is None and prompt_token_ids is None: 148 | raise ValueError("Either prompts or prompt_token_ids must be " "provided.") 149 | if isinstance(prompts, str): 150 | # Convert a single prompt to a list. 151 | prompts = [prompts] 152 | if ( 153 | prompts is not None 154 | and prompt_token_ids is not None 155 | and len(prompts) != len(prompt_token_ids) 156 | ): 157 | raise ValueError( 158 | "The lengths of prompts and prompt_token_ids " "must be the same." 159 | ) 160 | if sampling_params is None: 161 | # Use default sampling params. 162 | sampling_params = SamplingParams() 163 | 164 | # Add requests to the engine. 165 | num_requests = len(prompts) if prompts is not None else len(prompt_token_ids) 166 | for i in range(num_requests): 167 | prompt = prompts[i] if prompts is not None else None 168 | token_ids = None if prompt_token_ids is None else prompt_token_ids[i] 169 | self._add_request(prompt, sampling_params, token_ids) 170 | 171 | rtns = self._run_engine(use_tqdm) 172 | for i, rtn in enumerate(rtns): 173 | token_ids = rtn.outputs[0].token_ids 174 | for j, token_id in enumerate(token_ids): 175 | if len(token_id) == 1: 176 | token_ids[j] = token_id[0] 177 | else: 178 | token_ids[j] = list(token_id) 179 | 180 | return rtns 181 | 182 | def _add_request( 183 | self, 184 | prompt: Optional[str], 185 | sampling_params: SamplingParams, 186 | prompt_token_ids: Optional[List[int]], 187 | ) -> None: 188 | request_id = str(next(self.request_counter)) 189 | self.llm_engine.add_request( 190 | request_id, prompt, sampling_params, prompt_token_ids 191 | ) 192 | 193 | def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: 194 | # Initialize tqdm. 195 | if use_tqdm: 196 | num_requests = self.llm_engine.get_num_unfinished_requests() 197 | pbar = tqdm(total=num_requests, desc="Processed prompts") 198 | # Run the engine. 199 | outputs: List[RequestOutput] = [] 200 | while self.llm_engine.has_unfinished_requests(): 201 | step_outputs = self.llm_engine.step() 202 | for output in step_outputs: 203 | if output.finished: 204 | outputs.append(output) 205 | if use_tqdm: 206 | pbar.update(1) 207 | if use_tqdm: 208 | pbar.close() 209 | # Sort the outputs by request ID. 210 | # This is necessary because some requests may be finished earlier than 211 | # its previous requests. 212 | outputs = sorted(outputs, key=lambda x: int(x.request_id)) 213 | return outputs 214 | -------------------------------------------------------------------------------- /ChatTTS/model/velocity/model_loader.py: -------------------------------------------------------------------------------- 1 | """Utilities for selecting and loading models.""" 2 | 3 | import contextlib 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from vllm.config import ModelConfig 9 | from vllm.model_executor.models import ModelRegistry 10 | from vllm.model_executor.weight_utils import get_quant_config, initialize_dummy_weights 11 | 12 | from .llama import LlamaModel 13 | 14 | 15 | @contextlib.contextmanager 16 | def _set_default_torch_dtype(dtype: torch.dtype): 17 | """Sets the default torch dtype to the given dtype.""" 18 | old_dtype = torch.get_default_dtype() 19 | torch.set_default_dtype(dtype) 20 | yield 21 | torch.set_default_dtype(old_dtype) 22 | 23 | 24 | def get_model(model_config: ModelConfig) -> nn.Module: 25 | # Get the (maybe quantized) linear method. 26 | linear_method = None 27 | if model_config.quantization is not None: 28 | quant_config = get_quant_config( 29 | model_config.quantization, 30 | model_config.model, 31 | model_config.hf_config, 32 | model_config.download_dir, 33 | ) 34 | capability = torch.cuda.get_device_capability() 35 | capability = capability[0] * 10 + capability[1] 36 | if capability < quant_config.get_min_capability(): 37 | raise ValueError( 38 | f"The quantization method {model_config.quantization} is not " 39 | "supported for the current GPU. " 40 | f"Minimum capability: {quant_config.get_min_capability()}. " 41 | f"Current capability: {capability}." 42 | ) 43 | supported_dtypes = quant_config.get_supported_act_dtypes() 44 | if model_config.dtype not in supported_dtypes: 45 | raise ValueError( 46 | f"{model_config.dtype} is not supported for quantization " 47 | f"method {model_config.quantization}. Supported dtypes: " 48 | f"{supported_dtypes}" 49 | ) 50 | linear_method = quant_config.get_linear_method() 51 | 52 | with _set_default_torch_dtype(model_config.dtype): 53 | # Create a model instance. 54 | # The weights will be initialized as empty tensors. 55 | with torch.device("cuda"): 56 | model = LlamaModel(model_config.hf_config, linear_method) 57 | if model_config.load_format == "dummy": 58 | # NOTE(woosuk): For accurate performance evaluation, we assign 59 | # random values to the weights. 60 | initialize_dummy_weights(model) 61 | else: 62 | # Load the weights from the cached or downloaded files. 63 | model.load_weights( 64 | model_config.model, 65 | model_config.download_dir, 66 | model_config.load_format, 67 | model_config.revision, 68 | ) 69 | return model.eval() 70 | -------------------------------------------------------------------------------- /ChatTTS/model/velocity/output.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | import torch 3 | 4 | from .sequence import ( 5 | PromptLogprobs, 6 | SampleLogprobs, 7 | SequenceGroup, 8 | SequenceStatus, 9 | ) 10 | 11 | 12 | class CompletionOutput: 13 | """The output data of one completion output of a request. 14 | 15 | Args: 16 | index: The index of the output in the request. 17 | text: The generated output text. 18 | token_ids: The token IDs of the generated output text. 19 | cumulative_logprob: The cumulative log probability of the generated 20 | output text. 21 | logprobs: The log probabilities of the top probability words at each 22 | position if the logprobs are requested. 23 | finish_reason: The reason why the sequence is finished. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | index: int, 29 | text: str, 30 | token_ids: List[int], 31 | cumulative_logprob: float, 32 | logprobs: Optional[SampleLogprobs], 33 | finish_reason: Optional[str] = None, 34 | hidden_states: Optional[torch.Tensor] = None, 35 | ) -> None: 36 | self.index = index 37 | self.text = text 38 | self.token_ids = token_ids 39 | self.cumulative_logprob = cumulative_logprob 40 | self.logprobs = logprobs 41 | self.finish_reason = finish_reason 42 | self.hidden_states = hidden_states 43 | 44 | def finished(self) -> bool: 45 | return self.finish_reason is not None 46 | 47 | def __repr__(self) -> str: 48 | return ( 49 | f"CompletionOutput(index={self.index}, " 50 | f"text={self.text!r}, " 51 | f"token_ids={self.token_ids}, " 52 | f"cumulative_logprob={self.cumulative_logprob}, " 53 | f"logprobs={self.logprobs}, " 54 | f"finish_reason={self.finish_reason}, " 55 | f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None})" 56 | ) 57 | 58 | 59 | class RequestOutput: 60 | """The output data of a request to the LLM. 61 | 62 | Args: 63 | request_id: The unique ID of the request. 64 | prompt: The prompt string of the request. 65 | prompt_token_ids: The token IDs of the prompt. 66 | prompt_logprobs: The log probabilities to return per prompt token. 67 | outputs: The output sequences of the request. 68 | finished: Whether the whole request is finished. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | request_id: str, 74 | prompt: str, 75 | prompt_token_ids: List[int], 76 | prompt_logprobs: Optional[PromptLogprobs], 77 | outputs: List[CompletionOutput], 78 | finished: bool, 79 | ) -> None: 80 | self.request_id = request_id 81 | self.prompt = prompt 82 | self.prompt_token_ids = prompt_token_ids 83 | self.prompt_logprobs = prompt_logprobs 84 | self.outputs = outputs 85 | self.finished = finished 86 | 87 | @classmethod 88 | def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": 89 | # Get the top-n sequences. 90 | n = seq_group.sampling_params.n 91 | seqs = seq_group.get_seqs() 92 | if seq_group.sampling_params.use_beam_search: 93 | sorting_key = lambda seq: seq.get_beam_search_score( 94 | seq_group.sampling_params.length_penalty 95 | ) 96 | else: 97 | sorting_key = lambda seq: seq.get_cumulative_logprob() 98 | sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) 99 | top_n_seqs = sorted_seqs[:n] 100 | 101 | # Create the outputs. 102 | outputs: List[CompletionOutput] = [] 103 | for seq in top_n_seqs: 104 | logprobs = seq.output_logprobs 105 | if seq_group.sampling_params.logprobs is None: 106 | # NOTE: We need to take care of this case because the sequence 107 | # always has the logprobs of the sampled tokens even if the 108 | # logprobs are not requested. 109 | logprobs = None 110 | finshed_reason = SequenceStatus.get_finished_reason(seq.status) 111 | output = CompletionOutput( 112 | seqs.index(seq), 113 | seq.output_text, 114 | seq.get_output_token_ids(), 115 | seq.get_cumulative_logprob(), 116 | logprobs, 117 | finshed_reason, 118 | seq.data.hidden_states, 119 | ) 120 | outputs.append(output) 121 | 122 | # Every sequence in the sequence group should have the same prompt. 123 | prompt = seq_group.prompt 124 | prompt_token_ids = seq_group.prompt_token_ids 125 | prompt_logprobs = seq_group.prompt_logprobs 126 | finished = seq_group.is_finished() 127 | return cls( 128 | seq_group.request_id, 129 | prompt, 130 | prompt_token_ids, 131 | prompt_logprobs, 132 | outputs, 133 | finished, 134 | ) 135 | 136 | def __repr__(self) -> str: 137 | return ( 138 | f"RequestOutput(request_id={self.request_id}, " 139 | f"prompt={self.prompt!r}, " 140 | f"prompt_token_ids={self.prompt_token_ids}, " 141 | f"prompt_logprobs={self.prompt_logprobs}, " 142 | f"outputs={self.outputs}, " 143 | f"finished={self.finished})" 144 | ) 145 | -------------------------------------------------------------------------------- /ChatTTS/model/velocity/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.functional import F 3 | from typing import List, Callable 4 | 5 | from ..embed import Embed 6 | 7 | 8 | class Sampler: 9 | def __init__(self, post_model: Embed, num_audio_tokens: int, num_vq: int): 10 | self.post_model = post_model 11 | self.device = next(self.post_model.parameters()).device 12 | self.num_audio_tokens = num_audio_tokens 13 | self.num_vq = num_vq 14 | 15 | def sample( 16 | self, 17 | inputs_ids: torch.Tensor, 18 | hidden_states: torch.Tensor, 19 | infer_text: bool = False, 20 | temperature: torch.Tensor = 1.0, 21 | logits_processors: List[Callable] = [ 22 | lambda logits_token, logits: logits, 23 | ], 24 | logits_warpers: List[Callable] = [ 25 | lambda logits_token, logits: logits, 26 | ], 27 | min_new_token: int = 0, 28 | now_length: int = 0, 29 | eos_token: int = 0, 30 | start_idx: int = 0, 31 | ): 32 | # print(inputs_ids.shape) 33 | B = hidden_states.shape[0] 34 | 35 | end_idx = torch.zeros( 36 | inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long 37 | ) 38 | finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() 39 | if not infer_text: 40 | temperature = ( 41 | temperature.unsqueeze(0) 42 | .expand(inputs_ids.shape[0], -1) 43 | .contiguous() 44 | .view(-1, 1) 45 | ) 46 | 47 | if infer_text: 48 | logits: torch.Tensor = self.post_model.head_text(hidden_states) 49 | else: 50 | # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) 51 | logits = torch.empty( 52 | hidden_states.size(0), 53 | hidden_states.size(1), 54 | self.num_audio_tokens, 55 | self.num_vq, 56 | dtype=torch.float, 57 | device=self.device, 58 | ) 59 | for num_vq_iter in range(self.num_vq): 60 | x: torch.Tensor = self.post_model.head_code[num_vq_iter](hidden_states) 61 | logits[..., num_vq_iter] = x 62 | del x 63 | 64 | del hidden_states 65 | 66 | # logits = logits[:, -1].float() 67 | logits = logits.narrow(1, -1, 1).squeeze_(1).float() 68 | 69 | if not infer_text: 70 | # logits = rearrange(logits, "b c n -> (b n) c") 71 | logits = logits.permute(0, 2, 1) 72 | logits = logits.reshape(-1, logits.size(2)) 73 | # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") 74 | inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1) 75 | logits_token = inputs_ids_sliced.reshape( 76 | inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), 77 | -1, 78 | ).to(self.device) 79 | else: 80 | logits_token = inputs_ids[:, start_idx:, 0].to(self.device) 81 | 82 | logits /= temperature 83 | 84 | for logitsProcessors in logits_processors: 85 | logits = logitsProcessors(logits_token, logits) 86 | 87 | for logitsWarpers in logits_warpers: 88 | logits = logitsWarpers(logits_token, logits) 89 | 90 | del logits_token 91 | 92 | if now_length < min_new_token: 93 | logits[:, eos_token] = -torch.inf 94 | 95 | scores = F.softmax(logits, dim=-1) 96 | idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) 97 | if not infer_text: 98 | scores = scores.reshape(B, -1, scores.shape[-1]) 99 | if not infer_text: 100 | # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) 101 | idx_next = idx_next.view(-1, self.num_vq) 102 | finish_or = idx_next.eq(eos_token).any(1) 103 | finish.logical_or_(finish_or) 104 | del finish_or 105 | else: 106 | finish_or = idx_next.eq(eos_token).any(1) 107 | finish.logical_or_(finish_or) 108 | del finish_or 109 | 110 | del inputs_ids 111 | 112 | not_finished = finish.logical_not().to(end_idx.device) 113 | 114 | end_idx.add_(not_finished.int()) 115 | idx_next = idx_next[:, None, :] 116 | return ( 117 | idx_next, 118 | torch.log(scores), 119 | finish, 120 | ) 121 | -------------------------------------------------------------------------------- /ChatTTS/model/velocity/worker.py: -------------------------------------------------------------------------------- 1 | """A GPU worker class.""" 2 | 3 | import os 4 | from typing import Dict, List, Optional, Tuple 5 | 6 | import torch 7 | import torch.distributed 8 | 9 | from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig 10 | from vllm.model_executor import set_random_seed 11 | from vllm.model_executor.parallel_utils.communication_op import broadcast_object_list 12 | from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel 13 | from vllm.sequence import SamplerOutput, SequenceGroupMetadata 14 | from vllm.worker.cache_engine import CacheEngine 15 | 16 | from .model_runner import ModelRunner 17 | 18 | 19 | class Worker: 20 | """A worker class that executes (a partition of) the model on a GPU. 21 | 22 | Each worker is associated with a single GPU. The worker is responsible for 23 | maintaining the KV cache and executing the model on the GPU. In case of 24 | distributed inference, each worker is assigned a partition of the model. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | model_config: ModelConfig, 30 | parallel_config: ParallelConfig, 31 | scheduler_config: SchedulerConfig, 32 | local_rank: int, 33 | rank: int, 34 | distributed_init_method: str, 35 | post_model_path: str, 36 | is_driver_worker: bool = False, 37 | ) -> None: 38 | self.model_config = model_config 39 | self.parallel_config = parallel_config 40 | self.scheduler_config = scheduler_config 41 | self.local_rank = local_rank 42 | self.rank = rank 43 | self.distributed_init_method = distributed_init_method 44 | self.is_driver_worker = is_driver_worker 45 | self.post_model_path = post_model_path 46 | 47 | if self.is_driver_worker: 48 | assert self.rank == 0, "The driver worker must have rank 0." 49 | 50 | self.model_runner = ModelRunner( 51 | model_config, 52 | parallel_config, 53 | scheduler_config, 54 | is_driver_worker, 55 | post_model_path, 56 | ) 57 | # Uninitialized cache engine. Will be initialized by 58 | # self.init_cache_engine(). 59 | self.cache_config = None 60 | self.cache_engine = None 61 | self.cache_events = None 62 | self.gpu_cache = None 63 | 64 | def init_model(self) -> None: 65 | # torch.distributed.all_reduce does not free the input tensor until 66 | # the synchronization point. This causes the memory usage to grow 67 | # as the number of all_reduce calls increases. This env var disables 68 | # this behavior. 69 | # Related issue: 70 | # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 71 | os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" 72 | 73 | # This env var set by Ray causes exceptions with graph building. 74 | os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) 75 | self.device = torch.device(f"cuda:{self.local_rank}") 76 | torch.cuda.set_device(self.device) 77 | 78 | _check_if_gpu_supports_dtype(self.model_config.dtype) 79 | 80 | # Initialize the distributed environment. 81 | _init_distributed_environment( 82 | self.parallel_config, self.rank, self.distributed_init_method 83 | ) 84 | 85 | # Initialize the model. 86 | set_random_seed(self.model_config.seed) 87 | 88 | def load_model(self): 89 | self.model_runner.load_model() 90 | 91 | @torch.inference_mode() 92 | def profile_num_available_blocks( 93 | self, 94 | block_size: int, 95 | gpu_memory_utilization: float, 96 | cpu_swap_space: int, 97 | ) -> Tuple[int, int]: 98 | # Profile the memory usage of the model and get the maximum number of 99 | # cache blocks that can be allocated with the remaining free memory. 100 | torch.cuda.empty_cache() 101 | 102 | # Execute a forward pass with dummy inputs to profile the memory usage 103 | # of the model. 104 | self.model_runner.profile_run() 105 | 106 | # Calculate the number of blocks that can be allocated with the 107 | # profiled peak memory. 108 | torch.cuda.synchronize() 109 | free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() 110 | peak_memory = total_gpu_memory - free_gpu_memory 111 | 112 | cache_block_size = CacheEngine.get_cache_block_size( 113 | block_size, self.model_config, self.parallel_config 114 | ) 115 | num_gpu_blocks = int( 116 | (total_gpu_memory * gpu_memory_utilization - peak_memory) 117 | // cache_block_size 118 | ) 119 | num_cpu_blocks = int(cpu_swap_space // cache_block_size) 120 | num_gpu_blocks = max(num_gpu_blocks, 0) 121 | num_cpu_blocks = max(num_cpu_blocks, 0) 122 | torch.cuda.empty_cache() 123 | return num_gpu_blocks, num_cpu_blocks 124 | 125 | def init_cache_engine(self, cache_config: CacheConfig) -> None: 126 | self.cache_config = cache_config 127 | self.cache_engine = CacheEngine( 128 | self.cache_config, self.model_config, self.parallel_config 129 | ) 130 | self.cache_events = self.cache_engine.events 131 | self.gpu_cache = self.cache_engine.gpu_cache 132 | self.model_runner.set_block_size(self.cache_engine.block_size) 133 | 134 | def warm_up_model(self) -> None: 135 | if not self.model_config.enforce_eager: 136 | self.model_runner.capture_model(self.gpu_cache) 137 | # Reset the seed to ensure that the random state is not affected by 138 | # the model initialization and profiling. 139 | set_random_seed(self.model_config.seed) 140 | 141 | def cache_swap( 142 | self, 143 | blocks_to_swap_in: Dict[int, int], 144 | blocks_to_swap_out: Dict[int, int], 145 | blocks_to_copy: Dict[int, List[int]], 146 | ) -> None: 147 | # Issue cache operations. 148 | issued_cache_op = False 149 | if blocks_to_swap_in: 150 | self.cache_engine.swap_in(blocks_to_swap_in) 151 | issued_cache_op = True 152 | if blocks_to_swap_out: 153 | self.cache_engine.swap_out(blocks_to_swap_out) 154 | issued_cache_op = True 155 | if blocks_to_copy: 156 | self.cache_engine.copy(blocks_to_copy) 157 | issued_cache_op = True 158 | 159 | cache_events = self.cache_events if issued_cache_op else None 160 | 161 | # Wait for cache operations to finish. 162 | # TODO(woosuk): Profile swapping overhead and optimize if needed. 163 | if cache_events is not None: 164 | for event in cache_events: 165 | event.wait() 166 | 167 | @torch.inference_mode() 168 | def execute_model( 169 | self, 170 | seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, 171 | blocks_to_swap_in: Optional[Dict[int, int]] = None, 172 | blocks_to_swap_out: Optional[Dict[int, int]] = None, 173 | blocks_to_copy: Optional[Dict[int, List[int]]] = None, 174 | ) -> Optional[SamplerOutput]: 175 | if self.is_driver_worker: 176 | assert seq_group_metadata_list is not None 177 | num_seq_groups = len(seq_group_metadata_list) 178 | assert blocks_to_swap_in is not None 179 | assert blocks_to_swap_out is not None 180 | assert blocks_to_copy is not None 181 | block_swapping_info = [ 182 | blocks_to_swap_in, 183 | blocks_to_swap_out, 184 | blocks_to_copy, 185 | ] 186 | broadcast_object_list([num_seq_groups] + block_swapping_info, src=0) 187 | else: 188 | # num_seq_groups, blocks_to_swap_in, blocks_to_swap_out, 189 | # blocks_to_copy (4 elements) 190 | recv_data = [None] * 4 191 | broadcast_object_list(recv_data, src=0) 192 | num_seq_groups = recv_data[0] 193 | block_swapping_info = recv_data[1:] 194 | 195 | self.cache_swap(*block_swapping_info) 196 | 197 | # If there is no input, we don't need to execute the model. 198 | if num_seq_groups == 0: 199 | return {} 200 | 201 | output = self.model_runner.execute_model( 202 | seq_group_metadata_list, self.gpu_cache 203 | ) 204 | return output 205 | 206 | 207 | def _init_distributed_environment( 208 | parallel_config: ParallelConfig, 209 | rank: int, 210 | distributed_init_method: Optional[str] = None, 211 | ) -> None: 212 | """Initialize the distributed environment.""" 213 | if torch.distributed.is_initialized(): 214 | torch_world_size = torch.distributed.get_world_size() 215 | if torch_world_size != parallel_config.world_size: 216 | raise RuntimeError( 217 | "torch.distributed is already initialized but the torch world " 218 | "size does not match parallel_config.world_size " 219 | f"({torch_world_size} vs. {parallel_config.world_size})." 220 | ) 221 | elif not distributed_init_method: 222 | raise ValueError( 223 | "distributed_init_method must be set if torch.distributed " 224 | "is not already initialized" 225 | ) 226 | else: 227 | torch.distributed.init_process_group( 228 | backend="nccl", 229 | world_size=parallel_config.world_size, 230 | rank=rank, 231 | init_method=distributed_init_method, 232 | ) 233 | 234 | # A small all_reduce for warmup. 235 | torch.distributed.all_reduce(torch.zeros(1).cuda()) 236 | initialize_model_parallel( 237 | parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size 238 | ) 239 | 240 | 241 | def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): 242 | # Check if the GPU supports the dtype. 243 | if torch_dtype == torch.bfloat16: 244 | compute_capability = torch.cuda.get_device_capability() 245 | if compute_capability[0] < 8: 246 | gpu_name = torch.cuda.get_device_name() 247 | raise ValueError( 248 | "Bfloat16 is only supported on GPUs with compute capability " 249 | f"of at least 8.0. Your {gpu_name} GPU has compute capability " 250 | f"{compute_capability[0]}.{compute_capability[1]}." 251 | ) 252 | -------------------------------------------------------------------------------- /ChatTTS/norm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | from typing import Dict, Tuple, List, Literal, Callable, Optional 5 | import sys 6 | 7 | from numba import jit 8 | import numpy as np 9 | 10 | from .utils import del_all 11 | 12 | 13 | @jit(nopython=True) 14 | def _find_index(table: np.ndarray, val: np.uint16): 15 | for i in range(table.size): 16 | if table[i] == val: 17 | return i 18 | return -1 19 | 20 | 21 | @jit(nopython=True) 22 | def _fast_replace( 23 | table: np.ndarray, text: bytes 24 | ) -> Tuple[np.ndarray, List[Tuple[str, str]]]: 25 | result = np.frombuffer(text, dtype=np.uint16).copy() 26 | replaced_words = [] 27 | for i in range(result.size): 28 | ch = result[i] 29 | p = _find_index(table[0], ch) 30 | if p >= 0: 31 | repl_char = table[1][p] 32 | result[i] = repl_char 33 | replaced_words.append((chr(ch), chr(repl_char))) 34 | return result, replaced_words 35 | 36 | 37 | @jit(nopython=True) 38 | def _split_tags(text: str) -> Tuple[List[str], List[str]]: 39 | texts: List[str] = [] 40 | tags: List[str] = [] 41 | current_text = "" 42 | current_tag = "" 43 | for c in text: 44 | if c == "[": 45 | texts.append(current_text) 46 | current_text = "" 47 | current_tag = c 48 | elif current_tag != "": 49 | current_tag += c 50 | else: 51 | current_text += c 52 | if c == "]": 53 | tags.append(current_tag) 54 | current_tag = "" 55 | if current_text != "": 56 | texts.append(current_text) 57 | return texts, tags 58 | 59 | 60 | @jit(nopython=True) 61 | def _combine_tags(texts: List[str], tags: List[str]) -> str: 62 | text = "" 63 | for t in texts: 64 | tg = "" 65 | if len(tags) > 0: 66 | tg = tags.pop(0) 67 | text += t + tg 68 | return text 69 | 70 | 71 | class Normalizer: 72 | def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)): 73 | self.logger = logger 74 | self.normalizers: Dict[str, Callable[[str], str]] = {} 75 | self.homophones_map = self._load_homophones_map(map_file_path) 76 | """ 77 | homophones_map 78 | 79 | Replace the mispronounced characters with correctly pronounced ones. 80 | 81 | Creation process of homophones_map.json: 82 | 83 | 1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text. 84 | 2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words. 85 | 3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS. 86 | 4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping. 87 | 88 | Thanks to: 89 | [Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html) 90 | [python-pinyin](https://github.com/mozillazg/python-pinyin) 91 | 92 | """ 93 | self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be" 94 | self.reject_pattern = re.compile(r"[^\u4e00-\u9fffA-Za-z,。、,\. ]") 95 | self.sub_pattern = re.compile(r"\[[\w_]+\]") 96 | self.chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]") 97 | self.english_word_pattern = re.compile(r"\b[A-Za-z]+\b") 98 | self.character_simplifier = str.maketrans( 99 | { 100 | ":": ",", 101 | ";": ",", 102 | "!": "。", 103 | "(": ",", 104 | ")": ",", 105 | "【": ",", 106 | "】": ",", 107 | "『": ",", 108 | "』": ",", 109 | "「": ",", 110 | "」": ",", 111 | "《": ",", 112 | "》": ",", 113 | "-": ",", 114 | ":": ",", 115 | ";": ",", 116 | "!": ".", 117 | "(": ",", 118 | ")": ",", 119 | # "[": ",", 120 | # "]": ",", 121 | ">": ",", 122 | "<": ",", 123 | "-": ",", 124 | } 125 | ) 126 | self.halfwidth_2_fullwidth = str.maketrans( 127 | { 128 | "!": "!", 129 | '"': "“", 130 | "'": "‘", 131 | "#": "#", 132 | "$": "$", 133 | "%": "%", 134 | "&": "&", 135 | "(": "(", 136 | ")": ")", 137 | ",": ",", 138 | "-": "-", 139 | "*": "*", 140 | "+": "+", 141 | ".": "。", 142 | "/": "/", 143 | ":": ":", 144 | ";": ";", 145 | "<": "<", 146 | "=": "=", 147 | ">": ">", 148 | "?": "?", 149 | "@": "@", 150 | # '[': '[', 151 | "\\": "\", 152 | # ']': ']', 153 | "^": "^", 154 | # '_': '_', 155 | "`": "`", 156 | "{": "{", 157 | "|": "|", 158 | "}": "}", 159 | "~": "~", 160 | } 161 | ) 162 | 163 | def __call__( 164 | self, 165 | text: str, 166 | do_text_normalization=True, 167 | do_homophone_replacement=True, 168 | lang: Optional[Literal["zh", "en"]] = None, 169 | ) -> str: 170 | if do_text_normalization: 171 | _lang = self._detect_language(text) if lang is None else lang 172 | if _lang in self.normalizers: 173 | texts, tags = _split_tags(text) 174 | self.logger.debug("split texts %s, tags %s", str(texts), str(tags)) 175 | texts = [self.normalizers[_lang](t) for t in texts] 176 | self.logger.debug("normed texts %s", str(texts)) 177 | text = _combine_tags(texts, tags) if len(tags) > 0 else texts[0] 178 | self.logger.debug("combined text %s", text) 179 | if _lang == "zh": 180 | text = self._apply_half2full_map(text) 181 | invalid_characters = self._count_invalid_characters(text) 182 | if len(invalid_characters): 183 | self.logger.warning(f"found invalid characters: {invalid_characters}") 184 | text = self._apply_character_map(text) 185 | if do_homophone_replacement: 186 | arr, replaced_words = _fast_replace( 187 | self.homophones_map, 188 | text.encode(self.coding), 189 | ) 190 | if replaced_words: 191 | text = arr.tobytes().decode(self.coding) 192 | repl_res = ", ".join([f"{_[0]}->{_[1]}" for _ in replaced_words]) 193 | self.logger.info(f"replace homophones: {repl_res}") 194 | if len(invalid_characters): 195 | texts, tags = _split_tags(text) 196 | self.logger.debug("split texts %s, tags %s", str(texts), str(tags)) 197 | texts = [self.reject_pattern.sub("", t) for t in texts] 198 | self.logger.debug("normed texts %s", str(texts)) 199 | text = _combine_tags(texts, tags) if len(tags) > 0 else texts[0] 200 | self.logger.debug("combined text %s", text) 201 | return text 202 | 203 | def register(self, name: str, normalizer: Callable[[str], str]) -> bool: 204 | if name in self.normalizers: 205 | self.logger.warning(f"name {name} has been registered") 206 | return False 207 | try: 208 | val = normalizer("test string 测试字符串") 209 | if not isinstance(val, str): 210 | self.logger.warning("normalizer must have caller type (str) -> str") 211 | return False 212 | except Exception as e: 213 | self.logger.warning(e) 214 | return False 215 | self.normalizers[name] = normalizer 216 | return True 217 | 218 | def unregister(self, name: str): 219 | if name in self.normalizers: 220 | del self.normalizers[name] 221 | 222 | def destroy(self): 223 | del_all(self.normalizers) 224 | del self.homophones_map 225 | 226 | def _load_homophones_map(self, map_file_path: str) -> np.ndarray: 227 | with open(map_file_path, "r", encoding="utf-8") as f: 228 | homophones_map: Dict[str, str] = json.load(f) 229 | map = np.empty((2, len(homophones_map)), dtype=np.uint32) 230 | for i, k in enumerate(homophones_map.keys()): 231 | map[:, i] = (ord(k), ord(homophones_map[k])) 232 | del homophones_map 233 | return map 234 | 235 | def _count_invalid_characters(self, s: str): 236 | s = self.sub_pattern.sub("", s) 237 | non_alphabetic_chinese_chars = self.reject_pattern.findall(s) 238 | return set(non_alphabetic_chinese_chars) 239 | 240 | def _apply_half2full_map(self, text: str) -> str: 241 | return text.translate(self.halfwidth_2_fullwidth) 242 | 243 | def _apply_character_map(self, text: str) -> str: 244 | return text.translate(self.character_simplifier) 245 | 246 | def _detect_language(self, sentence: str) -> Literal["zh", "en"]: 247 | chinese_chars = self.chinese_char_pattern.findall(sentence) 248 | english_words = self.english_word_pattern.findall(sentence) 249 | 250 | if len(chinese_chars) > len(english_words): 251 | return "zh" 252 | else: 253 | return "en" 254 | -------------------------------------------------------------------------------- /ChatTTS/res/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/2noise/ChatTTS/1092c1ffcaa82f4bdef104c35aa5541227c3e1d7/ChatTTS/res/__init__.py -------------------------------------------------------------------------------- /ChatTTS/res/sha256_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "sha256_asset_Decoder_safetensors": "77aa55e0a977949c4733df3c6f876fa85860d3298cba63295a7bc6901729d4e0", 3 | "sha256_asset_DVAE_safetensors" : "1d0b044a8368c0513100a2eca98456b289e6be6a18b7a63be1bcaa315ea874d9", 4 | "sha256_asset_Embed_safetensors" : "2ff0be7134934155741b643b74e32fb6bf3eec41257984459b2ed60cdb4c48b0", 5 | "sha256_asset_Vocos_safetensors" : "07e5561491cce41f7f90cfdb94b2ff263ff5742c3d89339db99b17ad82cc3f44", 6 | 7 | "sha256_asset_gpt_config_json" : "0aaa1ecd96c49ad4f473459eb1982fa7ad79fa5de08cde2781bf6ad1f9a0c236", 8 | "sha256_asset_gpt_model_safetensors" : "cd0806fd971f52f6a22c923ec64982b305e817bcc41ca83417fcf9141b984a0f", 9 | 10 | "sha256_asset_tokenizer_special_tokens_map_json": "bd0ac9d9bb1657996b5c5fbcaa7d80f8de530d01a283da97f89deae5b1b8d011", 11 | "sha256_asset_tokenizer_tokenizer_config_json" : "43e9d658b554fa5ee8d8e1d763349323bfef1ed7a89c0794220ab8861387d421", 12 | "sha256_asset_tokenizer_tokenizer_json" : "843838a64e121e23e774cc75874c6fe862198d9f7dd43747914633a8fd89c20e" 13 | } 14 | -------------------------------------------------------------------------------- /ChatTTS/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dl import check_all_assets, download_all_assets 2 | from .gpu import select_device 3 | from .io import load_safetensors, get_latest_modified_file, del_all, FileLike 4 | from .log import logger 5 | -------------------------------------------------------------------------------- /ChatTTS/utils/dl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import hashlib 4 | import requests 5 | from io import BytesIO 6 | from typing import Dict, Tuple, Optional 7 | from mmap import mmap, ACCESS_READ 8 | 9 | from .log import logger 10 | 11 | 12 | def sha256(fileno: int) -> str: 13 | data = mmap(fileno, 0, access=ACCESS_READ) 14 | h = hashlib.sha256(data).hexdigest() 15 | del data 16 | return h 17 | 18 | 19 | def check_model( 20 | dir_name: Path, model_name: str, hash: str, remove_incorrect=False 21 | ) -> bool: 22 | target = dir_name / model_name 23 | relname = target.as_posix() 24 | logger.get_logger().debug(f"checking {relname}...") 25 | if not os.path.exists(target): 26 | logger.get_logger().info(f"{target} not exist.") 27 | return False 28 | with open(target, "rb") as f: 29 | digest = sha256(f.fileno()) 30 | bakfile = f"{target}.bak" 31 | if digest != hash: 32 | logger.get_logger().warning(f"{target} sha256 hash mismatch.") 33 | logger.get_logger().info(f"expected: {hash}") 34 | logger.get_logger().info(f"real val: {digest}") 35 | if remove_incorrect: 36 | if not os.path.exists(bakfile): 37 | os.rename(str(target), bakfile) 38 | else: 39 | os.remove(str(target)) 40 | return False 41 | if remove_incorrect and os.path.exists(bakfile): 42 | os.remove(bakfile) 43 | return True 44 | 45 | 46 | def check_folder( 47 | base_dir: Path, 48 | *innder_dirs: str, 49 | names: Tuple[str], 50 | sha256_map: Dict[str, str], 51 | update=False, 52 | ) -> bool: 53 | key = "sha256_" 54 | current_dir = base_dir 55 | for d in innder_dirs: 56 | current_dir /= d 57 | key += f"{d}_" 58 | 59 | for model in names: 60 | menv = model.replace(".", "_") 61 | if not check_model(current_dir, model, sha256_map[f"{key}{menv}"], update): 62 | return False 63 | return True 64 | 65 | 66 | def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) -> bool: 67 | logger.get_logger().info("checking assets...") 68 | 69 | if not check_folder( 70 | base_dir, 71 | "asset", 72 | names=( 73 | "Decoder.safetensors", 74 | "DVAE.safetensors", 75 | "Embed.safetensors", 76 | "Vocos.safetensors", 77 | ), 78 | sha256_map=sha256_map, 79 | update=update, 80 | ): 81 | return False 82 | 83 | if not check_folder( 84 | base_dir, 85 | "asset", 86 | "gpt", 87 | names=( 88 | "config.json", 89 | "model.safetensors", 90 | ), 91 | sha256_map=sha256_map, 92 | update=update, 93 | ): 94 | return False 95 | 96 | if not check_folder( 97 | base_dir, 98 | "asset", 99 | "tokenizer", 100 | names=( 101 | "special_tokens_map.json", 102 | "tokenizer_config.json", 103 | "tokenizer.json", 104 | ), 105 | sha256_map=sha256_map, 106 | update=update, 107 | ): 108 | return False 109 | 110 | logger.get_logger().info("all assets are already latest.") 111 | return True 112 | 113 | 114 | def download_and_extract_tar_gz( 115 | url: str, folder: str, headers: Optional[Dict[str, str]] = None 116 | ): 117 | import tarfile 118 | 119 | logger.get_logger().info(f"downloading {url}") 120 | response = requests.get(url, headers=headers, stream=True, timeout=(10, 3)) 121 | with BytesIO() as out_file: 122 | out_file.write(response.content) 123 | out_file.seek(0) 124 | logger.get_logger().info(f"downloaded.") 125 | with tarfile.open(fileobj=out_file, mode="r:gz") as tar: 126 | tar.extractall(folder) 127 | logger.get_logger().info(f"extracted into {folder}") 128 | 129 | 130 | def download_and_extract_zip( 131 | url: str, folder: str, headers: Optional[Dict[str, str]] = None 132 | ): 133 | import zipfile 134 | 135 | logger.get_logger().info(f"downloading {url}") 136 | response = requests.get(url, headers=headers, stream=True, timeout=(10, 3)) 137 | with BytesIO() as out_file: 138 | out_file.write(response.content) 139 | out_file.seek(0) 140 | logger.get_logger().info(f"downloaded.") 141 | with zipfile.ZipFile(out_file) as zip_ref: 142 | zip_ref.extractall(folder) 143 | logger.get_logger().info(f"extracted into {folder}") 144 | 145 | 146 | def download_dns_yaml(url: str, folder: str, headers: Dict[str, str]): 147 | logger.get_logger().info(f"downloading {url}") 148 | response = requests.get(url, headers=headers, stream=True, timeout=(100, 3)) 149 | with open(os.path.join(folder, "dns.yaml"), "wb") as out_file: 150 | out_file.write(response.content) 151 | logger.get_logger().info(f"downloaded into {folder}") 152 | 153 | 154 | def download_all_assets(tmpdir: str, homedir: str, version="0.2.10"): 155 | import subprocess 156 | import platform 157 | 158 | archs = { 159 | "aarch64": "arm64", 160 | "armv8l": "arm64", 161 | "arm64": "arm64", 162 | "x86": "386", 163 | "i386": "386", 164 | "i686": "386", 165 | "386": "386", 166 | "x86_64": "amd64", 167 | "x64": "amd64", 168 | "amd64": "amd64", 169 | } 170 | system_type = platform.system().lower() 171 | architecture = platform.machine().lower() 172 | is_win = system_type == "windows" 173 | 174 | architecture = archs.get(architecture, None) 175 | if not architecture: 176 | logger.get_logger().error(f"architecture {architecture} is not supported") 177 | exit(1) 178 | try: 179 | BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/" 180 | suffix = "zip" if is_win else "tar.gz" 181 | RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}" 182 | cmdfile = os.path.join(tmpdir, "rvcmd") 183 | if is_win: 184 | download_and_extract_zip(RVCMD_URL, tmpdir) 185 | cmdfile += ".exe" 186 | else: 187 | download_and_extract_tar_gz(RVCMD_URL, tmpdir) 188 | os.chmod(cmdfile, 0o755) 189 | subprocess.run([cmdfile, "-notui", "-w", "0", "-H", homedir, "assets/chtts"]) 190 | except Exception: 191 | BASE_URL = ( 192 | "https://gitea.seku.su/fumiama/RVC-Models-Downloader/releases/download/" 193 | ) 194 | suffix = "zip" if is_win else "tar.gz" 195 | RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}" 196 | download_dns_yaml( 197 | "https://gitea.seku.su/fumiama/RVC-Models-Downloader/raw/branch/main/dns.yaml", 198 | tmpdir, 199 | headers={ 200 | "user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36 Edg/128.0.0.0" 201 | }, 202 | ) 203 | cmdfile = os.path.join(tmpdir, "rvcmd") 204 | if is_win: 205 | download_and_extract_zip(RVCMD_URL, tmpdir) 206 | cmdfile += ".exe" 207 | else: 208 | download_and_extract_tar_gz(RVCMD_URL, tmpdir) 209 | os.chmod(cmdfile, 0o755) 210 | subprocess.run( 211 | [ 212 | cmdfile, 213 | "-notui", 214 | "-w", 215 | "0", 216 | "-dns", 217 | os.path.join(tmpdir, "dns.yaml"), 218 | "-H", 219 | homedir, 220 | "assets/chtts", 221 | ] 222 | ) 223 | -------------------------------------------------------------------------------- /ChatTTS/utils/gpu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | try: 4 | import torch_npu 5 | except ImportError: 6 | pass 7 | 8 | from .log import logger 9 | 10 | 11 | def select_device(min_memory=2047, experimental=False): 12 | has_cuda = torch.cuda.is_available() 13 | if has_cuda or _is_torch_npu_available(): 14 | provider = torch.cuda if has_cuda else torch.npu 15 | """ 16 | Using Ascend NPU to accelerate the process of inferencing when GPU is not found. 17 | """ 18 | dev_idx = 0 19 | max_free_memory = -1 20 | for i in range(provider.device_count()): 21 | props = provider.get_device_properties(i) 22 | free_memory = props.total_memory - provider.memory_reserved(i) 23 | if max_free_memory < free_memory: 24 | dev_idx = i 25 | max_free_memory = free_memory 26 | free_memory_mb = max_free_memory / (1024 * 1024) 27 | if free_memory_mb < min_memory: 28 | logger.get_logger().warning( 29 | f"{provider.device(dev_idx)} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU." 30 | ) 31 | device = torch.device("cpu") 32 | else: 33 | device = provider._get_device(dev_idx) 34 | elif torch.backends.mps.is_available(): 35 | """ 36 | Currently MPS is slower than CPU while needs more memory and core utility, 37 | so only enable this for experimental use. 38 | """ 39 | if experimental: 40 | # For Apple M1/M2 chips with Metal Performance Shaders 41 | logger.get_logger().warning("experimantal: found apple GPU, using MPS.") 42 | device = torch.device("mps") 43 | else: 44 | logger.get_logger().info("found Apple GPU, but use CPU.") 45 | device = torch.device("cpu") 46 | else: 47 | logger.get_logger().warning("no GPU or NPU found, use CPU instead") 48 | device = torch.device("cpu") 49 | 50 | return device 51 | 52 | 53 | def _is_torch_npu_available(): 54 | try: 55 | # will raise a AttributeError if torch_npu is not imported or a RuntimeError if no NPU found 56 | _ = torch.npu.device_count() 57 | return torch.npu.is_available() 58 | except (AttributeError, RuntimeError): 59 | return False 60 | -------------------------------------------------------------------------------- /ChatTTS/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Union, IO 4 | from dataclasses import is_dataclass 5 | 6 | from safetensors import safe_open 7 | import torch 8 | 9 | from .log import logger 10 | 11 | if hasattr(torch.serialization, "FILE_LIKE"): 12 | FileLike = torch.serialization.FILE_LIKE 13 | elif hasattr(torch.types, "FILE_LIKE"): 14 | FileLike = torch.types.FileLike 15 | else: 16 | FileLike = Union[str, os.PathLike, IO[bytes]] 17 | 18 | 19 | @torch.inference_mode() 20 | def load_safetensors(filename: str): 21 | state_dict_tensors = {} 22 | with safe_open(filename, framework="pt") as f: 23 | for k in f.keys(): 24 | state_dict_tensors[k] = f.get_tensor(k) 25 | return state_dict_tensors 26 | 27 | 28 | def get_latest_modified_file(directory): 29 | 30 | files = [os.path.join(directory, f) for f in os.listdir(directory)] 31 | if not files: 32 | logger.get_logger().log( 33 | logging.WARNING, f"no files found in the directory: {directory}" 34 | ) 35 | return None 36 | latest_file = max(files, key=os.path.getmtime) 37 | 38 | return latest_file 39 | 40 | 41 | def del_all(d: Union[dict, list]): 42 | if is_dataclass(d): 43 | for k in list(vars(d).keys()): 44 | x = getattr(d, k) 45 | if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x): 46 | del_all(x) 47 | del x 48 | delattr(d, k) 49 | elif isinstance(d, dict): 50 | lst = list(d.keys()) 51 | for k in lst: 52 | x = d.pop(k) 53 | if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x): 54 | del_all(x) 55 | del x 56 | elif isinstance(d, list): 57 | while len(d): 58 | x = d.pop() 59 | if isinstance(x, dict) or isinstance(x, list) or is_dataclass(x): 60 | del_all(x) 61 | del x 62 | else: 63 | del d 64 | -------------------------------------------------------------------------------- /ChatTTS/utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | 5 | class Logger: 6 | def __init__(self, logger=logging.getLogger(Path(__file__).parent.name)): 7 | self.logger = logger 8 | 9 | def set_logger(self, logger: logging.Logger): 10 | self.logger = logger 11 | 12 | def get_logger(self) -> logging.Logger: 13 | return self.logger 14 | 15 | 16 | logger = Logger() 17 | -------------------------------------------------------------------------------- /docs/cn/README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 2noise%2FChatTTS | Trendshift 4 | 5 | # ChatTTS 6 | 一款适用于日常对话的生成式语音模型。 7 | 8 | [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE) 9 | [![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS) 10 | 11 | [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS) 12 | [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb) 13 | [![Discord](https://img.shields.io/badge/Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/Ud5Jxgx5yD) 14 | 15 | [**English**](../../README.md) | **简体中文** | [**日本語**](../jp/README.md) | [**Русский**](../ru/README.md) | [**Español**](../es/README.md) | [**Français**](../fr/README.md) | [**한국어**](../kr/README.md) 16 | 17 |
18 | 19 | > [!NOTE] 20 | > 注意此版本可能不是最新版,所有内容请以英文版为准。 21 | 22 | ## 简介 23 | 24 | > [!Note] 25 | > 这个仓库包含算法架构和一些简单的示例。 26 | 27 | > [!Tip] 28 | > 由本仓库衍生出的用户端产品,请参见由社区维护的索引仓库 [Awesome-ChatTTS](https://github.com/libukai/Awesome-ChatTTS)。 29 | 30 | ChatTTS 是一款专门为对话场景(例如 LLM 助手)设计的文本转语音模型。 31 | 32 | ### 支持的语种 33 | 34 | - [x] 英语 35 | - [x] 中文 36 | - [ ] 敬请期待... 37 | 38 | ### 亮点 39 | 40 | > 你可以参考 **[Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)** 上的这个视频,了解本项目的详细情况。 41 | 42 | 1. **对话式 TTS**: ChatTTS 针对对话式任务进行了优化,能够实现自然且富有表现力的合成语音。它支持多个说话者,便于生成互动式对话。 43 | 2. **精细的控制**: 该模型可以预测和控制精细的韵律特征,包括笑声、停顿和插入语。 44 | 3. **更好的韵律**: ChatTTS 在韵律方面超越了大多数开源 TTS 模型。我们提供预训练模型以支持进一步的研究和开发。 45 | 46 | ### 数据集和模型 47 | 48 | - 主模型使用了 100,000+ 小时的中文和英文音频数据进行训练。 49 | - **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** 上的开源版本是一个在 40,000 小时数据上进行无监督微调的预训练模型。 50 | 51 | ### 路线图 52 | 53 | - [x] 开源 4 万小时基础模型和 spk_stats 文件。 54 | - [x] 支持流式语音输出。 55 | - [x] 开源 DVAE 编码器和零样本推理代码 56 | - [ ] 开源具有多情感控制功能的 4 万小时版本。 57 | - [ ] ChatTTS.cpp (欢迎在 2noise 组织中新建仓库)。 58 | 59 | ### 免责声明 60 | 61 | > [!Important] 62 | > 此仓库仅供学术用途。 63 | 64 | 本项目旨在用于教育和研究目的,不适用于任何商业或法律目的。作者不保证信息的准确性、完整性和可靠性。此仓库中使用的信息和数据仅供学术和研究目的。数据来自公开来源,作者不声称对数据拥有任何所有权或版权。 65 | 66 | ChatTTS 是一款强大的文本转语音系统。但是,负责任和道德地使用这项技术非常重要。为了限制 ChatTTS 的使用,我们在 40,000 小时模型的训练过程中添加了少量高频噪声,并使用 MP3 格式尽可能压缩音频质量,以防止恶意行为者将其用于犯罪目的。同时,我们内部训练了一个检测模型,并计划在未来开源它。 67 | 68 | ### 联系方式 69 | 70 | > 欢迎随时提交 GitHub issues/PRs。 71 | 72 | #### 合作洽谈 73 | 74 | 如需就模型和路线图进行合作洽谈,请发送邮件至 **open-source@2noise.com**。 75 | 76 | #### 线上讨论 77 | 78 | ##### 1. 官方 QQ 群 79 | 80 | - **群 1**, 808364215 (已满) 81 | - **群 2**, 230696694 (已满) 82 | - **群 3**, 933639842 (已满) 83 | - **群 4**, 608667975 84 | 85 | ##### 2. Discord 86 | 87 | 点击加入 [Discord](https://discord.gg/Ud5Jxgx5yD)。 88 | 89 | ## 体验教程 90 | 91 | ### 克隆仓库 92 | 93 | ```bash 94 | git clone https://github.com/2noise/ChatTTS 95 | cd ChatTTS 96 | ``` 97 | 98 | ### 安装依赖 99 | 100 | #### 1. 直接安装 101 | 102 | ```bash 103 | pip install --upgrade -r requirements.txt 104 | ``` 105 | 106 | #### 2. 使用 conda 安装 107 | 108 | ```bash 109 | conda create -n chattts 110 | conda activate chattts 111 | pip install -r requirements.txt 112 | ``` 113 | 114 | #### 可选 : 如果使用 NVIDIA GPU(仅限 Linux),可安装 TransformerEngine。 115 | 116 | > [!Note] 117 | > 安装过程可能耗时很长。 118 | 119 | > [!Warning] 120 | > TransformerEngine 的适配目前正在开发中,运行时可能会遇到较多问题。仅推荐出于开发目的安装。 121 | 122 | ```bash 123 | pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable 124 | ``` 125 | 126 | #### 可选 : 安装 FlashAttention-2 (主要适用于 NVIDIA GPU) 127 | 128 | > [!Note] 129 | > 支持设备列表详见 [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2). 130 | 131 | ```bash 132 | pip install flash-attn --no-build-isolation 133 | ``` 134 | 135 | ### 快速启动 136 | 137 | > 确保在执行以下命令时,处于项目根目录下。 138 | 139 | #### 1. WebUI 可视化界面 140 | 141 | ```bash 142 | python examples/web/webui.py 143 | ``` 144 | 145 | #### 2. 命令行交互 146 | 147 | > 生成的音频将保存至 `./output_audio_n.mp3` 148 | 149 | ```bash 150 | python examples/cmd/run.py "Your text 1." "Your text 2." 151 | ``` 152 | 153 | ## 开发教程 154 | 155 | ### 安装 Python 包 156 | 157 | 1. 从 PyPI 安装稳定版 158 | 159 | ```bash 160 | pip install ChatTTS 161 | ``` 162 | 163 | 2. 从 GitHub 安装最新版 164 | 165 | ```bash 166 | pip install git+https://github.com/2noise/ChatTTS 167 | ``` 168 | 169 | 3. 从本地文件夹安装开发版 170 | 171 | ```bash 172 | pip install -e . 173 | ``` 174 | 175 | ### 基础用法 176 | 177 | ```python 178 | import ChatTTS 179 | import torch 180 | import torchaudio 181 | 182 | chat = ChatTTS.Chat() 183 | chat.load(compile=False) # Set to True for better performance 184 | 185 | texts = ["PUT YOUR 1st TEXT HERE", "PUT YOUR 2nd TEXT HERE"] 186 | 187 | wavs = chat.infer(texts) 188 | 189 | torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) 190 | ``` 191 | 192 | ### 进阶用法 193 | 194 | ```python 195 | ################################### 196 | # Sample a speaker from Gaussian. 197 | 198 | rand_spk = chat.sample_random_speaker() 199 | print(rand_spk) # save it for later timbre recovery 200 | 201 | params_infer_code = ChatTTS.Chat.InferCodeParams( 202 | spk_emb = rand_spk, # add sampled speaker 203 | temperature = .3, # using custom temperature 204 | top_P = 0.7, # top P decode 205 | top_K = 20, # top K decode 206 | ) 207 | 208 | ################################### 209 | # For sentence level manual control. 210 | 211 | # use oral_(0-9), laugh_(0-2), break_(0-7) 212 | # to generate special token in text to synthesize. 213 | params_refine_text = ChatTTS.Chat.RefineTextParams( 214 | prompt='[oral_2][laugh_0][break_6]', 215 | ) 216 | 217 | wavs = chat.infer( 218 | texts, 219 | params_refine_text=params_refine_text, 220 | params_infer_code=params_infer_code, 221 | ) 222 | 223 | ################################### 224 | # For word level manual control. 225 | 226 | text = 'What is [uv_break]your favorite english food?[laugh][lbreak]' 227 | wavs = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code) 228 | torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000) 229 | ``` 230 | 231 |
232 |

示例: 自我介绍

233 | 234 | ```python 235 | inputs_en = """ 236 | chatTTS is a text to speech model designed for dialogue applications. 237 | [uv_break]it supports mixed language input [uv_break]and offers multi speaker 238 | capabilities with precise control over prosodic elements like 239 | [uv_break]laughter[uv_break][laugh], [uv_break]pauses, [uv_break]and intonation. 240 | [uv_break]it delivers natural and expressive speech,[uv_break]so please 241 | [uv_break] use the project responsibly at your own risk.[uv_break] 242 | """.replace('\n', '') # English is still experimental. 243 | 244 | params_refine_text = ChatTTS.Chat.RefineTextParams( 245 | prompt='[oral_2][laugh_0][break_4]', 246 | ) 247 | 248 | audio_array_en = chat.infer(inputs_en, params_refine_text=params_refine_text) 249 | torchaudio.save("output3.wav", torch.from_numpy(audio_array_en[0]), 24000) 250 | ``` 251 | 252 | 253 | 254 | 259 | 264 | 265 | 266 | 271 | 276 | 277 |
255 | 256 | **男性音色** 257 | 258 | 260 | 261 | **女性音色** 262 | 263 |
267 | 268 | [男性音色](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1) 269 | 270 | 272 | 273 | [女性音色](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd) 274 | 275 |
278 | 279 |
280 | 281 | ## 常见问题 282 | 283 | #### 1. 我需要多少 VRAM? 推理速度如何? 284 | 285 | 对于 30 秒的音频片段,至少需要 4GB 的 GPU 内存。 对于 4090 GPU,它可以每秒生成大约 7 个语义 token 对应的音频。实时因子 (RTF) 约为 0.3。 286 | 287 | #### 2. 模型稳定性不够好,存在多个说话者或音频质量差等问题。 288 | 289 | 这是一个通常发生在自回归模型(例如 bark 和 valle)中的问题,通常很难避免。可以尝试多个样本以找到合适的结果。 290 | 291 | #### 3. 除了笑声,我们还能控制其他东西吗?我们能控制其他情绪吗? 292 | 293 | 在当前发布的模型中,可用的 token 级控制单元是 `[laugh]`, `[uv_break]` 和 `[lbreak]`。未来的版本中,我们可能会开源具有更多情绪控制功能的模型。 294 | 295 | ## 致谢 296 | 297 | - [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS) 和 [valle](https://arxiv.org/abs/2301.02111) 通过自回归式系统展示了非凡的 TTS 效果。 298 | - [fish-speech](https://github.com/fishaudio/fish-speech) 揭示了 GVQ 作为 LLM 建模的音频分词器的能力。 299 | - [vocos](https://github.com/gemelo-ai/vocos) vocos 被用作预训练声码器。 300 | 301 | ## 特别鸣谢 302 | 303 | - [wlu-audio lab](https://audio.westlake.edu.cn/) 对于早期算法实验的支持。 304 | 305 | ## 贡献者列表 306 | 307 | [![contributors](https://contrib.rocks/image?repo=2noise/ChatTTS)](https://github.com/2noise/ChatTTS/graphs/contributors) 308 | 309 | ## 项目浏览量 310 | 311 |
312 | 313 | ![counter](https://counter.seku.su/cmoe?name=chattts&theme=mbs) 314 | 315 |
316 | -------------------------------------------------------------------------------- /docs/jp/README.md: -------------------------------------------------------------------------------- 1 | # ChatTTS 2 | > [!NOTE] 3 | > 以下の内容は最新情報ではない可能性がありますのでご了承ください。全ての内容は英語版に基準することになります。 4 | 5 | [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS) 6 | 7 | [**English**](../../README.md) | [**简体中文**](../cn/README.md) | **日本語** | [**Русский**](../ru/README.md) | [**Español**](../es/README.md) | [**Français**](../fr/README.md) | [**한국어**](../kr/README.md) 8 | 9 | ChatTTSは、LLMアシスタントなどの対話シナリオ用に特別に設計されたテキストから音声へのモデルです。英語と中国語の両方をサポートしています。私たちのモデルは、中国語と英語で構成される100,000時間以上でトレーニングされています。**[HuggingFace](https://huggingface.co/2Noise/ChatTTS)**でオープンソース化されているバージョンは、40,000時間の事前トレーニングモデルで、SFTは行われていません。 10 | 11 | モデルやロードマップについての正式なお問い合わせは、**open-source@2noise.com**までご連絡ください。QQグループ:808364215に参加してディスカッションすることもできます。GitHubでの問題提起も歓迎します。 12 | 13 | --- 14 | ## ハイライト 15 | 1. **会話型TTS**: ChatTTSは対話ベースのタスクに最適化されており、自然で表現豊かな音声合成を実現します。複数の話者をサポートし、対話型の会話を容易にします。 16 | 2. **細かい制御**: このモデルは、笑い、一時停止、間投詞などの細かい韻律特徴を予測および制御することができます。 17 | 3. **より良い韻律**: ChatTTSは、韻律の面でほとんどのオープンソースTTSモデルを超えています。さらなる研究と開発をサポートするために、事前トレーニングされたモデルを提供しています。 18 | 19 | モデルの詳細な説明については、**[Bilibiliのビデオ](https://www.bilibili.com/video/BV1zn4y1o7iV)**を参照してください。 20 | 21 | --- 22 | 23 | ## 免責事項 24 | 25 | このリポジトリは学術目的のみのためです。教育および研究用途にのみ使用され、商業的または法的な目的には使用されません。著者は情報の正確性、完全性、または信頼性を保証しません。このリポジトリで使用される情報およびデータは、学術および研究目的のみのためのものです。データは公開されているソースから取得され、著者はデータに対する所有権または著作権を主張しません。 26 | 27 | ChatTTSは強力なテキストから音声へのシステムです。しかし、この技術を責任を持って、倫理的に利用することが非常に重要です。ChatTTSの使用を制限するために、40,000時間のモデルのトレーニング中に少量の高周波ノイズを追加し、MP3形式を使用して音質を可能な限り圧縮しました。これは、悪意のあるアクターが潜在的に犯罪目的で使用することを防ぐためです。同時に、私たちは内部的に検出モデルをトレーニングしており、将来的にオープンソース化する予定です。 28 | 29 | --- 30 | ## 使用方法 31 | 32 |

基本的な使用方法

33 | 34 | ```python 35 | import ChatTTS 36 | from IPython.display import Audio 37 | import torch 38 | 39 | chat = ChatTTS.Chat() 40 | chat.load(compile=False) # より良いパフォーマンスのためにTrueに設定 41 | 42 | texts = ["ここにテキストを入力してください",] 43 | 44 | wavs = chat.infer(texts, ) 45 | 46 | torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) 47 | ``` 48 | 49 |

高度な使用方法

50 | 51 | ```python 52 | ################################### 53 | # ガウス分布から話者をサンプリングします。 54 | 55 | rand_spk = chat.sample_random_speaker() 56 | print(rand_spk) # save it for later timbre recovery 57 | 58 | params_infer_code = { 59 | 'spk_emb': rand_spk, # サンプリングされた話者を追加 60 | 'temperature': .3, # カスタム温度を使用 61 | 'top_P': 0.7, # トップPデコード 62 | 'top_K': 20, # トップKデコード 63 | } 64 | 65 | ################################### 66 | # 文レベルの手動制御のために。 67 | 68 | # 特別なトークンを生成するためにテキストにoral_(0-9)、laugh_(0-2)、break_(0-7)を使用します。 69 | params_refine_text = { 70 | 'prompt': '[oral_2][laugh_0][break_6]' 71 | } 72 | 73 | wav = chat.infer(texts, params_refine_text=params_refine_text, params_infer_code=params_infer_code) 74 | 75 | ################################### 76 | # 単語レベルの手動制御のために。 77 | text = 'あなたの好きな英語の食べ物は何ですか?[uv_break][laugh][lbreak]' 78 | wav = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code) 79 | torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000) 80 | ``` 81 | 82 |
83 |

例:自己紹介

84 | 85 | ```python 86 | inputs_jp = """ 87 | ChatTTSは、対話アプリケーション用に設計されたテキストから音声へのモデルです。 88 | [uv_break]混合言語入力をサポートし[uv_break]、韻律要素[laugh]の正確な制御を提供します 89 | [uv_break]笑い[laugh]、[uv_break]一時停止、[uv_break]およびイントネーション。[uv_break]自然で表現豊かな音声を提供します 90 | [uv_break]したがって、自己責任でプロジェクトを責任を持って使用してください。[uv_break] 91 | """.replace('\n', '') # 英語はまだ実験的です。 92 | 93 | params_refine_text = { 94 | 'prompt': '[oral_2][laugh_0][break_4]' 95 | } 96 | audio_array_jp = chat.infer(inputs_jp, params_refine_text=params_refine_text) 97 | torchaudio.save("output3.wav", torch.from_numpy(audio_array_jp[0]), 24000) 98 | ``` 99 | [男性話者](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1) 100 | 101 | [女性話者](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd) 102 |
103 | 104 | --- 105 | ## ロードマップ 106 | - [x] 40k時間のベースモデルとspk_statsファイルをオープンソース化 107 | - [ ] VQエンコーダーとLoraトレーニングコードをオープンソース化 108 | - [ ] テキストをリファインせずにストリーミングオーディオ生成* 109 | - [ ] 複数の感情制御を備えた40k時間バージョンをオープンソース化 110 | - [ ] ChatTTS.cppもしかしたら?(PRや新しいリポジトリが歓迎されます。) 111 | 112 | ---- 113 | ## FAQ 114 | 115 | ##### VRAMはどれくらい必要ですか?推論速度はどうですか? 116 | 30秒のオーディオクリップには、少なくとも4GBのGPUメモリが必要です。4090 GPUの場合、約7つの意味トークンに対応するオーディオを1秒あたり生成できます。リアルタイムファクター(RTF)は約0.3です。 117 | 118 | ##### モデルの安定性が十分でなく、複数の話者や音質が悪いという問題があります。 119 | 120 | これは、自己回帰モデル(barkおよびvalleの場合)で一般的に発生する問題です。一般的に避けるのは難しいです。複数のサンプルを試して、適切な結果を見つけることができます。 121 | 122 | ##### 笑い以外に何か制御できますか?他の感情を制御できますか? 123 | 124 | 現在リリースされているモデルでは、トークンレベルの制御ユニットは[laugh]、[uv_break]、および[lbreak]のみです。将来のバージョンでは、追加の感情制御機能を備えたモデルをオープンソース化する可能性があります。 125 | 126 | --- 127 | ## 謝辞 128 | - [bark](https://github.com/suno-ai/bark)、[XTTSv2](https://github.com/coqui-ai/TTS)、および[valle](https://arxiv.org/abs/2301.02111)は、自己回帰型システムによる顕著なTTS結果を示しました。 129 | - [fish-speech](https://github.com/fishaudio/fish-speech)は、LLMモデリングのためのオーディオトークナイザーとしてのGVQの能力を明らかにしました。 130 | - 事前トレーニングされたボコーダーとして使用される[vocos](https://github.com/gemelo-ai/vocos)。 131 | 132 | --- 133 | ## 特別感謝 134 | - 初期のアルゴリズム実験をサポートしてくれた[wlu-audio lab](https://audio.westlake.edu.cn/)。 135 | -------------------------------------------------------------------------------- /docs/kr/README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 2noise%2FChatTTS | Trendshift 4 | 5 | # ChatTTS 6 | 일상 대화를 위한 생성형 음성 모델입니다. 7 | 8 | [![Licence](https://img.shields.io/github/license/2noise/ChatTTS?style=for-the-badge)](https://github.com/2noise/ChatTTS/blob/main/LICENSE) 9 | [![PyPI](https://img.shields.io/pypi/v/ChatTTS.svg?style=for-the-badge&color=green)](https://pypi.org/project/ChatTTS) 10 | 11 | [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS) 12 | [![Open In Colab](https://img.shields.io/badge/Colab-F9AB00?style=for-the-badge&logo=googlecolab&color=525252)](https://colab.research.google.com/github/2noise/ChatTTS/blob/main/examples/ipynb/colab.ipynb) 13 | [![Discord](https://img.shields.io/badge/Discord-7289DA?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/Ud5Jxgx5yD) 14 | 15 | [**English**](../../README.md) | [**简体中文**](../cn/README.md) | [**日本語**](../jp/README.md) | [**Русский**](../ru/README.md) | [**Español**](../es/README.md) | [**Français**](../fr/README.md) | **한국어** 16 | 17 |
18 | 19 | > [!NOTE] 20 | > 이 문서는 최신 버전이 아닐 수 있습니다. [영어 문서](../../README.md)를 기준으로 작업하는 것을 권장합니다. 21 | 22 | ## 프로젝트 소개 23 | 24 | > [!Note] 25 | > 이 저장소에는 알고리즘 구조와 간단한 예시들이 포함되어 있습니다. 26 | 27 | > [!Tip] 28 | > 이 프로젝트에서 파생된 프로젝트는 커뮤니티가 유지 관리하는 커뮤니티[Awesome-ChatTTS](https://github.com/libukai/Awesome-ChatTTS)를 참조하시길 바랍니다. 29 | 30 | ChatTTS는 대화 기반 작업(예: LLM 어시스턴트)을 위해 설계된 텍스트-음성 변환(TTS) 모델입니다. 31 | 32 | ### 지원 언어 33 | 34 | - [x] 영어 35 | - [x] 중국어 36 | - [ ] 계속 추가 예정... 37 | 38 | ### 프로젝트 특징 39 | 40 | > 이 프로젝트의 내용은 **[Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)**에서 제공되는 비디오를 참조하시길 바랍니다. 41 | 42 | 1. **대화형 TTS**: ChatTTS는 대화 기반 작업에 최적화되어 자연스럽고 표현력 있는 음성 합성을 구현합니다. 다중 화자를 지원하여 상호작용적인 대화를 가능하게 합니다. 43 | 2. **세밀한 제어**: 이 모델은 웃음, 일시 정지, 삽입어 등 세밀한 운율적 특징을 예측하고 제어할 수 있습니다. 44 | 3. **향상된 운율**: ChatTTS는 운율 측면에서 대부분의 오픈 소스 TTS 모델을 능가하며, 추가 연구와 개발을 지원하기 위해 사전 훈련된 모델을 제공합니다. 45 | 46 | ### 데이터셋 및 모델 47 | > [!Important] 48 | > 공개된 모델은 학술 목적으로만 사용 가능합니다. 49 | 50 | - 주요 모델은 100,000+ 시간의 중국어 및 영어 오디오 데이터를 사용하여 훈련되었습니다. 51 | - **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)**에서 제공되는 오픈 소스 버전은 40,000시간의 사전 훈련된 모델로, SFT가 적용되지 않았습니다. 52 | 53 | ### 로드맵 54 | - [x] 40,000시간 기반 모델과 spk_stats 파일 오픈 소스화. 55 | - [x] 스트리밍 오디오 생성. 56 | - [x] DVAE 인코더와 제로 샷 추론 코드 오픈 소스화. 57 | - [ ] 다중 감정 제어 기능. 58 | - [ ] ChatTTS.cpp (`2noise` 조직 내의 새로운 저장소를 환영합니다.) 59 | 60 | ### 라이선스 61 | 62 | #### 코드 63 | 코드는 `AGPLv3+` 라이선스를 따릅니다. 64 | 65 | #### 모델 66 | 모델은 `CC BY-NC 4.0` 라이선스로 공개되었습니다. 이 모델은 교육 및 연구 목적으로만 사용되며, 상업적 또는 불법적 목적으로 사용되어서는 안 됩니다. 저자들은 정보의 정확성, 완전성, 신뢰성을 보장하지 않습니다. 이 저장소에서 사용된 정보와 데이터는 학술 및 연구 목적으로만 사용되며, 공개적으로 이용 가능한 출처에서 얻은 데이터입니다. 저자들은 데이터에 대한 소유권 또는 저작권을 주장하지 않습니다. 67 | 68 | ### 면책 조항 69 | 70 | ChatTTS는 강력한 텍스트-음성 변환 시스템입니다. 그렇기에 기술을 책임감 있고 윤리적으로 사용하는 것은 아주 중요합니다. ChatTTS의 악용을 방지하기 위해 40,000시간 모델의 훈련 중 소량의 고주파 노이즈를 추가하고, 오디오 품질을 최대한 압축하여 MP3 형식으로 제공했습니다. 또한, 우리는 내부적으로 탐지 모델을 훈련했으며, 추후 이를 오픈 소스화할 계획입니다. 71 | 72 | ### 연락처 73 | > GitHub 이슈/PR은 언제든지 환영합니다. 74 | 75 | #### 공식 문의 76 | 모델 및 로드맵에 대한 공식적인 문의는 **open-source@2noise.com**으로 연락해 주십시오. 77 | 78 | #### 온라인 채팅 79 | ##### 1. QQ Group (Chinese Social APP) 80 | - **Group 1**, 808364215 81 | - **Group 2**, 230696694 82 | - **Group 3**, 933639842 83 | - **Group 4**, 608667975 84 | 85 | ##### 2. Discord 서버 86 | [이곳](https://discord.gg/Ud5Jxgx5yD)를 클릭하여 참여하십시오. 87 | 88 | ## 시작하기 89 | ### 레포지토리 클론 90 | ```bash 91 | git clone https://github.com/2noise/ChatTTS 92 | cd ChatTTS 93 | ``` 94 | 95 | ### 의존성 설치 96 | #### 1. 직접 설치 97 | ```bash 98 | pip install --upgrade -r requirements.txt 99 | ``` 100 | 101 | #### 2. Conda에서 설치 102 | ```bash 103 | conda create -n chattts 104 | conda activate chattts 105 | pip install -r requirements.txt 106 | ``` 107 | 108 | #### 선택사항: vLLM 설치 (Linux 전용) 109 | ```bash 110 | pip install safetensors vllm==0.2.7 torchaudio 111 | ``` 112 | 113 | #### 권장되지 않는 선택사항: NVIDIA GPU 사용 시 TransformerEngine 설치 (Linux 전용) 114 | > [!Warning] 115 | > 설치하지 마십시오! 116 | > TransformerEngine의 적응 작업은 현재 개발 중이며, 아직 제대로 작동하지 않습니다. 117 | > 개발 목적으로만 설치하십시오. 자세한 내용은 #672 및 #676에서 확인할 수 있습니다. 118 | 119 | > [!Note] 120 | > 설치 과정은 매우 느립니다. 121 | 122 | ```bash 123 | pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable 124 | ``` 125 | 126 | #### 권장되지 않는 선택사항: FlashAttention-2 설치 (주로 NVIDIA GPU) 127 | > [!Warning] 128 | > 설치하지 마십시오! 129 | > 현재 FlashAttention-2는 [이 이슈](https://github.com/huggingface/transformers/issues/26990)에 따르면 생성 속도를 저하시킵니다. 130 | > 개발 목적으로만 설치하십시오. 131 | 132 | > [!Note] 133 | > 지원되는 장치는 [Hugging Face 문서](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2)에서 확인할 수 있습니다. 134 | 135 | ```bash 136 | pip install flash-attn --no-build-isolation 137 | ``` 138 | 139 | ### 빠른 시작 140 | > 아래 명령어를 실행할 때 반드시 프로젝트 루트 디렉토리에서 실행하십시오. 141 | 142 | #### 1. WebUI 실행 143 | ```bash 144 | python examples/web/webui.py 145 | ``` 146 | 147 | #### 2. 커맨드 라인에서 추론 148 | > 오디오는 `./output_audio_n.mp3`에 저장됩니다. 149 | 150 | ```bash 151 | python examples/cmd/run.py "Your text 1." "Your text 2." 152 | ``` 153 | 154 | ## 설치 방법 155 | 156 | 1. PyPI에서 안정 버전 설치 157 | ```bash 158 | pip install ChatTTS 159 | ``` 160 | 161 | 2. GitHub에서 최신 버전 설치 162 | ```bash 163 | pip install git+https://github.com/2noise/ChatTTS 164 | ``` 165 | 166 | 3. 로컬 디렉토리에서 개발 모드로 설치 167 | ```bash 168 | pip install -e . 169 | ``` 170 | 171 | ### 기본 사용법 172 | 173 | ```python 174 | import ChatTTS 175 | import torch 176 | import torchaudio 177 | 178 | chat = ChatTTS.Chat() 179 | chat.load(compile=False) # 성능 향상을 위해 True로 설정 가능 180 | 181 | texts = ["PUT YOUR 1st TEXT HERE", "PUT YOUR 2nd TEXT HERE"] 182 | 183 | wavs = chat.infer(texts) 184 | 185 | for i in range(len(wavs)): 186 | """ 187 | torchaudio의 버전에 따라 첫 번째 줄이 작동할 수 있고, 다른 버전에서는 두 번째 줄이 작동할 수 있습니다. 188 | """ 189 | try: 190 | torchaudio.save(f"basic_output{i}.wav", torch.from_numpy(wavs[i]).unsqueeze(0), 24000) 191 | except: 192 | torchaudio.save(f"basic_output{i}.wav", torch.from_numpy(wavs[i]), 24000) 193 | ``` 194 | 195 | ### Advanced Usage 196 | 197 | ```python 198 | ################################### 199 | # Sample a speaker from Gaussian. 200 | 201 | rand_spk = chat.sample_random_speaker() 202 | print(rand_spk) # save it for later timbre recovery 203 | 204 | params_infer_code = ChatTTS.Chat.InferCodeParams( 205 | spk_emb = rand_spk, # add sampled speaker 206 | temperature = .3, # using custom temperature 207 | top_P = 0.7, # top P decode 208 | top_K = 20, # top K decode 209 | ) 210 | 211 | ################################### 212 | # For sentence level manual control. 213 | 214 | # use oral_(0-9), laugh_(0-2), break_(0-7) 215 | # to generate special token in text to synthesize. 216 | params_refine_text = ChatTTS.Chat.RefineTextParams( 217 | prompt='[oral_2][laugh_0][break_6]', 218 | ) 219 | 220 | wavs = chat.infer( 221 | texts, 222 | params_refine_text=params_refine_text, 223 | params_infer_code=params_infer_code, 224 | ) 225 | 226 | ################################### 227 | # For word level manual control. 228 | 229 | text = 'What is [uv_break]your favorite english food?[laugh][lbreak]' 230 | wavs = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code) 231 | """ 232 | In some versions of torchaudio, the first line works but in other versions, so does the second line. 233 | """ 234 | try: 235 | torchaudio.save("word_level_output.wav", torch.from_numpy(wavs[0]).unsqueeze(0), 24000) 236 | except: 237 | torchaudio.save("word_level_output.wav", torch.from_numpy(wavs[0]), 24000) 238 | ``` 239 | 240 |
241 |

Example: self introduction

242 | 243 | ```python 244 | inputs_en = """ 245 | chat T T S is a text to speech model designed for dialogue applications. 246 | [uv_break]it supports mixed language input [uv_break]and offers multi speaker 247 | capabilities with precise control over prosodic elements like 248 | [uv_break]laughter[uv_break][laugh], [uv_break]pauses, [uv_break]and intonation. 249 | [uv_break]it delivers natural and expressive speech,[uv_break]so please 250 | [uv_break] use the project responsibly at your own risk.[uv_break] 251 | """.replace('\n', '') # English is still experimental. 252 | 253 | params_refine_text = ChatTTS.Chat.RefineTextParams( 254 | prompt='[oral_2][laugh_0][break_4]', 255 | ) 256 | 257 | audio_array_en = chat.infer(inputs_en, params_refine_text=params_refine_text) 258 | torchaudio.save("self_introduction_output.wav", torch.from_numpy(audio_array_en[0]), 24000) 259 | ``` 260 | 261 | 262 | 263 | 268 | 273 | 274 | 275 | 280 | 285 | 286 |
264 | 265 | **male speaker** 266 | 267 | 269 | 270 | **female speaker** 271 | 272 |
276 | 277 | [male speaker](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1) 278 | 279 | 281 | 282 | [female speaker](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd) 283 | 284 |
287 | 288 |
289 | 290 | ## FAQ 291 | 292 | #### 1. VRAM이 얼마나 필요한가요? 추론 속도는 어느 정도인가요? 293 | 30초 길이의 오디오 클립을 생성하려면 최소 4GB의 GPU 메모리가 필요합니다. 4090 GPU의 경우 초당 약 7개의 의미 토큰에 해당하는 오디오를 생성할 수 있습니다. 실시간 인자(RTF)는 약 0.3입니다. 294 | 295 | #### 2. 모델의 안정성은 불안정하며, 화자가 많은 경우 및 오디오 품질이 저하되는 이슈 존재. 296 | 297 | 이는 일반적으로 autoregressive 모델(bark 및 valle 등)에서 발생하는 불가피한 문제입니다. 현재로선 여러 번 샘플링하여 적절한 결과를 찾는 것이 최선입니다. 298 | 299 | #### 3. 웃음 뿐 아니라 다른 감정도 표현할 수 있나요? 300 | 301 | 현재 공개된 모델에서는 제어 가능한 토큰은 `[laugh]`, `[uv_break]`, `[lbreak]`입니다. 향후 버전의 모델에서는 추가적인 감정 제어 기능 포함하여 오픈 소스로 제공할 계획입니다. 302 | 303 | ## 감사의 인사 304 | - [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS), [valle](https://arxiv.org/abs/2301.02111)는 autoregressive 방식의 시스템으로 뛰어난 TTS 성능을 보여주었습니다. 305 | - [fish-speech](https://github.com/fishaudio/fish-speech)는 LLM 모델링을 위한 오디오 토크나이저로서 GVQ의 능력을 보여주었습니다. 306 | - [vocos](https://github.com/gemelo-ai/vocos)는 사전 훈련된 vocoder로 사용되었습니다. 307 | 308 | ## 특별 감사 309 | - 초기 알고리즘 실험을 위한 [wlu-audio lab](https://audio.westlake.edu.cn/)에 감사의 말씀을 전합니다. 310 | 311 | ## 모든 기여자들의 노고에 감사드립니다 312 | [![contributors](https://contrib.rocks/image?repo=2noise/ChatTTS)](https://github.com/2noise/ChatTTS/graphs/contributors) 313 | 314 |
315 | 316 | ![counter](https://counter.seku.su/cmoe?name=chattts&theme=mbs) 317 | 318 |
319 | -------------------------------------------------------------------------------- /docs/ru/README.md: -------------------------------------------------------------------------------- 1 | # ChatTTS 2 | > [!NOTE] 3 | > Следующая информация может быть не самой последней, пожалуйста, смотрите английскую версию для актуальных данных. 4 | 5 | [![Huggingface](https://img.shields.io/badge/🤗%20-Models-yellow.svg?style=for-the-badge)](https://huggingface.co/2Noise/ChatTTS) 6 | 7 | [**English**](../../README.md) | [**简体中文**](../cn/README.md) | [**日本語**](../jp/README.md) | **Русский** | [**Español**](../es/README.md) | [**Français**](../fr/README.md) | [**한국어**](../kr/README.md) 8 | 9 | ChatTTS - это модель преобразования текста в речь, специально разработанная для диалоговых сценариев, таких как помощник LLM. Она поддерживает как английский, так и китайский языки. Наша модель обучена на более чем 100 000 часах английского и китайского языков. Открытая версия на **[HuggingFace](https://huggingface.co/2Noise/ChatTTS)** - это предварительно обученная модель с 40 000 часами без SFT. 10 | 11 | Для официальных запросов о модели и плане развития, пожалуйста, свяжитесь с нами по адресу **open-source@2noise.com**. Вы можете присоединиться к нашей группе QQ: 808364215 для обсуждения. Добавление вопросов на GitHub также приветствуется. 12 | 13 | --- 14 | ## Особенности 15 | 1. **Диалоговый TTS**: ChatTTS оптимизирован для задач, основанных на диалогах, что позволяет создавать натуральную и выразительную речь. Он поддерживает несколько говорящих, облегчая интерактивные беседы. 16 | 2. **Тонкий контроль**: Модель может предсказывать и контролировать тонкие просодические особенности, включая смех, паузы и вставные слова. 17 | 3. **Лучшая просодия**: ChatTTS превосходит большинство открытых моделей TTS с точки зрения просодии. Мы предоставляем предварительно обученные модели для поддержки дальнейших исследований и разработок. 18 | 19 | Для подробного описания модели вы можете обратиться к **[видео на Bilibili](https://www.bilibili.com/video/BV1zn4y1o7iV)** 20 | 21 | --- 22 | 23 | ## Отказ от ответственности 24 | 25 | Этот репозиторий предназначен только для академических целей. Он предназначен для образовательного и исследовательского использования и не должен использоваться в коммерческих или юридических целях. Авторы не гарантируют точность, полноту или надежность информации. Информация и данные, использованные в этом репозитории, предназначены только для академических и исследовательских целей. Данные получены из общедоступных источников, и авторы не заявляют о каких-либо правах собственности или авторских правах на данные. 26 | 27 | ChatTTS - мощная система преобразования текста в речь. Однако очень важно использовать эту технологию ответственно и этично. Чтобы ограничить использование ChatTTS, мы добавили небольшое количество высокочастотного шума во время обучения модели на 40 000 часов и сжали качество аудио как можно больше с помощью формата MP3, чтобы предотвратить возможное использование злоумышленниками в преступных целях. В то же время мы внутренне обучили модель обнаружения и планируем открыть ее в будущем. 28 | 29 | --- 30 | ## Использование 31 | 32 |

Базовое использование

33 | 34 | ```python 35 | import ChatTTS 36 | from IPython.display import Audio 37 | import torch 38 | 39 | chat = ChatTTS.Chat() 40 | chat.load(compile=False) # Установите значение True для лучшей производительности 41 | 42 | texts = ["ВВЕДИТЕ ВАШ ТЕКСТ ЗДЕСЬ",] 43 | 44 | wavs = chat.infer(texts) 45 | 46 | torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) 47 | ``` 48 | 49 |

Продвинутое использование

50 | 51 | ```python 52 | ################################### 53 | # Выборка говорящего из Гауссиана. 54 | 55 | rand_spk = chat.sample_random_speaker() 56 | print(rand_spk) # save it for later timbre recovery 57 | 58 | params_infer_code = { 59 | 'spk_emb': rand_spk, # добавить выбранного говорящего 60 | 'temperature': .3, # использовать пользовательскую температуру 61 | 'top_P': 0.7, # декодирование top P 62 | 'top_K': 20, # декодирование top K 63 | } 64 | 65 | ################################### 66 | # Для контроля на уровне предложений. 67 | 68 | # используйте oral_(0-9), laugh_(0-2), break_(0-7) 69 | # для генерации специального токена в тексте для синтеза. 70 | params_refine_text = { 71 | 'prompt': '[oral_2][laugh_0][break_6]' 72 | } 73 | 74 | wav = chat.infer(texts, params_refine_text=params_refine_text, params_infer_code=params_infer_code) 75 | 76 | ################################### 77 | # Для контроля на уровне слов. 78 | text = 'Какая ваша любимая английская еда?[uv_break]your favorite english food?[laugh][lbreak]' 79 | wav = chat.infer(text, skip_refine_text=True, params_refine_text=params_refine_text, params_infer_code=params_infer_code) 80 | torchaudio.save("output2.wav", torch.from_numpy(wavs[0]), 24000) 81 | ``` 82 | 83 |
84 |

Пример: самопрезентация

85 | 86 | ```python 87 | inputs_ru = """ 88 | ChatTTS - это модель преобразования текста в речь, разработанная для диалоговых приложений. 89 | [uv_break]Она поддерживает смешанный языковой ввод [uv_break]и предлагает возможности множественных говорящих 90 | с точным контролем над просодическими элементами [laugh]как [uv_break]смех[laugh], [uv_break]паузы, [uv_break]и интонацию. 91 | [uv_break]Она обеспечивает натуральную и выразительную речь,[uv_break]поэтому, пожалуйста, 92 | [uv_break] используйте проект ответственно и на свой страх и риск.[uv_break] 93 | """.replace('\n', '') # Русский язык все еще находится в экспериментальной стадии. 94 | 95 | params_refine_text = { 96 | 'prompt': '[oral_2][laugh_0][break_4]' 97 | } 98 | audio_array_ru = chat.infer(inputs_ru, params_refine_text=params_refine_text) 99 | torchaudio.save("output3.wav", torch.from_numpy(audio_array_ru[0]), 24000) 100 | ``` 101 | [мужской говорящий](https://github.com/2noise/ChatTTS/assets/130631963/e0f51251-db7f-4d39-a0e9-3e095bb65de1) 102 | 103 | [женский говорящий](https://github.com/2noise/ChatTTS/assets/130631963/f5dcdd01-1091-47c5-8241-c4f6aaaa8bbd) 104 |
105 | 106 | --- 107 | ## План развития 108 | - [x] Открыть исходный код базовой модели на 40 тысяч часов и файла spk_stats 109 | - [ ] Открыть исходный код кодировщика VQ и кода обучения Lora 110 | - [ ] Потоковая генерация аудио без уточнения текста* 111 | - [ ] Открыть исходный код версии на 40 тысяч часов с управлением множественными эмоциями 112 | - [ ] ChatTTS.cpp возможно? (PR или новый репозиторий приветствуются.) 113 | 114 | ---- 115 | ## Часто задаваемые вопросы 116 | 117 | ##### Сколько VRAM мне нужно? Как насчет скорости инференса? 118 | Для 30-секундного аудиоклипа требуется как минимум 4 ГБ памяти GPU. Для GPU 4090, он может генерировать аудио, соответствующее примерно 7 семантическим токенам в секунду. Фактор реального времени (RTF) составляет около 0.3. 119 | 120 | ##### Стабильность модели кажется недостаточно хорошей, возникают проблемы с множественными говорящими или плохим качеством аудио. 121 | 122 | Это проблема, которая обычно возникает с авторегрессивными моделями (для bark и valle). Это обычно трудно избежать. Можно попробовать несколько образцов, чтобы найти подходящий результат. 123 | 124 | ##### Помимо смеха, можем ли мы контролировать что-то еще? Можем ли мы контролировать другие эмоции? 125 | 126 | В текущей выпущенной модели единственными элементами управления на уровне токенов являются [laugh], [uv_break] и [lbreak]. В будущих версиях мы можем открыть модели с дополнительными возможностями контроля эмоций. 127 | 128 | --- 129 | ## Благодарности 130 | - [bark](https://github.com/suno-ai/bark), [XTTSv2](https://github.com/coqui-ai/TTS) и [valle](https://arxiv.org/abs/2301.02111) демонстрируют замечательный результат TTS с помощью системы авторегрессивного стиля. 131 | - [fish-speech](https://github.com/fishaudio/fish-speech) показывает возможности GVQ как аудио токенизатора для моделирования LLM. 132 | - [vocos](https://github.com/gemelo-ai/vocos), который используется в качестве предварительно обученного вокодера. 133 | 134 | --- 135 | ## Особая благодарность 136 | - [wlu-audio lab](https://audio.westlake.edu.cn/) за ранние эксперименты с алгоритмами. 137 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/2noise/ChatTTS/1092c1ffcaa82f4bdef104c35aa5541227c3e1d7/examples/__init__.py -------------------------------------------------------------------------------- /examples/api/README.md: -------------------------------------------------------------------------------- 1 | # Generating voice with ChatTTS via API 2 | 3 | ## Install requirements 4 | 5 | Install `FastAPI` and `requests`: 6 | 7 | ``` 8 | pip install -r examples/api/requirements.txt 9 | ``` 10 | 11 | ## Run API server 12 | 13 | ``` 14 | fastapi dev examples/api/main.py --host 0.0.0.0 --port 8000 15 | ``` 16 | 17 | ## Run openAI_API server 18 | 19 | ``` 20 | fastapi dev examples/api/openai_api.py --host 0.0.0.0 --port 8000 21 | ``` 22 | ## Generate audio using requests 23 | 24 | ``` 25 | python examples/api/client.py 26 | ``` 27 | 28 | mp3 audio files will be saved to the `output` directory. 29 | -------------------------------------------------------------------------------- /examples/api/client.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import zipfile 4 | from io import BytesIO 5 | 6 | import requests 7 | 8 | chattts_service_host = os.environ.get("CHATTTS_SERVICE_HOST", "localhost") 9 | chattts_service_port = os.environ.get("CHATTTS_SERVICE_PORT", "8000") 10 | 11 | CHATTTS_URL = f"http://{chattts_service_host}:{chattts_service_port}/generate_voice" 12 | 13 | 14 | # main infer params 15 | body = { 16 | "text": [ 17 | "四川美食确实以辣闻名,但也有不辣的选择。", 18 | "比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。", 19 | ], 20 | "stream": False, 21 | "lang": None, 22 | "skip_refine_text": True, 23 | "refine_text_only": False, 24 | "use_decoder": True, 25 | "audio_seed": 12345678, 26 | "text_seed": 87654321, 27 | "do_text_normalization": True, 28 | "do_homophone_replacement": False, 29 | } 30 | 31 | # refine text params 32 | params_refine_text = { 33 | "prompt": "", 34 | "top_P": 0.7, 35 | "top_K": 20, 36 | "temperature": 0.7, 37 | "repetition_penalty": 1, 38 | "max_new_token": 384, 39 | "min_new_token": 0, 40 | "show_tqdm": True, 41 | "ensure_non_empty": True, 42 | "stream_batch": 24, 43 | } 44 | body["params_refine_text"] = params_refine_text 45 | 46 | # infer code params 47 | params_infer_code = { 48 | "prompt": "[speed_5]", 49 | "top_P": 0.1, 50 | "top_K": 20, 51 | "temperature": 0.3, 52 | "repetition_penalty": 1.05, 53 | "max_new_token": 2048, 54 | "min_new_token": 0, 55 | "show_tqdm": True, 56 | "ensure_non_empty": True, 57 | "stream_batch": True, 58 | "spk_emb": None, 59 | } 60 | body["params_infer_code"] = params_infer_code 61 | 62 | 63 | try: 64 | response = requests.post(CHATTTS_URL, json=body) 65 | response.raise_for_status() 66 | with zipfile.ZipFile(BytesIO(response.content), "r") as zip_ref: 67 | # save files for each request in a different folder 68 | dt = datetime.datetime.now() 69 | ts = int(dt.timestamp()) 70 | tgt = f"./output/{ts}/" 71 | os.makedirs(tgt, 0o755) 72 | zip_ref.extractall(tgt) 73 | print("Extracted files into", tgt) 74 | 75 | except requests.exceptions.RequestException as e: 76 | print(f"Request Error: {e}") 77 | -------------------------------------------------------------------------------- /examples/api/main.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import sys 4 | import zipfile 5 | 6 | from fastapi import FastAPI 7 | from fastapi.responses import StreamingResponse 8 | 9 | 10 | if sys.platform == "darwin": 11 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 12 | 13 | now_dir = os.getcwd() 14 | sys.path.append(now_dir) 15 | 16 | from typing import Optional 17 | 18 | import ChatTTS 19 | 20 | from tools.audio import pcm_arr_to_mp3_view 21 | from tools.logger import get_logger 22 | import torch 23 | 24 | 25 | from pydantic import BaseModel 26 | from fastapi.exceptions import RequestValidationError 27 | from fastapi.responses import JSONResponse 28 | from tools.normalizer.en import normalizer_en_nemo_text 29 | from tools.normalizer.zh import normalizer_zh_tn 30 | 31 | logger = get_logger("Command") 32 | 33 | app = FastAPI() 34 | 35 | 36 | @app.on_event("startup") 37 | async def startup_event(): 38 | global chat 39 | 40 | chat = ChatTTS.Chat(get_logger("ChatTTS")) 41 | chat.normalizer.register("en", normalizer_en_nemo_text()) 42 | chat.normalizer.register("zh", normalizer_zh_tn()) 43 | 44 | logger.info("Initializing ChatTTS...") 45 | if chat.load(source="huggingface"): 46 | logger.info("Models loaded successfully.") 47 | else: 48 | logger.error("Models load failed.") 49 | sys.exit(1) 50 | 51 | 52 | @app.exception_handler(RequestValidationError) 53 | async def validation_exception_handler(request, exc: RequestValidationError): 54 | logger.error(f"Validation error: {exc.errors()}") 55 | return JSONResponse(status_code=422, content={"detail": exc.errors()}) 56 | 57 | 58 | class ChatTTSParams(BaseModel): 59 | text: list[str] 60 | stream: bool = False 61 | lang: Optional[str] = None 62 | skip_refine_text: bool = False 63 | refine_text_only: bool = False 64 | use_decoder: bool = True 65 | do_text_normalization: bool = True 66 | do_homophone_replacement: bool = False 67 | params_refine_text: ChatTTS.Chat.RefineTextParams = None 68 | params_infer_code: ChatTTS.Chat.InferCodeParams 69 | 70 | 71 | @app.post("/generate_voice") 72 | async def generate_voice(params: ChatTTSParams): 73 | logger.info("Text input: %s", str(params.text)) 74 | 75 | # audio seed 76 | if params.params_infer_code.manual_seed is not None: 77 | torch.manual_seed(params.params_infer_code.manual_seed) 78 | params.params_infer_code.spk_emb = chat.sample_random_speaker() 79 | 80 | # text seed for text refining 81 | if params.params_refine_text: 82 | text = chat.infer( 83 | text=params.text, skip_refine_text=False, refine_text_only=True 84 | ) 85 | logger.info(f"Refined text: {text}") 86 | else: 87 | # no text refining 88 | text = params.text 89 | 90 | logger.info("Use speaker:") 91 | logger.info(params.params_infer_code.spk_emb) 92 | 93 | logger.info("Start voice inference.") 94 | wavs = chat.infer( 95 | text=text, 96 | stream=params.stream, 97 | lang=params.lang, 98 | skip_refine_text=params.skip_refine_text, 99 | use_decoder=params.use_decoder, 100 | do_text_normalization=params.do_text_normalization, 101 | do_homophone_replacement=params.do_homophone_replacement, 102 | params_infer_code=params.params_infer_code, 103 | params_refine_text=params.params_refine_text, 104 | ) 105 | logger.info("Inference completed.") 106 | 107 | # zip all of the audio files together 108 | buf = io.BytesIO() 109 | with zipfile.ZipFile( 110 | buf, "a", compression=zipfile.ZIP_DEFLATED, allowZip64=False 111 | ) as f: 112 | for idx, wav in enumerate(wavs): 113 | f.writestr(f"{idx}.mp3", pcm_arr_to_mp3_view(wav)) 114 | logger.info("Audio generation successful.") 115 | buf.seek(0) 116 | 117 | response = StreamingResponse(buf, media_type="application/zip") 118 | response.headers["Content-Disposition"] = "attachment; filename=audio_files.zip" 119 | return response 120 | -------------------------------------------------------------------------------- /examples/api/postScript.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import zipfile 5 | from io import BytesIO 6 | 7 | import requests 8 | 9 | chattts_service_host = os.environ.get("CHATTTS_SERVICE_HOST", "127.0.0.1") 10 | chattts_service_port = os.environ.get("CHATTTS_SERVICE_PORT", "9900") 11 | 12 | CHATTTS_URL = f"http://{chattts_service_host}:{chattts_service_port}/generate_voice" 13 | 14 | 15 | def parse_arguments(): 16 | parser = argparse.ArgumentParser(description="HTTP client for ChatTTS service") 17 | parser.add_argument( 18 | "--text", type=str, nargs="+", required=True, help="Text to synthesize" 19 | ) 20 | parser.add_argument( 21 | "--audio_seed", type=int, required=True, help="Audio generation seed" 22 | ) 23 | parser.add_argument( 24 | "--text_seed", type=int, required=True, help="Text generation seed" 25 | ) 26 | parser.add_argument( 27 | "--stream", type=bool, default=False, help="Enable/disable streaming" 28 | ) 29 | parser.add_argument("--lang", type=str, default=None, help="Language code for text") 30 | parser.add_argument( 31 | "--skip_refine_text", type=bool, default=True, help="Skip text refinement" 32 | ) 33 | parser.add_argument( 34 | "--refine_text_only", type=bool, default=False, help="Only refine text" 35 | ) 36 | parser.add_argument( 37 | "--use_decoder", type=bool, default=True, help="Use decoder during inference" 38 | ) 39 | parser.add_argument( 40 | "--do_text_normalization", 41 | type=bool, 42 | default=True, 43 | help="Enable text normalization", 44 | ) 45 | parser.add_argument( 46 | "--do_homophone_replacement", 47 | type=bool, 48 | default=False, 49 | help="Enable homophone replacement", 50 | ) 51 | parser.add_argument( 52 | "--tgt", 53 | type=str, 54 | default="./output", 55 | help="Target directory to save output files", 56 | ) 57 | parser.add_argument( 58 | "--filename", 59 | type=str, 60 | default="test.mp3", 61 | help="Target directory to save output files", 62 | ) 63 | 64 | # Refinement text parameters 65 | parser.add_argument( 66 | "--refine_prompt", type=str, default="", help="Prompt for text refinement" 67 | ) 68 | parser.add_argument( 69 | "--refine_top_P", 70 | type=float, 71 | default=0.7, 72 | help="Top P value for text refinement", 73 | ) 74 | parser.add_argument( 75 | "--refine_top_K", type=int, default=20, help="Top K value for text refinement" 76 | ) 77 | parser.add_argument( 78 | "--refine_temperature", 79 | type=float, 80 | default=0.7, 81 | help="Temperature for text refinement", 82 | ) 83 | parser.add_argument( 84 | "--refine_repetition_penalty", 85 | type=float, 86 | default=1.0, 87 | help="Repetition penalty for text refinement", 88 | ) 89 | parser.add_argument( 90 | "--refine_max_new_token", 91 | type=int, 92 | default=384, 93 | help="Max new tokens for text refinement", 94 | ) 95 | parser.add_argument( 96 | "--refine_min_new_token", 97 | type=int, 98 | default=0, 99 | help="Min new tokens for text refinement", 100 | ) 101 | parser.add_argument( 102 | "--refine_show_tqdm", 103 | type=bool, 104 | default=True, 105 | help="Show progress bar for text refinement", 106 | ) 107 | parser.add_argument( 108 | "--refine_ensure_non_empty", 109 | type=bool, 110 | default=True, 111 | help="Ensure non-empty output", 112 | ) 113 | parser.add_argument( 114 | "--refine_stream_batch", 115 | type=int, 116 | default=24, 117 | help="Stream batch size for refinement", 118 | ) 119 | 120 | # Infer code parameters 121 | parser.add_argument( 122 | "--infer_prompt", type=str, default="[speed_5]", help="Prompt for inference" 123 | ) 124 | parser.add_argument( 125 | "--infer_top_P", type=float, default=0.1, help="Top P value for inference" 126 | ) 127 | parser.add_argument( 128 | "--infer_top_K", type=int, default=20, help="Top K value for inference" 129 | ) 130 | parser.add_argument( 131 | "--infer_temperature", type=float, default=0.3, help="Temperature for inference" 132 | ) 133 | parser.add_argument( 134 | "--infer_repetition_penalty", 135 | type=float, 136 | default=1.05, 137 | help="Repetition penalty for inference", 138 | ) 139 | parser.add_argument( 140 | "--infer_max_new_token", 141 | type=int, 142 | default=2048, 143 | help="Max new tokens for inference", 144 | ) 145 | parser.add_argument( 146 | "--infer_min_new_token", 147 | type=int, 148 | default=0, 149 | help="Min new tokens for inference", 150 | ) 151 | parser.add_argument( 152 | "--infer_show_tqdm", 153 | type=bool, 154 | default=True, 155 | help="Show progress bar for inference", 156 | ) 157 | parser.add_argument( 158 | "--infer_ensure_non_empty", 159 | type=bool, 160 | default=True, 161 | help="Ensure non-empty output", 162 | ) 163 | parser.add_argument( 164 | "--infer_stream_batch", 165 | type=bool, 166 | default=True, 167 | help="Stream batch for inference", 168 | ) 169 | parser.add_argument( 170 | "--infer_spk_emb", 171 | type=str, 172 | default=None, 173 | help="Speaker embedding for inference", 174 | ) 175 | 176 | return parser.parse_args() 177 | 178 | 179 | def main(): 180 | args = parse_arguments() 181 | 182 | # Main infer params 183 | body = { 184 | "text": args.text, 185 | "stream": args.stream, 186 | "lang": args.lang, 187 | "filename": args.filename, 188 | "skip_refine_text": args.skip_refine_text, 189 | "refine_text_only": args.refine_text_only, 190 | "use_decoder": args.use_decoder, 191 | "audio_seed": args.audio_seed, 192 | "text_seed": args.text_seed, 193 | "do_text_normalization": args.do_text_normalization, 194 | "do_homophone_replacement": args.do_homophone_replacement, 195 | } 196 | # Refinement text parameters 197 | params_refine_text = { 198 | "prompt": args.refine_prompt, 199 | "top_P": args.refine_top_P, 200 | "top_K": args.refine_top_K, 201 | "temperature": args.refine_temperature, 202 | "repetition_penalty": args.refine_repetition_penalty, 203 | "max_new_token": args.refine_max_new_token, 204 | "min_new_token": args.refine_min_new_token, 205 | "show_tqdm": args.refine_show_tqdm, 206 | "ensure_non_empty": args.refine_ensure_non_empty, 207 | "stream_batch": args.refine_stream_batch, 208 | } 209 | body["params_refine_text"] = params_refine_text 210 | 211 | # Infer code parameters 212 | params_infer_code = { 213 | "prompt": args.infer_prompt, 214 | "top_P": args.infer_top_P, 215 | "top_K": args.infer_top_K, 216 | "temperature": args.infer_temperature, 217 | "repetition_penalty": args.infer_repetition_penalty, 218 | "max_new_token": args.infer_max_new_token, 219 | "min_new_token": args.infer_min_new_token, 220 | "show_tqdm": args.infer_show_tqdm, 221 | "ensure_non_empty": args.infer_ensure_non_empty, 222 | "stream_batch": args.infer_stream_batch, 223 | "spk_emb": args.infer_spk_emb, 224 | } 225 | body["params_infer_code"] = params_infer_code 226 | 227 | try: 228 | response = requests.post(CHATTTS_URL, json=body) 229 | response.raise_for_status() 230 | with zipfile.ZipFile(BytesIO(response.content), "r") as zip_ref: 231 | tgt = args.tgt 232 | # filename=args.filename 233 | os.makedirs(tgt, exist_ok=True) 234 | zip_ref.extractall(tgt) 235 | print(f"Extracted files:{tgt}/{filename}") 236 | # print(tgt) 237 | except requests.exceptions.RequestException as e: 238 | print(f"Request Error: {e}") 239 | 240 | 241 | if __name__ == "__main__": 242 | main() 243 | -------------------------------------------------------------------------------- /examples/api/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | requests 3 | -------------------------------------------------------------------------------- /examples/cmd/run.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | if sys.platform == "darwin": 4 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 5 | 6 | now_dir = os.getcwd() 7 | sys.path.append(now_dir) 8 | 9 | from typing import Optional, List 10 | import argparse 11 | 12 | import numpy as np 13 | 14 | import ChatTTS 15 | 16 | from tools.logger import get_logger 17 | from tools.audio import pcm_arr_to_mp3_view 18 | from tools.normalizer.en import normalizer_en_nemo_text 19 | from tools.normalizer.zh import normalizer_zh_tn 20 | 21 | logger = get_logger("Command") 22 | 23 | 24 | def save_mp3_file(wav, index): 25 | data = pcm_arr_to_mp3_view(wav) 26 | mp3_filename = f"output_audio_{index}.mp3" 27 | with open(mp3_filename, "wb") as f: 28 | f.write(data) 29 | logger.info(f"Audio saved to {mp3_filename}") 30 | 31 | 32 | def load_normalizer(chat: ChatTTS.Chat): 33 | # try to load normalizer 34 | try: 35 | chat.normalizer.register("en", normalizer_en_nemo_text()) 36 | except ValueError as e: 37 | logger.error(e) 38 | except BaseException: 39 | logger.warning("Package nemo_text_processing not found!") 40 | logger.warning( 41 | "Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing", 42 | ) 43 | try: 44 | chat.normalizer.register("zh", normalizer_zh_tn()) 45 | except ValueError as e: 46 | logger.error(e) 47 | except BaseException: 48 | logger.warning("Package WeTextProcessing not found!") 49 | logger.warning( 50 | "Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing", 51 | ) 52 | 53 | 54 | def main( 55 | texts: List[str], 56 | spk: Optional[str] = None, 57 | stream: bool = False, 58 | source: str = "local", 59 | custom_path: str = "", 60 | ): 61 | logger.info("Text input: %s", str(texts)) 62 | 63 | chat = ChatTTS.Chat(get_logger("ChatTTS")) 64 | logger.info("Initializing ChatTTS...") 65 | load_normalizer(chat) 66 | 67 | is_load = False 68 | if os.path.isdir(custom_path) and source == "custom": 69 | is_load = chat.load(source="custom", custom_path=custom_path) 70 | else: 71 | is_load = chat.load(source=source) 72 | 73 | if is_load: 74 | logger.info("Models loaded successfully.") 75 | else: 76 | logger.error("Models load failed.") 77 | sys.exit(1) 78 | 79 | if spk is None: 80 | spk = chat.sample_random_speaker() 81 | logger.info("Use speaker:") 82 | print(spk) 83 | 84 | logger.info("Start inference.") 85 | wavs = chat.infer( 86 | texts, 87 | stream, 88 | params_infer_code=ChatTTS.Chat.InferCodeParams( 89 | spk_emb=spk, 90 | ), 91 | ) 92 | logger.info("Inference completed.") 93 | # Save each generated wav file to a local file 94 | if stream: 95 | wavs_list = [] 96 | for index, wav in enumerate(wavs): 97 | if stream: 98 | for i, w in enumerate(wav): 99 | save_mp3_file(w, (i + 1) * 1000 + index) 100 | wavs_list.append(wav) 101 | else: 102 | save_mp3_file(wav, index) 103 | if stream: 104 | for index, wav in enumerate(np.concatenate(wavs_list, axis=1)): 105 | save_mp3_file(wav, index) 106 | logger.info("Audio generation successful.") 107 | 108 | 109 | if __name__ == "__main__": 110 | r""" 111 | python -m examples.cmd.run \ 112 | --source custom --custom_path ../../models/2Noise/ChatTTS 你好喲 ":)" 113 | """ 114 | logger.info("Starting ChatTTS commandline demo...") 115 | parser = argparse.ArgumentParser( 116 | description="ChatTTS Command", 117 | usage='[--spk xxx] [--stream] [--source ***] [--custom_path XXX] "Your text 1." " Your text 2."', 118 | ) 119 | parser.add_argument( 120 | "--spk", 121 | help="Speaker (empty to sample a random one)", 122 | type=Optional[str], 123 | default=None, 124 | ) 125 | parser.add_argument( 126 | "--stream", 127 | help="Use stream mode", 128 | action="store_true", 129 | ) 130 | parser.add_argument( 131 | "--source", 132 | help="source form [ huggingface(hf download), local(ckpt save to asset dir), custom(define) ]", 133 | type=str, 134 | default="local", 135 | ) 136 | parser.add_argument( 137 | "--custom_path", 138 | help="custom defined model path(include asset ckpt dir)", 139 | type=str, 140 | default="", 141 | ) 142 | parser.add_argument( 143 | "texts", 144 | help="Original text", 145 | default=["YOUR TEXT HERE"], 146 | nargs=argparse.REMAINDER, 147 | ) 148 | args = parser.parse_args() 149 | logger.info(args) 150 | main(args.texts, args.spk, args.stream, args.source, args.custom_path) 151 | logger.info("ChatTTS process finished.") 152 | -------------------------------------------------------------------------------- /examples/cmd/stream.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | 5 | from tools.audio import float_to_int16 6 | 7 | 8 | # 流式推理数据获取器,支持流式获取音频编码字节流 9 | class ChatStreamer: 10 | def __init__(self, base_block_size=8000): 11 | self.base_block_size = base_block_size 12 | 13 | # stream状态更新。数据量不足的stream,先存一段时间,直到拿到足够数据,监控小块数据情况 14 | @staticmethod 15 | def _update_stream(history_stream_wav, new_stream_wav, thre): 16 | if history_stream_wav is not None: 17 | result_stream = np.concatenate([history_stream_wav, new_stream_wav], axis=1) 18 | is_keep_next = result_stream.shape[0] * result_stream.shape[1] < thre 19 | if random.random() > 0.1: 20 | print( 21 | "update_stream", 22 | is_keep_next, 23 | [i.shape if i is not None else None for i in result_stream], 24 | ) 25 | else: 26 | result_stream = new_stream_wav 27 | is_keep_next = result_stream.shape[0] * result_stream.shape[1] < thre 28 | 29 | return result_stream, is_keep_next 30 | 31 | # 已推理batch数据保存 32 | @staticmethod 33 | def _accum(accum_wavs, stream_wav): 34 | if accum_wavs is None: 35 | accum_wavs = stream_wav 36 | else: 37 | accum_wavs = np.concatenate([accum_wavs, stream_wav], axis=1) 38 | return accum_wavs 39 | 40 | # batch stream数据格式转化 41 | @staticmethod 42 | def batch_stream_formatted(stream_wav, output_format="PCM16_byte"): 43 | if output_format in ("PCM16_byte", "PCM16"): 44 | format_data = float_to_int16(stream_wav) 45 | else: 46 | format_data = stream_wav 47 | return format_data 48 | 49 | # 数据格式转化 50 | @staticmethod 51 | def formatted(data, output_format="PCM16_byte"): 52 | if output_format == "PCM16_byte": 53 | format_data = data.astype(" 1e-6).sum() 83 | if n_valid_texts == 0: 84 | continue 85 | else: 86 | block_thre = n_valid_texts * self.base_block_size 87 | stream_wav, is_keep_next = ChatStreamer._update_stream( 88 | history_stream_wav, stream_wav, block_thre 89 | ) 90 | # 数据量不足,先保存状态 91 | if is_keep_next: 92 | history_stream_wav = stream_wav 93 | continue 94 | # 数据量足够,执行写入操作 95 | else: 96 | history_stream_wav = None 97 | stream_wav = ChatStreamer.batch_stream_formatted( 98 | stream_wav, output_format 99 | ) 100 | article_streamwavs = ChatStreamer._accum( 101 | article_streamwavs, stream_wav 102 | ) 103 | # 写入当前句子 104 | if ChatStreamer.checkvoice(stream_wav[curr_sentence_index]): 105 | for sub_wav in ChatStreamer._subgen( 106 | stream_wav[curr_sentence_index] 107 | ): 108 | if ChatStreamer.checkvoice(sub_wav): 109 | yield ChatStreamer.formatted(sub_wav, output_format) 110 | # 当前句子已写入完成,直接写下一个句子已经推理完成的部分 111 | elif curr_sentence_index < n_texts - 1: 112 | curr_sentence_index += 1 113 | print("add next sentence") 114 | finish_stream_wavs = article_streamwavs[curr_sentence_index] 115 | 116 | for sub_wav in ChatStreamer._subgen(finish_stream_wavs): 117 | if ChatStreamer.checkvoice(sub_wav): 118 | yield ChatStreamer.formatted(sub_wav, output_format) 119 | 120 | # streamchat遍历完毕,在外层把剩余结果写入 121 | else: 122 | break 123 | # 本轮剩余最后一点数据写入 124 | if is_keep_next: 125 | if len(list(filter(lambda x: x is not None, stream_wav))) > 0: 126 | stream_wav = ChatStreamer.batch_stream_formatted( 127 | stream_wav, output_format 128 | ) 129 | if ChatStreamer.checkvoice(stream_wav[curr_sentence_index]): 130 | 131 | for sub_wav in ChatStreamer._subgen( 132 | stream_wav[curr_sentence_index] 133 | ): 134 | if ChatStreamer.checkvoice(sub_wav): 135 | yield ChatStreamer.formatted(sub_wav, output_format) 136 | article_streamwavs = ChatStreamer._accum( 137 | article_streamwavs, stream_wav 138 | ) 139 | # 把已经完成推理的下几轮剩余数据写入 140 | for i_text in range(curr_sentence_index + 1, n_texts): 141 | finish_stream_wavs = article_streamwavs[i_text] 142 | 143 | for sub_wav in ChatStreamer._subgen(finish_stream_wavs): 144 | if ChatStreamer.checkvoice(sub_wav): 145 | yield ChatStreamer.formatted(sub_wav, output_format) 146 | 147 | # 流式播放接口 148 | def play(self, streamchat, wait=5): 149 | import pyaudio # please install it manually 150 | 151 | p = pyaudio.PyAudio() 152 | print(p.get_device_count()) 153 | # 设置音频流参数 154 | FORMAT = pyaudio.paInt16 # 16位深度 155 | CHANNELS = 1 # 单声道 156 | RATE = 24000 # 采样率 157 | CHUNK = 1024 # 每块音频数据大小 158 | 159 | # 打开输出流(扬声器) 160 | stream_out = p.open( 161 | format=FORMAT, 162 | channels=CHANNELS, 163 | rate=RATE, 164 | output=True, 165 | ) 166 | 167 | first_prefill_size = wait * RATE 168 | prefill_bytes = b"" 169 | meet = False 170 | for i in self.generate(streamchat, output_format="PCM16_byte"): 171 | if not meet: 172 | prefill_bytes += i 173 | if len(prefill_bytes) > first_prefill_size: 174 | meet = True 175 | stream_out.write(prefill_bytes) 176 | else: 177 | stream_out.write(i) 178 | if not meet: 179 | stream_out.write(prefill_bytes) 180 | 181 | stream_out.stop_stream() 182 | stream_out.close() 183 | 184 | 185 | if __name__ == "__main__": 186 | import ChatTTS 187 | 188 | # 加载 ChatTTS 189 | chat = ChatTTS.Chat() 190 | chat.load(compile=False) 191 | 192 | rand_spk = chat.sample_random_speaker() 193 | params_infer_code = ChatTTS.Chat.InferCodeParams( 194 | spk_emb=rand_spk, # add sampled speaker 195 | temperature=0.3, # using custom temperature 196 | top_P=0.7, # top P decode 197 | top_K=20, # top K decode 198 | ) 199 | 200 | # 获取ChatTTS 流式推理generator 201 | streamchat = chat.infer( 202 | [ 203 | "总结一下,AI Agent是大模型功能的扩展,让AI更接近于通用人工智能,也就是我们常说的AGI。", 204 | "你太聪明啦。", 205 | "举个例子,大模型可能可以写代码,但它不能独立完成一个完整的软件开发项目。这时候,AI Agent就根据大模型的智能,结合记忆和规划,一步步实现从需求分析到产品上线。", 206 | ], 207 | skip_refine_text=True, 208 | stream=True, 209 | params_infer_code=params_infer_code, 210 | ) 211 | # 先存放一部分,存的差不多了再播放,适合生成速度比较慢的cpu玩家使用 212 | ChatStreamer().play(streamchat, wait=5) 213 | -------------------------------------------------------------------------------- /examples/onnx/README.md: -------------------------------------------------------------------------------- 1 | # Export onnx or JIT models for deployment 2 | 3 | ## Run `pip install onnx -U`. 4 | 5 | ## Export GPT 6 | 7 | 3. Run `python examples/onnx/exporter.py --gpt` 8 | 9 | 10 | ## Export other models 11 | Run `python examples/onnx/exporter.py --decoder --vocos` 12 | 13 | ## Reference 14 | [Run LLMs on Sophon TPU](https://github.com/sophgo/LLM-TPU) -------------------------------------------------------------------------------- /examples/onnx/gpt.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.utils.parametrizations import weight_norm 7 | 8 | from modeling_llama import LlamaModel, LlamaConfig 9 | 10 | 11 | class GPT(nn.Module): 12 | def __init__( 13 | self, 14 | gpt_config: dict, 15 | num_audio_tokens: int = 626, 16 | num_text_tokens: int = 21178, 17 | num_vq=4, 18 | use_flash_attn=False, 19 | device=torch.device("cpu"), 20 | logger=logging.getLogger(__name__), 21 | ): 22 | super().__init__() 23 | 24 | self.logger = logger 25 | 26 | self.device = device 27 | self.device_gpt = device if "mps" not in str(device) else torch.device("cpu") 28 | 29 | self.num_vq = num_vq 30 | self.num_audio_tokens = num_audio_tokens 31 | 32 | self.use_flash_attn = use_flash_attn 33 | 34 | self.gpt, self.llama_config = self._build_llama(gpt_config, self.device_gpt) 35 | self.is_te_llama = False 36 | self.model_dim = int(self.gpt.config.hidden_size) 37 | self.emb_code = nn.ModuleList( 38 | [ 39 | nn.Embedding( 40 | num_audio_tokens, 41 | self.model_dim, 42 | device=self.device_gpt, 43 | ) 44 | for _ in range(num_vq) 45 | ], 46 | ) 47 | self.emb_text = nn.Embedding( 48 | num_text_tokens, self.model_dim, device=self.device_gpt 49 | ) 50 | 51 | self.head_text = weight_norm( 52 | nn.Linear( 53 | self.model_dim, 54 | num_text_tokens, 55 | bias=False, 56 | device=device, 57 | ), 58 | name="weight", 59 | ) 60 | self.head_code = nn.ModuleList( 61 | [ 62 | weight_norm( 63 | nn.Linear( 64 | self.model_dim, 65 | num_audio_tokens, 66 | bias=False, 67 | device=device, 68 | ), 69 | name="weight", 70 | ) 71 | for _ in range(self.num_vq) 72 | ], 73 | ) 74 | 75 | def from_pretrained(self, file_path: str): 76 | self.load_state_dict( 77 | torch.load(file_path, weights_only=True, mmap=True), strict=False 78 | ) 79 | 80 | def _build_llama( 81 | self, 82 | config: dict, 83 | device: torch.device, 84 | ) -> Tuple[LlamaModel, LlamaConfig]: 85 | 86 | llama_config = LlamaConfig(**config) 87 | model = LlamaModel(llama_config) 88 | del model.embed_tokens 89 | return model.to(device), llama_config 90 | -------------------------------------------------------------------------------- /examples/web/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/2noise/ChatTTS/1092c1ffcaa82f4bdef104c35aa5541227c3e1d7/examples/web/__init__.py -------------------------------------------------------------------------------- /examples/web/ex.py: -------------------------------------------------------------------------------- 1 | ex = [ 2 | [ 3 | "四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。", 4 | 0.3, 5 | 0.7, 6 | 20, 7 | 2, 8 | 42, 9 | True, 10 | ], 11 | [ 12 | "What is your favorite english food?", 13 | 0.5, 14 | 0.5, 15 | 10, 16 | 245, 17 | 531, 18 | True, 19 | ], 20 | [ 21 | "chat T T S is a text to speech model designed for dialogue applications. [uv_break]it supports mixed language input [uv_break]and offers multi speaker capabilities with precise control over prosodic elements like [uv_break]laughter[uv_break][laugh], [uv_break]pauses, [uv_break]and intonation. [uv_break]it delivers natural and expressive speech,[uv_break]so please[uv_break] use the project responsibly at your own risk.[uv_break]", 22 | 0.8, 23 | 0.4, 24 | 7, 25 | 70, 26 | 165, 27 | False, 28 | ], 29 | ] 30 | -------------------------------------------------------------------------------- /examples/web/funcs.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | from time import sleep 4 | 5 | import gradio as gr 6 | 7 | import sys 8 | 9 | sys.path.append("..") 10 | sys.path.append("../..") 11 | from tools.audio import float_to_int16, has_ffmpeg_installed, load_audio 12 | from tools.logger import get_logger 13 | 14 | logger = get_logger(" WebUI ") 15 | 16 | from tools.seeder import TorchSeedContext 17 | from tools.normalizer import normalizer_en_nemo_text, normalizer_zh_tn 18 | 19 | import ChatTTS 20 | 21 | chat = ChatTTS.Chat(get_logger("ChatTTS")) 22 | 23 | custom_path: Optional[str] = None 24 | 25 | has_interrupted = False 26 | is_in_generate = False 27 | 28 | seed_min = 1 29 | seed_max = 4294967295 30 | 31 | use_mp3 = has_ffmpeg_installed() 32 | if not use_mp3: 33 | logger.warning("no ffmpeg installed, use wav file output") 34 | 35 | # 音色选项:用于预置合适的音色 36 | voices = { 37 | "Default": {"seed": 2}, 38 | "Timbre1": {"seed": 1111}, 39 | "Timbre2": {"seed": 2222}, 40 | "Timbre3": {"seed": 3333}, 41 | "Timbre4": {"seed": 4444}, 42 | "Timbre5": {"seed": 5555}, 43 | "Timbre6": {"seed": 6666}, 44 | "Timbre7": {"seed": 7777}, 45 | "Timbre8": {"seed": 8888}, 46 | "Timbre9": {"seed": 9999}, 47 | } 48 | 49 | 50 | def generate_seed(): 51 | return gr.update(value=random.randint(seed_min, seed_max)) 52 | 53 | 54 | # 返回选择音色对应的seed 55 | def on_voice_change(vocie_selection): 56 | return voices.get(vocie_selection)["seed"] 57 | 58 | 59 | def on_audio_seed_change(audio_seed_input): 60 | with TorchSeedContext(audio_seed_input): 61 | rand_spk = chat.sample_random_speaker() 62 | return rand_spk 63 | 64 | 65 | def load_chat(cust_path: Optional[str], coef: Optional[str]) -> bool: 66 | if cust_path == None: 67 | ret = chat.load(coef=coef) 68 | else: 69 | logger.info("local model path: %s", cust_path) 70 | ret = chat.load("custom", custom_path=cust_path, coef=coef) 71 | global custom_path 72 | custom_path = cust_path 73 | if ret: 74 | try: 75 | chat.normalizer.register("en", normalizer_en_nemo_text()) 76 | except ValueError as e: 77 | logger.error(e) 78 | except: 79 | logger.warning("Package nemo_text_processing not found!") 80 | logger.warning( 81 | "Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing", 82 | ) 83 | try: 84 | chat.normalizer.register("zh", normalizer_zh_tn()) 85 | except ValueError as e: 86 | logger.error(e) 87 | except: 88 | logger.warning("Package WeTextProcessing not found!") 89 | logger.warning( 90 | "Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing", 91 | ) 92 | return ret 93 | 94 | 95 | def reload_chat(coef: Optional[str]) -> str: 96 | global is_in_generate 97 | 98 | if is_in_generate: 99 | gr.Warning("Cannot reload when generating!") 100 | return coef 101 | 102 | chat.unload() 103 | gr.Info("Model unloaded.") 104 | if len(coef) != 230: 105 | gr.Warning("Ingore invalid DVAE coefficient.") 106 | coef = None 107 | try: 108 | global custom_path 109 | ret = load_chat(custom_path, coef) 110 | except Exception as e: 111 | raise gr.Error(str(e)) 112 | if not ret: 113 | raise gr.Error("Unable to load model.") 114 | gr.Info("Reload succeess.") 115 | return chat.coef 116 | 117 | 118 | def on_upload_sample_audio(sample_audio_input: Optional[str]) -> str: 119 | if sample_audio_input is None: 120 | return "" 121 | sample_audio = load_audio(sample_audio_input, 24000) 122 | spk_smp = chat.sample_audio_speaker(sample_audio) 123 | del sample_audio 124 | return spk_smp 125 | 126 | 127 | def _set_generate_buttons(generate_button, interrupt_button, is_reset=False): 128 | return gr.update( 129 | value=generate_button, visible=is_reset, interactive=is_reset 130 | ), gr.update(value=interrupt_button, visible=not is_reset, interactive=not is_reset) 131 | 132 | 133 | def refine_text( 134 | text, 135 | text_seed_input, 136 | refine_text_flag, 137 | temperature, 138 | top_P, 139 | top_K, 140 | split_batch, 141 | ): 142 | global chat 143 | 144 | if not refine_text_flag: 145 | sleep(1) # to skip fast answer of loading mark 146 | return text 147 | 148 | text = chat.infer( 149 | text, 150 | skip_refine_text=False, 151 | refine_text_only=True, 152 | params_refine_text=ChatTTS.Chat.RefineTextParams( 153 | temperature=temperature, 154 | top_P=top_P, 155 | top_K=top_K, 156 | manual_seed=text_seed_input, 157 | ), 158 | split_text=split_batch > 0, 159 | ) 160 | 161 | return text[0] if isinstance(text, list) else text 162 | 163 | 164 | def generate_audio( 165 | text, 166 | temperature, 167 | top_P, 168 | top_K, 169 | spk_emb_text: str, 170 | stream, 171 | audio_seed_input, 172 | sample_text_input, 173 | sample_audio_code_input, 174 | split_batch, 175 | ): 176 | global chat, has_interrupted 177 | 178 | if not text or has_interrupted or not spk_emb_text.startswith("蘁淰"): 179 | return None 180 | 181 | params_infer_code = ChatTTS.Chat.InferCodeParams( 182 | spk_emb=spk_emb_text, 183 | temperature=temperature, 184 | top_P=top_P, 185 | top_K=top_K, 186 | manual_seed=audio_seed_input, 187 | ) 188 | 189 | if sample_text_input and sample_audio_code_input: 190 | params_infer_code.txt_smp = sample_text_input 191 | params_infer_code.spk_smp = sample_audio_code_input 192 | params_infer_code.spk_emb = None 193 | 194 | wav = chat.infer( 195 | text, 196 | skip_refine_text=True, 197 | params_infer_code=params_infer_code, 198 | stream=stream, 199 | split_text=split_batch > 0, 200 | max_split_batch=split_batch, 201 | ) 202 | if stream: 203 | for gen in wav: 204 | audio = gen[0] 205 | if audio is not None and len(audio) > 0: 206 | yield 24000, float_to_int16(audio).T 207 | del audio 208 | else: 209 | yield 24000, float_to_int16(wav[0]).T 210 | 211 | 212 | def interrupt_generate(): 213 | global chat, has_interrupted 214 | 215 | has_interrupted = True 216 | chat.interrupt() 217 | 218 | 219 | def set_buttons_before_generate(generate_button, interrupt_button): 220 | global has_interrupted, is_in_generate 221 | 222 | has_interrupted = False 223 | is_in_generate = True 224 | 225 | return _set_generate_buttons( 226 | generate_button, 227 | interrupt_button, 228 | ) 229 | 230 | 231 | def set_buttons_after_generate(generate_button, interrupt_button, audio_output): 232 | global has_interrupted, is_in_generate 233 | 234 | is_in_generate = False 235 | 236 | return _set_generate_buttons( 237 | generate_button, 238 | interrupt_button, 239 | audio_output is not None or has_interrupted, 240 | ) 241 | -------------------------------------------------------------------------------- /examples/web/webui.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | if sys.platform == "darwin": 4 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 5 | 6 | now_dir = os.getcwd() 7 | sys.path.append(now_dir) 8 | 9 | import argparse 10 | 11 | import gradio as gr 12 | 13 | from funcs import * 14 | from ex import ex 15 | 16 | 17 | def main(): 18 | 19 | with gr.Blocks() as demo: 20 | gr.Markdown("# ChatTTS WebUI") 21 | gr.Markdown("- **GitHub Repo**: https://github.com/2noise/ChatTTS") 22 | gr.Markdown("- **HuggingFace Repo**: https://huggingface.co/2Noise/ChatTTS") 23 | 24 | with gr.Row(): 25 | with gr.Column(scale=2): 26 | text_input = gr.Textbox( 27 | label="Input Text", 28 | lines=4, 29 | max_lines=4, 30 | placeholder="Please Input Text...", 31 | value=ex[0][0], 32 | interactive=True, 33 | ) 34 | sample_text_input = gr.Textbox( 35 | label="Sample Text", 36 | lines=4, 37 | max_lines=4, 38 | placeholder="If Sample Audio and Sample Text are available, the Speaker Embedding will be disabled.", 39 | interactive=True, 40 | ) 41 | with gr.Column(): 42 | with gr.Tab(label="Sample Audio"): 43 | sample_audio_input = gr.Audio( 44 | value=None, 45 | type="filepath", 46 | interactive=True, 47 | show_label=False, 48 | waveform_options=gr.WaveformOptions( 49 | sample_rate=24000, 50 | ), 51 | scale=1, 52 | ) 53 | with gr.Tab(label="Sample Audio Code"): 54 | sample_audio_code_input = gr.Textbox( 55 | lines=12, 56 | max_lines=12, 57 | show_label=False, 58 | placeholder="Paste the Code copied before after uploading Sample Audio.", 59 | interactive=True, 60 | ) 61 | 62 | with gr.Row(): 63 | refine_text_checkbox = gr.Checkbox( 64 | label="Refine text", value=ex[0][6], interactive=True 65 | ) 66 | temperature_slider = gr.Slider( 67 | minimum=0.00001, 68 | maximum=1.0, 69 | step=0.00001, 70 | value=ex[0][1], 71 | label="Audio Temperature", 72 | interactive=True, 73 | ) 74 | top_p_slider = gr.Slider( 75 | minimum=0.1, 76 | maximum=0.9, 77 | step=0.05, 78 | value=ex[0][2], 79 | label="top_P", 80 | interactive=True, 81 | ) 82 | top_k_slider = gr.Slider( 83 | minimum=1, 84 | maximum=20, 85 | step=1, 86 | value=ex[0][3], 87 | label="top_K", 88 | interactive=True, 89 | ) 90 | 91 | with gr.Row(): 92 | voice_selection = gr.Dropdown( 93 | label="Timbre", 94 | choices=voices.keys(), 95 | value="Default", 96 | interactive=True, 97 | ) 98 | audio_seed_input = gr.Number( 99 | value=ex[0][4], 100 | label="Audio Seed", 101 | interactive=True, 102 | minimum=seed_min, 103 | maximum=seed_max, 104 | ) 105 | generate_audio_seed = gr.Button("\U0001f3b2", interactive=True) 106 | text_seed_input = gr.Number( 107 | value=ex[0][5], 108 | label="Text Seed", 109 | interactive=True, 110 | minimum=seed_min, 111 | maximum=seed_max, 112 | ) 113 | generate_text_seed = gr.Button("\U0001f3b2", interactive=True) 114 | 115 | with gr.Row(): 116 | spk_emb_text = gr.Textbox( 117 | label="Speaker Embedding", 118 | max_lines=3, 119 | show_copy_button=True, 120 | interactive=True, 121 | scale=2, 122 | ) 123 | dvae_coef_text = gr.Textbox( 124 | label="DVAE Coefficient", 125 | max_lines=3, 126 | show_copy_button=True, 127 | interactive=True, 128 | scale=2, 129 | ) 130 | reload_chat_button = gr.Button("Reload", scale=1, interactive=True) 131 | 132 | with gr.Row(): 133 | auto_play_checkbox = gr.Checkbox( 134 | label="Auto Play", value=False, scale=1, interactive=True 135 | ) 136 | stream_mode_checkbox = gr.Checkbox( 137 | label="Stream Mode", 138 | value=False, 139 | scale=1, 140 | interactive=True, 141 | ) 142 | split_batch_slider = gr.Slider( 143 | minimum=0, 144 | maximum=100, 145 | step=1, 146 | value=4, 147 | label="Split Batch", 148 | interactive=True, 149 | ) 150 | generate_button = gr.Button( 151 | "Generate", scale=2, variant="primary", interactive=True 152 | ) 153 | interrupt_button = gr.Button( 154 | "Interrupt", 155 | scale=2, 156 | variant="stop", 157 | visible=False, 158 | interactive=False, 159 | ) 160 | 161 | text_output = gr.Textbox( 162 | label="Output Text", 163 | interactive=False, 164 | show_copy_button=True, 165 | ) 166 | 167 | sample_audio_input.change( 168 | fn=on_upload_sample_audio, 169 | inputs=sample_audio_input, 170 | outputs=sample_audio_code_input, 171 | ).then(fn=lambda: gr.Info("Sampled Audio Code generated at another Tab.")) 172 | 173 | # 使用Gradio的回调功能来更新数值输入框 174 | voice_selection.change( 175 | fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input 176 | ) 177 | 178 | generate_audio_seed.click(generate_seed, outputs=audio_seed_input) 179 | 180 | generate_text_seed.click(generate_seed, outputs=text_seed_input) 181 | 182 | audio_seed_input.change( 183 | on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text 184 | ) 185 | 186 | reload_chat_button.click( 187 | reload_chat, inputs=dvae_coef_text, outputs=dvae_coef_text 188 | ) 189 | 190 | interrupt_button.click(interrupt_generate) 191 | 192 | @gr.render(inputs=[auto_play_checkbox, stream_mode_checkbox]) 193 | def make_audio(autoplay, stream): 194 | audio_output = gr.Audio( 195 | label="Output Audio", 196 | value=None, 197 | format="mp3" if use_mp3 and not stream else "wav", 198 | autoplay=autoplay, 199 | streaming=stream, 200 | interactive=False, 201 | show_label=True, 202 | waveform_options=gr.WaveformOptions( 203 | sample_rate=24000, 204 | ), 205 | ) 206 | generate_button.click( 207 | fn=set_buttons_before_generate, 208 | inputs=[generate_button, interrupt_button], 209 | outputs=[generate_button, interrupt_button], 210 | ).then( 211 | refine_text, 212 | inputs=[ 213 | text_input, 214 | text_seed_input, 215 | refine_text_checkbox, 216 | temperature_slider, 217 | top_p_slider, 218 | top_k_slider, 219 | split_batch_slider, 220 | ], 221 | outputs=text_output, 222 | ).then( 223 | generate_audio, 224 | inputs=[ 225 | text_output, 226 | temperature_slider, 227 | top_p_slider, 228 | top_k_slider, 229 | spk_emb_text, 230 | stream_mode_checkbox, 231 | audio_seed_input, 232 | sample_text_input, 233 | sample_audio_code_input, 234 | split_batch_slider, 235 | ], 236 | outputs=audio_output, 237 | ).then( 238 | fn=set_buttons_after_generate, 239 | inputs=[generate_button, interrupt_button, audio_output], 240 | outputs=[generate_button, interrupt_button], 241 | ) 242 | 243 | gr.Examples( 244 | examples=ex, 245 | inputs=[ 246 | text_input, 247 | temperature_slider, 248 | top_p_slider, 249 | top_k_slider, 250 | audio_seed_input, 251 | text_seed_input, 252 | refine_text_checkbox, 253 | ], 254 | ) 255 | 256 | parser = argparse.ArgumentParser(description="ChatTTS demo Launch") 257 | parser.add_argument( 258 | "--server_name", type=str, default="0.0.0.0", help="server name" 259 | ) 260 | parser.add_argument("--server_port", type=int, default=8080, help="server port") 261 | parser.add_argument("--root_path", type=str, help="root path") 262 | parser.add_argument("--custom_path", type=str, help="custom model path") 263 | parser.add_argument("--coef", type=str, help="custom dvae coefficient") 264 | args = parser.parse_args() 265 | 266 | logger.info("loading ChatTTS model...") 267 | 268 | if load_chat(args.custom_path, args.coef): 269 | logger.info("Models loaded successfully.") 270 | else: 271 | logger.error("Models load failed.") 272 | sys.exit(1) 273 | 274 | spk_emb_text.value = on_audio_seed_change(audio_seed_input.value) 275 | dvae_coef_text.value = chat.coef 276 | 277 | demo.launch( 278 | server_name=args.server_name, 279 | server_port=args.server_port, 280 | root_path=args.root_path, 281 | inbrowser=True, 282 | show_api=False, 283 | ) 284 | 285 | 286 | if __name__ == "__main__": 287 | main() 288 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<3.0.0 2 | numba 3 | torch>=2.1.0 4 | torchaudio 5 | tqdm 6 | vector_quantize_pytorch 7 | transformers>=4.41.1 8 | vocos 9 | IPython 10 | gradio 11 | pybase16384 12 | pynini==2.1.5; sys_platform == 'linux' 13 | WeTextProcessing; sys_platform == 'linux' 14 | nemo_text_processing; sys_platform == 'linux' 15 | av 16 | pydub 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup, find_packages 3 | 4 | version = "v0.0.0" 5 | 6 | setup( 7 | name="chattts", 8 | version=os.environ.get("CHTTS_VER", version).lstrip("v"), 9 | description="A generative speech model for daily dialogue", 10 | long_description=open("README.md", encoding="utf8").read(), 11 | long_description_content_type="text/markdown", 12 | author="2noise", 13 | author_email="open-source@2noise.com", 14 | maintainer="fumiama", 15 | url="https://github.com/2noise/ChatTTS", 16 | packages=find_packages(include=["ChatTTS", "ChatTTS.*"]), 17 | package_data={ 18 | "ChatTTS.res": ["homophones_map.json", "sha256_map.json"], 19 | }, 20 | license="AGPLv3+", 21 | install_requires=[ 22 | "numba", 23 | "numpy<3.0.0", 24 | "pybase16384", 25 | "torch>=2.1.0", 26 | "torchaudio", 27 | "tqdm", 28 | "transformers>=4.41.1", 29 | "vector_quantize_pytorch", 30 | "vocos", 31 | ], 32 | platforms="any", 33 | classifiers=[ 34 | "Programming Language :: Python :: 3", 35 | "Operating System :: OS Independent", 36 | "License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)", 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /tests/#511.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | if sys.platform == "darwin": 4 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 5 | 6 | now_dir = os.getcwd() 7 | sys.path.append(now_dir) 8 | 9 | import logging 10 | 11 | import ChatTTS 12 | 13 | from tools.logger import get_logger 14 | 15 | logger = get_logger("Test", lv=logging.WARN) 16 | 17 | chat = ChatTTS.Chat(logger) 18 | chat.load(compile=False, source="huggingface") # Set to True for better performance 19 | 20 | texts = [ 21 | "的 话 语 音 太 短 了 会 造 成 生 成 音 频 错 误 , 这 是 占 位 占 位 , 老 大 爷 觉 得 车 夫 的 想 法 很 有 道 理 [uv_break]", 22 | "的 话 评 分 只 是 衡 量 音 色 的 稳 定 性 , 不 代 表 音 色 的 好 坏 , 可 以 根 据 自 己 的 需 求 选 择 [uv_break] 合 适 的 音 色", 23 | "然 后 举 个 简 单 的 例 子 , 如 果 一 个 [uv_break] 沙 哑 且 结 巴 的 音 色 一 直 很 稳 定 , 那 么 它 的 评 分 就 会 很 高 。", 24 | "语 音 太 短 了 会 造 成 生 成 音 频 错 误 , 这 是 占 位 [uv_break] 占 位 。 我 使 用 seed id 去 生 成 音 频 , 但 是 生 成 的 音 频 不 稳 定", 25 | "在d id 只 是 一 个 参 考 id [uv_break] 不 同 的 环 境 下 音 色 不 一 定 一 致 。 还 是 推 荐 使 用 。 pt 文 件 载 入 音 色", 26 | "的 话 语 音 太 短 了 会 造 成 生 成 音 频 错 误 , 这 是 占 位 占 位 。 音 色 标 的 男 女 [uv_break] 准 确 吗", 27 | ", 当 前 第 一 批 测 试 的 音 色 有 两 千 条 [uv_break] , 根 据 声 纹 相 似 性 简 单 打 标 , 准 确 度 不 高 , 特 别 是 特 征 一 项", 28 | "语 音 太 短 了 会 造 成 生 成 音 频 错 误 , 这 是 占 位 占 位 。 仅 供 参 考 。 如 果 大 家 有 更 好 的 标 注 方 法 , 欢 迎 pr [uv_break] 。", 29 | ] 30 | 31 | params_infer_code = ChatTTS.Chat.InferCodeParams( 32 | spk_emb=chat.sample_random_speaker(), 33 | temperature=0.3, 34 | top_P=0.005, 35 | top_K=1, 36 | show_tqdm=False, 37 | ) 38 | 39 | fail = False 40 | 41 | wavs = chat.infer( 42 | texts, 43 | skip_refine_text=True, 44 | split_text=False, 45 | params_infer_code=params_infer_code, 46 | ) 47 | 48 | for k, wav in enumerate(wavs): 49 | if wav is None: 50 | logger.warning("index", k, "is None") 51 | fail = True 52 | 53 | if fail: 54 | import sys 55 | 56 | sys.exit(1) 57 | -------------------------------------------------------------------------------- /tests/#588.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | if sys.platform == "darwin": 4 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 5 | 6 | now_dir = os.getcwd() 7 | sys.path.append(now_dir) 8 | 9 | import logging 10 | import re 11 | 12 | import ChatTTS 13 | 14 | from tools.logger import get_logger 15 | 16 | logger = get_logger("Test", lv=logging.WARN) 17 | 18 | chat = ChatTTS.Chat(logger) 19 | chat.load(compile=False, source="huggingface") # Set to True for better performance 20 | 21 | texts = [ 22 | "总结一下,AI Agent是大模型功能的扩展,让AI更接近于通用人工智能,也就是我们常说的AGI。", 23 | "你真是太聪明啦。", 24 | ] 25 | 26 | fail = False 27 | 28 | refined = chat.infer( 29 | texts, 30 | refine_text_only=True, 31 | stream=False, 32 | split_text=False, 33 | params_refine_text=ChatTTS.Chat.RefineTextParams(show_tqdm=False), 34 | ) 35 | 36 | trimre = re.compile("\\[[\w_]+\\]") 37 | 38 | 39 | def trim_tags(txt: str) -> str: 40 | global trimre 41 | return trimre.sub("", txt) 42 | 43 | 44 | for i, t in enumerate(refined): 45 | if len(trim_tags(t)) > 4 * len(texts[i]): 46 | fail = True 47 | logger.warning("in: %s, out: %s", texts[i], t) 48 | 49 | if fail: 50 | import sys 51 | 52 | sys.exit(1) 53 | -------------------------------------------------------------------------------- /tests/#655.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | if sys.platform == "darwin": 4 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 5 | 6 | now_dir = os.getcwd() 7 | sys.path.append(now_dir) 8 | 9 | import logging 10 | 11 | import torch 12 | 13 | import ChatTTS 14 | 15 | from tools.logger import get_logger 16 | from tools.normalizer import normalizer_en_nemo_text 17 | 18 | logger = get_logger("Test", lv=logging.WARN) 19 | 20 | chat = ChatTTS.Chat(logger) 21 | chat.load(compile=False, source="huggingface") # Set to True for better performance 22 | try: 23 | chat.normalizer.register("en", normalizer_en_nemo_text()) 24 | except: 25 | logger.warning("Package nemo_text_processing not found!") 26 | 27 | rand_spk = chat.sample_random_speaker() 28 | 29 | 30 | text = ["What is [uv_break]your favorite english food?[laugh][lbreak]"] 31 | 32 | fail = False 33 | 34 | refined_text = chat.infer( 35 | text, 36 | refine_text_only=True, 37 | params_refine_text=ChatTTS.Chat.RefineTextParams( 38 | prompt="[oral_2][laugh_0][break_6]", 39 | manual_seed=12345, 40 | ), 41 | split_text=False, 42 | ) 43 | if refined_text[0] not in [ 44 | "what is [uv_break] your favorite english [uv_break] food [laugh] like [lbreak]", 45 | "like what is [uv_break] your favorite english food [laugh] [lbreak]", 46 | ]: 47 | fail = True 48 | logger.warning("refined text is '%s'", refined_text[0]) 49 | 50 | params = ChatTTS.Chat.InferCodeParams( 51 | spk_emb=rand_spk, # add sampled speaker 52 | temperature=0.3, # using custom temperature 53 | top_P=0.7, # top P decode 54 | top_K=20, # top K decode 55 | ) 56 | input_ids, attention_mask, text_mask = chat.tokenizer.encode( 57 | chat.speaker.decorate_code_prompts( 58 | text, 59 | params.prompt, 60 | params.txt_smp, 61 | params.spk_emb, 62 | ), 63 | chat.config.gpt.num_vq, 64 | prompt=( 65 | chat.speaker.decode_prompt(params.spk_smp) 66 | if params.spk_smp is not None 67 | else None 68 | ), 69 | device=chat.device_gpt, 70 | ) 71 | with torch.inference_mode(): 72 | start_idx, end_idx = 0, torch.zeros( 73 | input_ids.shape[0], device=input_ids.device, dtype=torch.long 74 | ).fill_(input_ids.shape[1]) 75 | 76 | recoded_text = chat.tokenizer.decode( 77 | chat.gpt._prepare_generation_outputs( 78 | input_ids, 79 | start_idx, 80 | end_idx, 81 | [], 82 | [], 83 | True, 84 | ).ids 85 | ) 86 | 87 | if ( 88 | recoded_text[0] 89 | != "[Stts] [spk_emb] [speed_5] what is [uv_break] your favorite english food? [laugh] [lbreak] [Ptts]" 90 | ): 91 | fail = True 92 | logger.warning("recoded text is '%s'", refined_text) 93 | 94 | if fail: 95 | import sys 96 | 97 | sys.exit(1) 98 | -------------------------------------------------------------------------------- /tests/testall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | exitcode=0 4 | 5 | for file in tests/*.py 6 | do 7 | echo "Testing $file..." 8 | python "$file" 9 | if [ $? -ne 0 ] 10 | then 11 | echo "Error: $file exited with a non-zero status." 12 | exitcode=1 13 | fi 14 | echo "Test $file success" 15 | done 16 | 17 | exit $exitcode 18 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/2noise/ChatTTS/1092c1ffcaa82f4bdef104c35aa5541227c3e1d7/tools/__init__.py -------------------------------------------------------------------------------- /tools/audio/__init__.py: -------------------------------------------------------------------------------- 1 | from .av import load_audio 2 | from .pcm import pcm_arr_to_mp3_view, pcm_arr_to_ogg_view, pcm_arr_to_wav_view 3 | from .ffmpeg import has_ffmpeg_installed 4 | from .np import float_to_int16 5 | -------------------------------------------------------------------------------- /tools/audio/av.py: -------------------------------------------------------------------------------- 1 | from io import BufferedWriter, BytesIO 2 | from pathlib import Path 3 | from typing import Dict, Tuple, Optional, Union, List 4 | 5 | import av 6 | from av.audio.frame import AudioFrame 7 | from av.audio.resampler import AudioResampler 8 | import numpy as np 9 | 10 | 11 | video_format_dict: Dict[str, str] = { 12 | "m4a": "mp4", 13 | } 14 | 15 | audio_format_dict: Dict[str, str] = { 16 | "ogg": "libvorbis", 17 | "mp4": "aac", 18 | } 19 | 20 | 21 | def wav2(i: BytesIO, o: BufferedWriter, format: str): 22 | """ 23 | https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L20 24 | """ 25 | inp = av.open(i, "r") 26 | format = video_format_dict.get(format, format) 27 | out = av.open(o, "w", format=format) 28 | format = audio_format_dict.get(format, format) 29 | 30 | ostream = out.add_stream(format) 31 | 32 | for frame in inp.decode(audio=0): 33 | for p in ostream.encode(frame): 34 | out.mux(p) 35 | 36 | for p in ostream.encode(None): 37 | out.mux(p) 38 | 39 | out.close() 40 | inp.close() 41 | 42 | 43 | def load_audio( 44 | file: Union[str, BytesIO, Path], 45 | sr: Optional[int] = None, 46 | format: Optional[str] = None, 47 | mono=True, 48 | ) -> Union[np.ndarray, Tuple[np.ndarray, int]]: 49 | """ 50 | https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39 51 | """ 52 | if (isinstance(file, str) and not Path(file).exists()) or ( 53 | isinstance(file, Path) and not file.exists() 54 | ): 55 | raise FileNotFoundError(f"File not found: {file}") 56 | rate = 0 57 | 58 | container = av.open(file, format=format) 59 | audio_stream = next(s for s in container.streams if s.type == "audio") 60 | channels = 1 if audio_stream.layout == "mono" else 2 61 | container.seek(0) 62 | resampler = ( 63 | AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr) 64 | if sr is not None 65 | else None 66 | ) 67 | 68 | # Estimated maximum total number of samples to pre-allocate the array 69 | # AV stores length in microseconds by default 70 | estimated_total_samples = ( 71 | int(container.duration * sr // 1_000_000) if sr is not None else 48000 72 | ) 73 | decoded_audio = np.zeros( 74 | ( 75 | estimated_total_samples + 1 76 | if channels == 1 77 | else (channels, estimated_total_samples + 1) 78 | ), 79 | dtype=np.float32, 80 | ) 81 | 82 | offset = 0 83 | 84 | def process_packet(packet: List[AudioFrame]): 85 | frames_data = [] 86 | rate = 0 87 | for frame in packet: 88 | # frame.pts = None # 清除时间戳,避免重新采样问题 89 | resampled_frames = ( 90 | resampler.resample(frame) if resampler is not None else [frame] 91 | ) 92 | for resampled_frame in resampled_frames: 93 | frame_data = resampled_frame.to_ndarray() 94 | rate = resampled_frame.rate 95 | frames_data.append(frame_data) 96 | return (rate, frames_data) 97 | 98 | def frame_iter(container): 99 | for p in container.demux(container.streams.audio[0]): 100 | yield p.decode() 101 | 102 | for r, frames_data in map(process_packet, frame_iter(container)): 103 | if not rate: 104 | rate = r 105 | for frame_data in frames_data: 106 | end_index = offset + len(frame_data[0]) 107 | 108 | # 检查 decoded_audio 是否有足够的空间,并在必要时调整大小 109 | if end_index > decoded_audio.shape[1]: 110 | decoded_audio = np.resize( 111 | decoded_audio, (decoded_audio.shape[0], end_index * 4) 112 | ) 113 | 114 | np.copyto(decoded_audio[..., offset:end_index], frame_data) 115 | offset += len(frame_data[0]) 116 | 117 | container.close() 118 | 119 | # Truncate the array to the actual size 120 | decoded_audio = decoded_audio[..., :offset] 121 | 122 | if mono and decoded_audio.shape[0] > 1: 123 | decoded_audio = decoded_audio.mean(0) 124 | 125 | if sr is not None: 126 | return decoded_audio 127 | return decoded_audio, rate 128 | -------------------------------------------------------------------------------- /tools/audio/ffmpeg.py: -------------------------------------------------------------------------------- 1 | from pydub.utils import which 2 | 3 | 4 | def has_ffmpeg_installed() -> bool: 5 | return which("ffmpeg") and which("ffprobe") 6 | -------------------------------------------------------------------------------- /tools/audio/np.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | from numba import jit 5 | 6 | 7 | @jit(nopython=True) 8 | def float_to_int16(audio: np.ndarray) -> np.ndarray: 9 | am = int(math.ceil(float(np.abs(audio).max())) * 32768) 10 | am = 32767 * 32768 // am 11 | return np.multiply(audio, am).astype(np.int16) 12 | -------------------------------------------------------------------------------- /tools/audio/pcm.py: -------------------------------------------------------------------------------- 1 | import wave 2 | from io import BytesIO 3 | import numpy as np 4 | from .np import float_to_int16 5 | from .av import wav2 6 | 7 | 8 | def _pcm_to_wav_buffer(wav: np.ndarray, sample_rate: int = 24000) -> BytesIO: 9 | """ 10 | Convert PCM audio data to a WAV format byte stream (internal utility function). 11 | 12 | :param wav: PCM data, NumPy array, typically in float32 format. 13 | :param sample_rate: Sample rate (in Hz), defaults to 24000. 14 | :return: WAV format byte stream, stored in a BytesIO object. 15 | """ 16 | # Create an in-memory byte stream buffer 17 | buf = BytesIO() 18 | 19 | # Open a WAV file stream in write mode 20 | with wave.open(buf, "wb") as wf: 21 | # Set number of channels to 1 (mono) 22 | wf.setnchannels(1) 23 | # Set sample width to 2 bytes (16-bit) 24 | wf.setsampwidth(2) 25 | # Set sample rate 26 | wf.setframerate(sample_rate) 27 | # Convert PCM to 16-bit integer and write 28 | wf.writeframes(float_to_int16(wav)) 29 | 30 | # Reset buffer pointer to the beginning 31 | buf.seek(0, 0) 32 | return buf 33 | 34 | 35 | def pcm_arr_to_mp3_view(wav: np.ndarray, sample_rate: int = 24000) -> memoryview: 36 | """ 37 | Convert PCM audio data to MP3 format. 38 | 39 | :param wav: PCM data, NumPy array, typically in float32 format. 40 | :param sample_rate: Sample rate (in Hz), defaults to 24000. 41 | :return: MP3 format byte data, returned as a memoryview. 42 | """ 43 | # Get WAV format byte stream 44 | buf = _pcm_to_wav_buffer(wav, sample_rate) 45 | 46 | # Create output buffer 47 | buf2 = BytesIO() 48 | # Convert WAV data to MP3 49 | wav2(buf, buf2, "mp3") 50 | # Return MP3 data 51 | return buf2.getbuffer() 52 | 53 | 54 | def pcm_arr_to_ogg_view(wav: np.ndarray, sample_rate: int = 24000) -> memoryview: 55 | """ 56 | Convert PCM audio data to OGG format (using Vorbis encoding). 57 | 58 | :param wav: PCM data, NumPy array, typically in float32 format. 59 | :param sample_rate: Sample rate (in Hz), defaults to 24000. 60 | :return: OGG format byte data, returned as a memoryview. 61 | """ 62 | # Get WAV format byte stream 63 | buf = _pcm_to_wav_buffer(wav, sample_rate) 64 | 65 | # Create output buffer 66 | buf2 = BytesIO() 67 | # Convert WAV data to OGG 68 | wav2(buf, buf2, "ogg") 69 | # Return OGG data 70 | return buf2.getbuffer() 71 | 72 | 73 | def pcm_arr_to_wav_view( 74 | wav: np.ndarray, sample_rate: int = 24000, include_header: bool = True 75 | ) -> memoryview: 76 | """ 77 | Convert PCM audio data to WAV format, with an option to include header. 78 | 79 | :param wav: PCM data, NumPy array, typically in float32 format. 80 | :param sample_rate: Sample rate (in Hz), defaults to 24000. 81 | :param include_header: Whether to include WAV header, defaults to True. 82 | :return: WAV format or raw PCM byte data, returned as a memoryview. 83 | """ 84 | if include_header: 85 | # Get complete WAV byte stream 86 | buf = _pcm_to_wav_buffer(wav, sample_rate) 87 | return buf.getbuffer() 88 | else: 89 | # Return only converted 16-bit PCM data 90 | pcm_data = float_to_int16(wav) 91 | return memoryview(pcm_data.tobytes()) 92 | -------------------------------------------------------------------------------- /tools/checksum/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "crypto/sha256" 5 | "encoding/hex" 6 | "fmt" 7 | "io" 8 | "os" 9 | ) 10 | 11 | func main() { 12 | var buf [32]byte 13 | h := sha256.New() 14 | lst := make([]any, 0, 64) 15 | for _, fname := range files { 16 | f, err := os.Open(fname) 17 | if err != nil { 18 | panic(err) 19 | } 20 | _, err = io.Copy(h, f) 21 | if err != nil { 22 | panic(err) 23 | } 24 | s := hex.EncodeToString(h.Sum(buf[:0])) 25 | fmt.Println("sha256 of", fname, "=", s) 26 | lst = append(lst, s) 27 | h.Reset() 28 | f.Close() 29 | } 30 | f, err := os.Create("ChatTTS/res/sha256_map.json") 31 | if err != nil { 32 | panic(err) 33 | } 34 | _, err = fmt.Fprintf(f, jsontmpl, lst...) 35 | if err != nil { 36 | panic(err) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /tools/checksum/tmpl.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | var files = [...]string{ 4 | "asset/Decoder.safetensors", 5 | "asset/DVAE.safetensors", 6 | "asset/Embed.safetensors", 7 | "asset/Vocos.safetensors", 8 | 9 | "asset/gpt/config.json", 10 | "asset/gpt/model.safetensors", 11 | 12 | "asset/tokenizer/special_tokens_map.json", 13 | "asset/tokenizer/tokenizer_config.json", 14 | "asset/tokenizer/tokenizer.json", 15 | } 16 | 17 | const jsontmpl = `{ 18 | "sha256_asset_Decoder_safetensors": "%s", 19 | "sha256_asset_DVAE_safetensors" : "%s", 20 | "sha256_asset_Embed_safetensors" : "%s", 21 | "sha256_asset_Vocos_safetensors" : "%s", 22 | 23 | "sha256_asset_gpt_config_json" : "%s", 24 | "sha256_asset_gpt_model_safetensors" : "%s", 25 | 26 | "sha256_asset_tokenizer_special_tokens_map_json": "%s", 27 | "sha256_asset_tokenizer_tokenizer_config_json" : "%s", 28 | "sha256_asset_tokenizer_tokenizer_json" : "%s" 29 | } 30 | ` 31 | -------------------------------------------------------------------------------- /tools/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm import ChatOpenAI 2 | -------------------------------------------------------------------------------- /tools/llm/llm.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | prompt_dict = { 4 | "kimi": [ 5 | { 6 | "role": "system", 7 | "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。", 8 | }, 9 | { 10 | "role": "user", 11 | "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。", 12 | }, 13 | { 14 | "role": "assistant", 15 | "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。", 16 | }, 17 | ], 18 | "deepseek": [ 19 | {"role": "system", "content": "You are a helpful assistant"}, 20 | { 21 | "role": "user", 22 | "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。", 23 | }, 24 | { 25 | "role": "assistant", 26 | "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。", 27 | }, 28 | ], 29 | "deepseek_TN": [ 30 | {"role": "system", "content": "You are a helpful assistant"}, 31 | { 32 | "role": "user", 33 | "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号", 34 | }, 35 | { 36 | "role": "assistant", 37 | "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入", 38 | }, 39 | {"role": "user", "content": "We paid $123 for this desk."}, 40 | { 41 | "role": "assistant", 42 | "content": "We paid one hundred and twenty three dollars for this desk.", 43 | }, 44 | {"role": "user", "content": "详询请拨打010-724654"}, 45 | {"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"}, 46 | {"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"}, 47 | { 48 | "role": "assistant", 49 | "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。", 50 | }, 51 | ], 52 | } 53 | 54 | 55 | class ChatOpenAI: 56 | def __init__(self, api_key, base_url, model): 57 | self.client = OpenAI( 58 | api_key=api_key, 59 | base_url=base_url, 60 | ) 61 | self.model = model 62 | 63 | def call(self, user_question, temperature=0.3, prompt_version="kimi", **kwargs): 64 | 65 | completion = self.client.chat.completions.create( 66 | model=self.model, 67 | messages=prompt_dict[prompt_version] 68 | + [ 69 | {"role": "user", "content": user_question}, 70 | ], 71 | temperature=temperature, 72 | **kwargs 73 | ) 74 | return completion.choices[0].message.content 75 | -------------------------------------------------------------------------------- /tools/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .log import get_logger 2 | -------------------------------------------------------------------------------- /tools/logger/log.py: -------------------------------------------------------------------------------- 1 | import platform, sys 2 | import logging 3 | from datetime import datetime, timezone 4 | 5 | logging.getLogger("numba").setLevel(logging.WARNING) 6 | logging.getLogger("httpx").setLevel(logging.WARNING) 7 | logging.getLogger("wetext-zh_normalizer").setLevel(logging.WARNING) 8 | logging.getLogger("NeMo-text-processing").setLevel(logging.WARNING) 9 | 10 | # from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96 11 | colorCodePanic = "\x1b[1;31m" 12 | colorCodeFatal = "\x1b[1;31m" 13 | colorCodeError = "\x1b[31m" 14 | colorCodeWarn = "\x1b[33m" 15 | colorCodeInfo = "\x1b[37m" 16 | colorCodeDebug = "\x1b[32m" 17 | colorCodeTrace = "\x1b[36m" 18 | colorReset = "\x1b[0m" 19 | 20 | log_level_color_code = { 21 | logging.DEBUG: colorCodeDebug, 22 | logging.INFO: colorCodeInfo, 23 | logging.WARN: colorCodeWarn, 24 | logging.ERROR: colorCodeError, 25 | logging.FATAL: colorCodeFatal, 26 | } 27 | 28 | log_level_msg_str = { 29 | logging.DEBUG: "DEBU", 30 | logging.INFO: "INFO", 31 | logging.WARN: "WARN", 32 | logging.ERROR: "ERRO", 33 | logging.FATAL: "FATL", 34 | } 35 | 36 | 37 | class Formatter(logging.Formatter): 38 | def __init__(self, color=platform.system().lower() != "windows"): 39 | # https://stackoverflow.com/questions/2720319/python-figure-out-local-timezone 40 | self.tz = datetime.now(timezone.utc).astimezone().tzinfo 41 | self.color = color 42 | 43 | def format(self, record: logging.LogRecord): 44 | logstr = "[" + datetime.now(self.tz).strftime("%z %Y%m%d %H:%M:%S") + "] [" 45 | if self.color: 46 | logstr += log_level_color_code.get(record.levelno, colorCodeInfo) 47 | logstr += log_level_msg_str.get(record.levelno, record.levelname) 48 | if self.color: 49 | logstr += colorReset 50 | if sys.version_info >= (3, 9): 51 | fn = record.filename.removesuffix(".py") 52 | elif record.filename.endswith(".py"): 53 | fn = record.filename[:-3] 54 | logstr += f"] {str(record.name)} | {fn} | {str(record.msg)%record.args}" 55 | return logstr 56 | 57 | 58 | def get_logger(name: str, lv=logging.INFO, remove_exist=False, format_root=False): 59 | logger = logging.getLogger(name) 60 | logger.setLevel(lv) 61 | if remove_exist and logger.hasHandlers(): 62 | logger.handlers.clear() 63 | if not logger.hasHandlers(): 64 | syslog = logging.StreamHandler() 65 | syslog.setFormatter(Formatter()) 66 | logger.addHandler(syslog) 67 | else: 68 | for h in logger.handlers: 69 | h.setFormatter(Formatter()) 70 | if format_root: 71 | for h in logger.root.handlers: 72 | h.setFormatter(Formatter()) 73 | return logger 74 | -------------------------------------------------------------------------------- /tools/normalizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .en import normalizer_en_nemo_text 2 | from .zh import normalizer_zh_tn 3 | -------------------------------------------------------------------------------- /tools/normalizer/en.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from functools import partial 3 | 4 | 5 | def normalizer_en_nemo_text() -> Callable[[str], str]: 6 | from nemo_text_processing.text_normalization.normalize import Normalizer 7 | 8 | return partial( 9 | Normalizer(input_case="cased", lang="en").normalize, 10 | verbose=False, 11 | punct_post_process=True, 12 | ) 13 | -------------------------------------------------------------------------------- /tools/normalizer/zh.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | 4 | def normalizer_zh_tn() -> Callable[[str], str]: 5 | from tn.chinese.normalizer import Normalizer 6 | 7 | return Normalizer(remove_interjections=False).normalize 8 | -------------------------------------------------------------------------------- /tools/seeder/__init__.py: -------------------------------------------------------------------------------- 1 | from .ctx import TorchSeedContext 2 | -------------------------------------------------------------------------------- /tools/seeder/ctx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TorchSeedContext: 5 | def __init__(self, seed): 6 | self.seed = seed 7 | self.state = None 8 | 9 | def __enter__(self): 10 | self.state = torch.random.get_rng_state() 11 | torch.manual_seed(self.seed) 12 | 13 | def __exit__(self, type, value, traceback): 14 | torch.random.set_rng_state(self.state) 15 | --------------------------------------------------------------------------------