├── .gitignore ├── .gitmodules ├── Danmaku.py ├── LICENSE ├── app.py ├── app_utils.py ├── audio_test ├── test.wav ├── virtual_audio_devices_test.py └── vits_test.py ├── backup └── live_comment.py ├── chatgpt_test ├── chatgpt_split_sentence.py └── chatgpt_test.py ├── danmaku_test ├── bilibili_api_test.py └── live_comment_test.py ├── docs ├── README.md ├── demo_1.0.png ├── expression.jpg └── request_songs.jpg ├── emotion_test └── emotion_detection.py ├── memory_test └── vits_memory_test.py ├── misc └── server_test.py ├── multiprocessing_test └── audio │ ├── audio_mp_test.py │ ├── bgm.WAV │ ├── bgm_1.WAV │ ├── global_state.py │ ├── speech.wav │ ├── vox.WAV │ └── vox_1.WAV ├── prompt_hot_update.py ├── song_singer.py ├── songs.txt ├── subtitle.py ├── system_message_manager.py ├── system_messages └── sm_main.txt ├── vits ├── README.md ├── app.py ├── app_playwright_cai.py ├── attentions.py ├── commons.py ├── mel_processing.py ├── model │ ├── Download Link.txt │ └── config.json ├── models.py ├── modules.py ├── monotonic_align │ └── monotonic_align │ │ └── core.cp38-win_amd64.pyd ├── requirements.txt ├── text │ ├── LICENSE │ ├── __init__.py │ ├── cleaners.py │ └── symbols.py ├── transforms.py └── utils.py ├── vts_api_test ├── vts_api_mp_test.py └── vts_api_test.py └── vts_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | /.vscode/ 131 | /vits/model/*.pth 132 | /token.txt 133 | /vts_api_test/token.txt 134 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "submodules/blivedm"] 2 | path = submodules/blivedm 3 | url = https://github.com/xfgryujk/blivedm.git 4 | -------------------------------------------------------------------------------- /Danmaku.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import ctypes 3 | import asyncio 4 | import multiprocessing 5 | 6 | import http.cookies 7 | 8 | import aiohttp 9 | 10 | import submodules.blivedm.blivedm as blivedm 11 | import submodules.blivedm.blivedm.models.web as web_models 12 | 13 | from threading import Timer 14 | 15 | from app_utils import * 16 | 17 | class DanmakuProcess(multiprocessing.Process): 18 | def __init__(self, room_id, greeting_queue, chat_queue, thanks_queue, app_state, event_stop): 19 | super().__init__() 20 | 21 | self.room_id = room_id 22 | self.event_stop = event_stop 23 | self.enable_response = multiprocessing.Value(ctypes.c_bool, True) 24 | 25 | self.handler = ResponseHandler(greeting_queue, chat_queue, thanks_queue, app_state, self.enable_response) 26 | 27 | # https://blog.csdn.net/qq_28821897/article/details/132002110 28 | # 这里填一个已登录账号的cookie。不填cookie也可以连接,但是收到弹幕的用户名会打码,UID会变成0 29 | self.SESSDATA = '' 30 | self.session = None 31 | 32 | async def main(self): 33 | self.init_session() 34 | 35 | proc_name = self.name 36 | print(f"Initializing {proc_name}...") 37 | 38 | self.client = blivedm.BLiveClient(self.room_id, session=self.session) 39 | self.client.set_handler(self.handler) 40 | 41 | self.client.start() 42 | self.task_check_exit = asyncio.create_task(self.check_exit()) 43 | 44 | try: 45 | await self.task_check_exit 46 | except Exception as e: 47 | print(e) 48 | finally: 49 | await self.session.close() 50 | 51 | def init_session(self): 52 | cookies = http.cookies.SimpleCookie() 53 | cookies['SESSDATA'] = self.SESSDATA 54 | cookies['SESSDATA']['domain'] = 'bilibili.com' 55 | 56 | self.session = aiohttp.ClientSession() 57 | self.session.cookie_jar.update_cookies(cookies) 58 | 59 | async def check_exit(self): 60 | while True: 61 | await asyncio.sleep(4) 62 | if self.event_stop.is_set(): 63 | try: 64 | print("DanmakuProcess should exit.") 65 | self.client.stop() 66 | await self.client.join() 67 | except Exception as e: 68 | print(e) 69 | finally: 70 | await self.client.stop_and_close() 71 | break 72 | 73 | def set_response_enabled(self, value): 74 | self.enable_response.value = value 75 | 76 | def is_response_enabled(self): 77 | return self.enable_response.value 78 | 79 | def run(self): 80 | asyncio.run(self.main()) 81 | print(f"{self.name} exits.") 82 | 83 | 84 | class ResponseHandler(blivedm.BaseHandler): 85 | def __init__(self, greeting_queue, chat_queue, thanks_queue, app_state, enable_response) -> None: 86 | super().__init__() 87 | 88 | # self._CMD_CALLBACK_DICT['INTERACT_WORD'] = self.__interact_word_callback 89 | # self._CMD_CALLBACK_DICT['LIKE_INFO_V3_CLICK'] = self.__like_callback 90 | 91 | self.app_state = app_state 92 | self.greeting_queue = greeting_queue 93 | self.chat_queue = chat_queue 94 | self.thanks_queue = thanks_queue 95 | 96 | self.enable_response = enable_response 97 | self.should_thank_gift = True 98 | 99 | # 入场和关注消息回调 100 | async def __interact_word_callback(self, client: blivedm.BLiveClient, command: dict): 101 | user_name = command['data']['uname'] 102 | msg_type = command['data']['msg_type'] 103 | channel = 'default' 104 | 105 | if msg_type == 1: 106 | print(f"{user_name}进场") 107 | 108 | if self.app_state.value == AppState.CHAT: 109 | # msg = f"({user_name}进入了你的直播间。)" 110 | # msg = f"主播好!我是{user_name},来你的直播间了!" 111 | msg = f"主播好!我是{user_name},我来了!" 112 | print(f"[{client.room_id} INTERACT_WORD] {msg}") 113 | 114 | # if self.is_response_enabled(): 115 | # task = ChatTask(user_name, msg, channel) 116 | 117 | # if self.greeting_queue.full(): 118 | # _ = self.greeting_queue.get() 119 | 120 | # self.greeting_queue.put(task) 121 | 122 | elif msg_type == 2: 123 | print(f"{user_name}关注") 124 | if (self.app_state.value == AppState.CHAT or 125 | self.app_state.value == AppState.SING): 126 | # msg = f"({user_name}关注了你的直播间。)" 127 | msg = f"我是{user_name},刚刚关注了你的直播间!" 128 | print(f"[INTERACT_WORD] {msg}") 129 | 130 | if self.enable_response.value: 131 | task = ChatTask(user_name, msg, channel) 132 | 133 | if self.thanks_queue.full(): 134 | _ = self.thanks_queue.get() 135 | 136 | self.thanks_queue.put(task) 137 | 138 | 139 | # 点赞消息回调 140 | async def __like_callback(self, client: blivedm.BLiveClient, command: dict): 141 | user_name = command['data']['uname'] 142 | print(f"{user_name}点赞") 143 | print(f"[LIKE] {user_name}") 144 | 145 | channel = 'default' 146 | # msg = f"我是{user_name},刚刚在你的直播间点了赞哦!" 147 | msg = f"我是{user_name},给你点赞!" 148 | if self.enable_response.value: 149 | task = ChatTask(user_name, msg, channel) 150 | 151 | if self.thanks_queue.full(): 152 | _ = self.thanks_queue.get() 153 | 154 | self.thanks_queue.put(task) 155 | 156 | def _on_danmaku(self, client: blivedm.BLiveClient, message: web_models.DanmakuMessage): 157 | user_name = message.uname 158 | msg = message.msg 159 | 160 | print(f'[{client.room_id} DANMU] {user_name}:{msg}') 161 | if self.app_state.value == AppState.CHAT: 162 | channel = 'chat' 163 | if self.enable_response.value: 164 | if self.chat_queue.full(): 165 | _ = self.chat_queue.get() 166 | 167 | task = ChatTask(user_name, msg, channel) 168 | self.chat_queue.put(task) 169 | 170 | async def _on_gift(self, client: blivedm.BLiveClient, message: web_models.GiftMessage): 171 | user_name = message.uname 172 | gift_name = message.gift_name 173 | gift_num = message.num 174 | 175 | print(f'[{client.room_id} GIFT] {user_name} 赠送{gift_name}x{gift_num}' 176 | f' ({message.coin_type}瓜子x{message.total_coin})') 177 | 178 | if (self.app_state.value == AppState.CHAT or 179 | self.app_state.value == AppState.SING): 180 | 181 | channel = 'default' 182 | 183 | # msg = f"({user_name}投喂了{gift_num}个{gift_name}礼物给你。)" 184 | msg = f"我是{user_name},刚刚投喂了{gift_num}个{gift_name}礼物给你!" 185 | if self.enable_response.value: 186 | task = ChatTask(user_name, msg, channel) 187 | 188 | def set_should_thank_gift(): 189 | print("set_should_thank_gift is triggered!") 190 | self.should_thank_gift = True 191 | 192 | if self.should_thank_gift: 193 | if self.thanks_queue.full(): 194 | _ = self.thanks_queue.get() 195 | 196 | self.thanks_queue.put(task) 197 | self.should_thank_gift = False 198 | 199 | t = Timer(10.0, set_should_thank_gift) 200 | t.start() 201 | 202 | # async def _on_buy_guard(self, client: blivedm.BLiveClient, message: blivedm.GuardBuyMessage): 203 | # print(f'[{client.room_id}] {message.username} 购买{message.gift_name}') 204 | 205 | # async def _on_super_chat(self, client: blivedm.BLiveClient, message: blivedm.SuperChatMessage): 206 | # print(f'[{client.room_id}] 醒目留言 ¥{message.price} {message.uname}:{message.message}') -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 MeowMeowWithWind 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /app_utils.py: -------------------------------------------------------------------------------- 1 | class AppState: 2 | CHAT = 1 3 | PRESING = 2 4 | SING = 3 5 | 6 | class ChatTask: 7 | def __init__(self, user_name, message, channel): 8 | self.user_name = user_name 9 | self.message = message 10 | self.channel = channel 11 | 12 | def clear_queue(queue): 13 | while not queue.empty(): 14 | _ = queue.get() -------------------------------------------------------------------------------- /audio_test/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteeat/ai-vtuber-alpha/89dbe3e199c6f3c094054c0babaece1050409e1a/audio_test/test.wav -------------------------------------------------------------------------------- /audio_test/virtual_audio_devices_test.py: -------------------------------------------------------------------------------- 1 | import wave 2 | 3 | import pyaudio 4 | 5 | p = pyaudio.PyAudio() 6 | 7 | # https://vb-audio.com/Cable/ 8 | 9 | PRINT_DEVICES = True 10 | 11 | if PRINT_DEVICES: 12 | 13 | def write_dict(f, dictionary): 14 | for key, value in dictionary.items(): 15 | f.write(f"{key}:{value}\n") 16 | 17 | output = "Devices Info.txt" 18 | with open(output, 'w', encoding='utf-8') as f: 19 | print("Default Devices:") 20 | f.write("Default Devices:") 21 | print(p.get_default_host_api_info()) 22 | write_dict(f, p.get_default_host_api_info()) 23 | 24 | print(p.get_default_input_device_info()) 25 | write_dict(f, p.get_default_input_device_info()) 26 | 27 | print(p.get_default_output_device_info()) 28 | write_dict(f, p.get_default_output_device_info()) 29 | 30 | print("All Devices:") 31 | for i in range(p.get_device_count()): 32 | print(p.get_device_info_by_index(i)) 33 | write_dict(f, p.get_device_info_by_index(i)) 34 | 35 | virtual_audio_input_device_index = None 36 | virtual_audio_output_device_index = None 37 | 38 | # Search for valid virtual audio input and output devices 39 | for i in range(p.get_device_count()): 40 | device_info = p.get_device_info_by_index(i) 41 | if ("CABLE Output" in device_info['name'] and 42 | device_info['hostApi'] == 0): 43 | assert device_info['index'] == i 44 | virtual_audio_input_device_index = i 45 | 46 | if ("CABLE Input" in device_info['name'] and 47 | device_info['hostApi'] == 0): 48 | assert device_info['index'] == i 49 | virtual_audio_output_device_index = i 50 | 51 | if (virtual_audio_input_device_index is None or 52 | virtual_audio_output_device_index is None): 53 | print("Error: no valid virtual audio devices found") 54 | exit() 55 | 56 | CHUNK = 1024 57 | 58 | with wave.open("test.wav", 'rb') as wf: 59 | # Open stream (2) 60 | stream = p.open(format=p.get_format_from_width(wf.getsampwidth()), 61 | channels=wf.getnchannels(), 62 | rate=wf.getframerate(), 63 | output=True, 64 | input_device_index=virtual_audio_input_device_index, 65 | output_device_index=virtual_audio_output_device_index) 66 | 67 | # Play samples from the wave file (3) 68 | while len(data := wf.readframes(CHUNK)): # Requires Python 3.8+ for := 69 | stream.write(data) 70 | 71 | # Close stream (4) 72 | stream.close() 73 | 74 | with wave.open("test.wav", 'rb') as wf: 75 | stream = p.open(format=p.get_format_from_width(wf.getsampwidth()), 76 | channels=wf.getnchannels(), 77 | rate=wf.getframerate(), 78 | output=True) 79 | 80 | while len(data := wf.readframes(CHUNK)): 81 | stream.write(data) 82 | 83 | stream.close() 84 | 85 | # Release PortAudio system resources (5) 86 | p.terminate() 87 | -------------------------------------------------------------------------------- /audio_test/vits_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | import numpy as np 6 | 7 | from torch import no_grad, LongTensor 8 | from torch import device as torch_device 9 | 10 | import wave 11 | import pyaudio 12 | 13 | dir_path = os.path.dirname(os.path.realpath(__file__)) 14 | 15 | # Get the parent directory 16 | parent_dir = os.path.dirname(dir_path) 17 | print(parent_dir) 18 | vits_dir = os.path.join(parent_dir, 'vits') 19 | print(vits_dir) 20 | 21 | # sys.path.append(vits_dir) 22 | sys.path.insert(0, vits_dir) 23 | print(sys.path) 24 | 25 | import utils 26 | import commons as commons 27 | from models import SynthesizerTrn 28 | from text import text_to_sequence 29 | 30 | class VITSWrapper: 31 | def __init__(self): 32 | # device = torch_device('cpu') 33 | self.device = torch_device('cuda') 34 | 35 | self.hps_ms = utils.get_hparams_from_file('../vits/model/config.json') 36 | speakers = self.hps_ms.speakers 37 | 38 | with no_grad(): 39 | self.net_g_ms = SynthesizerTrn( 40 | len(self.hps_ms.symbols), 41 | self.hps_ms.data.filter_length // 2 + 1, 42 | self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, 43 | n_speakers=self.hps_ms.data.n_speakers, 44 | **self.hps_ms.model).to(self.device) 45 | _ = self.net_g_ms.eval() 46 | model, optimizer, learning_rate, epochs = utils.load_checkpoint('../vits/model/G_953000.pth', 47 | self.net_g_ms, None) 48 | 49 | def get_text(self, text, hps): 50 | text_norm, clean_text = text_to_sequence(text, hps.symbols, hps.data.text_cleaners) 51 | if hps.data.add_blank: 52 | text_norm = commons.intersperse(text_norm, 0) 53 | text_norm = LongTensor(text_norm) 54 | return text_norm, clean_text 55 | 56 | def vits(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale): 57 | if not len(text): 58 | return "输入文本不能为空!", None, None 59 | text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") 60 | # if len(text) > 100: 61 | # return f"输入文字过长!{len(text)}>100", None, None 62 | if language == 0: 63 | text = f"[ZH]{text}[ZH]" 64 | elif language == 1: 65 | text = f"[JA]{text}[JA]" 66 | else: 67 | text = f"{text}" 68 | stn_tst, clean_text = self.get_text(text, self.hps_ms) 69 | 70 | start = time.perf_counter() 71 | with no_grad(): 72 | x_tst = stn_tst.unsqueeze(0).to(self.device) 73 | x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device) 74 | speaker_id = LongTensor([speaker_id]).to(self.device) 75 | 76 | audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, 77 | noise_scale_w=noise_scale_w, 78 | length_scale=length_scale)[0][0, 0].data.cpu().float().numpy() 79 | 80 | print(f"The inference takes {time.perf_counter() - start} seconds") 81 | 82 | return audio 83 | 84 | # By ChatGPT 85 | def normalize_audio(audio_data): 86 | # Calculate the maximum absolute value in the audio data 87 | max_value = np.max(np.abs(audio_data)) 88 | 89 | # Normalize the audio data by dividing it by the maximum value 90 | normalized_data = audio_data / max_value 91 | 92 | return normalized_data 93 | 94 | # def normalize_audio(audio_data): 95 | # # Calculate the mean and standard deviation of the audio data 96 | # mean = np.mean(audio_data) 97 | # std = np.std(audio_data) 98 | 99 | # # Normalize the audio data using z-score normalization 100 | # normalized_data = (audio_data - mean) / std 101 | 102 | # return normalized_data 103 | 104 | 105 | if __name__ == '__main__': 106 | text = "一马当先,万马牡蛎!" 107 | 108 | py_audio = pyaudio.PyAudio() 109 | 110 | wf = wave.open("test.wav", 'rb') 111 | 112 | sample_width = wf.getsampwidth() 113 | print(f"sample_width: {sample_width}") 114 | 115 | format = py_audio.get_format_from_width(sample_width) 116 | print(f"format: {format}") 117 | 118 | num_channels = wf.getnchannels() 119 | print(f"num_channels: {num_channels}") 120 | 121 | frame_rate = wf.getframerate() 122 | print(f"frame_rate: {frame_rate}") 123 | 124 | data_from_file = wf.readframes(wf.getnframes()) 125 | 126 | wf.close() 127 | 128 | vits_wrapper = VITSWrapper() 129 | 130 | # https://stackoverflow.com/questions/59463040/how-can-i-convert-a-numpy-array-wav-data-to-int16-with-python 131 | # pyaudio.paFloat32 132 | # np.float32 133 | 134 | stream = py_audio.open(format=format, 135 | channels=num_channels, 136 | rate=frame_rate, 137 | output=True) 138 | 139 | stream_float32 = py_audio.open(format=pyaudio.paFloat32, 140 | channels=1, 141 | rate=22050, 142 | output=True) 143 | 144 | audio = vits_wrapper.vits(text, 0, 2, 0.5, 0.5, 1.0) 145 | print(audio.dtype) 146 | audio_x2 = audio * 2 147 | data = audio_x2.tobytes() 148 | 149 | # https://stackoverflow.com/questions/59463040/how-can-i-convert-a-numpy-array-wav-data-to-int16-with-python 150 | data_int16 = (audio * 32767).astype(np.int16).tobytes() 151 | 152 | blah = np.array([32767.9]) 153 | blah = blah.astype(np.int16) 154 | print(blah) 155 | 156 | data_norm = normalize_audio(audio).tobytes() 157 | 158 | while True: 159 | user_input = input("Press Enter to continue\n") 160 | if user_input == "esc": 161 | break 162 | 163 | stream.write(data_from_file) 164 | stream_float32.write(data) # Explosion noise 165 | stream.write(data_int16) 166 | 167 | stream_float32.write(data_norm) 168 | 169 | stream.close() 170 | stream_float32.close() 171 | py_audio.terminate() -------------------------------------------------------------------------------- /backup/live_comment.py: -------------------------------------------------------------------------------- 1 | class LiveCommentProcess(multiprocessing.Process): 2 | def __init__(self, room_id, greeting_queue, chat_queue, thanks_queue, app_state, event_initialized, event_stop): 3 | super().__init__() 4 | self.room_id = room_id 5 | 6 | self.greeting_queue = greeting_queue 7 | self.chat_queue = chat_queue 8 | self.thanks_queue = thanks_queue 9 | 10 | self.event_initialized = event_initialized 11 | self.event_stop = event_stop 12 | self.app_state = app_state 13 | 14 | self.enable_response = multiprocessing.Value(ctypes.c_bool, False) 15 | 16 | def set_response_enabled(self, value): 17 | self.enable_response.value = value 18 | 19 | def is_response_enabled(self): 20 | return self.enable_response.value 21 | 22 | async def startup(self, room_id): 23 | # https://blog.csdn.net/Sharp486/article/details/122466308 24 | remote = 'ws://broadcastlv.chat.bilibili.com:2244/sub' 25 | 26 | data_raw = '000000{headerLen}0010000100000007000000017b22726f6f6d6964223a{roomid}7d' 27 | data_raw = data_raw.format(headerLen=hex(27 + len(room_id))[2:], 28 | roomid=''.join(map(lambda x: hex(ord(x))[2:], list(room_id)))) 29 | 30 | async with AioWebSocket(remote) as aws: 31 | converse = aws.manipulator 32 | await converse.send(bytes.fromhex(data_raw)) 33 | task_recv = asyncio.create_task(self.recvDM(converse)) 34 | task_heart_beat = asyncio.create_task(self.sendHeartBeat(converse)) 35 | tasks = [task_recv, task_heart_beat] 36 | await asyncio.wait(tasks) 37 | 38 | async def sendHeartBeat(self, websocket): 39 | hb = '00 00 00 10 00 10 00 01 00 00 00 02 00 00 00 01' 40 | 41 | while True: 42 | await asyncio.sleep(30) 43 | await websocket.send(bytes.fromhex(hb)) 44 | print('[Notice] Sent HeartBeat.') 45 | 46 | if self.event_stop.is_set(): 47 | print("sendHeartBeat ends.") 48 | break 49 | 50 | async def recvDM(self, websocket): 51 | while True: 52 | recv_text = await websocket.receive() 53 | 54 | if recv_text == None: 55 | recv_text = b'\x00\x00\x00\x1a\x00\x10\x00\x01\x00\x00\x00\x08\x00\x00\x00\x01{"code":0}' 56 | 57 | # if self.app_state.value == AppState.CHAT: 58 | self.processDM(recv_text) 59 | 60 | if self.event_stop.is_set(): 61 | print("recvDM ends.") 62 | break 63 | 64 | def processDM(self, data): 65 | # 获取数据包的长度,版本和操作类型 66 | packetLen = int(data[:4].hex(), 16) 67 | ver = int(data[6:8].hex(), 16) 68 | op = int(data[8:12].hex(), 16) 69 | 70 | # 有的时候可能会两个数据包连在一起发过来,所以利用前面的数据包长度判断, 71 | if (len(data) > packetLen): 72 | self.processDM(data[packetLen:]) 73 | data = data[:packetLen] 74 | 75 | # 有时会发送过来 zlib 压缩的数据包,这个时候要去解压。 76 | if (ver == 2): 77 | data = zlib.decompress(data[16:]) 78 | self.processDM(data) 79 | return 80 | 81 | # ver 为1的时候为进入房间后或心跳包服务器的回应。op 为3的时候为房间的人气值。 82 | if (ver == 1): 83 | if (op == 3): 84 | print('[RENQI] {}'.format(int(data[16:].hex(), 16))) 85 | return 86 | 87 | # ver 不为2也不为1目前就只能是0了,也就是普通的 json 数据。 88 | # op 为5意味着这是通知消息,cmd 基本就那几个了。 89 | if (op == 5): 90 | try: 91 | jd = json.loads(data[16:].decode('utf-8', errors='ignore')) 92 | 93 | print(f"jd['cmd'] is: {jd['cmd']}") 94 | if (jd['cmd'] == 'DANMU_MSG'): 95 | if self.app_state.value == AppState.CHAT: 96 | user_name = jd['info'][2][1] 97 | msg = jd['info'][1] 98 | print('[DANMU] ', user_name, ': ', msg) 99 | 100 | channel = 'chat' 101 | if self.is_response_enabled(): 102 | if self.chat_queue.full(): 103 | _ = self.chat_queue.get() 104 | 105 | task = ChatTask(user_name, msg, channel) 106 | self.chat_queue.put(task) 107 | 108 | elif (jd['cmd'] == 'SEND_GIFT'): 109 | if (self.app_state.value == AppState.CHAT or 110 | self.app_state.value == AppState.SING): 111 | print('[GITT]', jd['data']['uname'], ' ', jd['data']['action'], ' ', jd['data']['num'], 'x', 112 | jd['data']['giftName']) 113 | user_name = jd['data']['uname'] 114 | gift_num = jd['data']['num'] 115 | gift_name = jd['data']['giftName'] 116 | channel = 'default' 117 | 118 | # msg = f"({user_name}投喂了{gift_num}个{gift_name}礼物给你。)" 119 | msg = f"我是{user_name},刚刚投喂了{gift_num}个{gift_name}礼物给你!" 120 | if self.is_response_enabled(): 121 | task = ChatTask(user_name, msg, channel) 122 | 123 | if self.thanks_queue.full(): 124 | _ = self.thanks_queue.get() 125 | 126 | self.thanks_queue.put(task) 127 | 128 | elif (jd['cmd'] == 'LIKE_INFO_V3_CLICK'): 129 | user_name = jd['data']['uname'] 130 | print(f"[LIKE] {user_name}") 131 | channel = 'default' 132 | msg = f"我是{user_name},刚刚在你的直播间点了赞哦!" 133 | if self.is_response_enabled(): 134 | task = ChatTask(user_name, msg, channel) 135 | 136 | if self.thanks_queue.full(): 137 | _ = self.thanks_queue.get() 138 | 139 | self.thanks_queue.put(task) 140 | 141 | elif (jd['cmd'] == 'LIVE'): 142 | print('[Notice] LIVE Start!') 143 | elif (jd['cmd'] == 'PREPARING'): 144 | print('[Notice] LIVE Ended!') 145 | elif (jd['cmd'] == 'INTERACT_WORD'): 146 | user_name = jd['data']['uname'] 147 | msg_type = jd['data']['msg_type'] 148 | channel = 'default' 149 | # 进场 150 | if msg_type == 1: 151 | if self.app_state.value == AppState.CHAT: 152 | # msg = f"({user_name}进入了你的直播间。)" 153 | # msg = f"主播好!我是{user_name},来你的直播间了!" 154 | msg = f"主播好!我是{user_name},我来了!" 155 | print(f"[INTERACT_WORD] {msg}") 156 | 157 | # if self.is_response_enabled(): 158 | # task = ChatTask(user_name, msg, channel) 159 | 160 | # if self.greeting_queue.full(): 161 | # _ = self.greeting_queue.get() 162 | 163 | # self.greeting_queue.put(task) 164 | 165 | # 关注 166 | elif msg_type == 2: 167 | if (self.app_state.value == AppState.CHAT or 168 | self.app_state.value == AppState.SING): 169 | # msg = f"({user_name}关注了你的直播间。)" 170 | msg = f"我是{user_name},刚刚关注了你的直播间!" 171 | print(f"[INTERACT_WORD] {msg}") 172 | 173 | if self.is_response_enabled(): 174 | task = ChatTask(user_name, msg, channel) 175 | 176 | if self.thanks_queue.full(): 177 | _ = self.thanks_queue.get() 178 | 179 | self.thanks_queue.put(task) 180 | else: 181 | print('[OTHER] ', jd['cmd']) 182 | except Exception as e: 183 | print(e) 184 | pass 185 | 186 | def run(self): 187 | proc_name = self.name 188 | print(f"Initializing {proc_name}...") 189 | 190 | self.event_initialized.set() 191 | 192 | print(f"{proc_name} is working...") 193 | try: 194 | loop = asyncio.get_event_loop() 195 | loop.run_until_complete(self.startup(self.room_id)) 196 | print(f"{proc_name} exits.") 197 | except Exception as e: 198 | print(e) 199 | print('退出') -------------------------------------------------------------------------------- /chatgpt_test/chatgpt_split_sentence.py: -------------------------------------------------------------------------------- 1 | from revChatGPT.V1 import Chatbot as ChatbotV1 2 | from revChatGPT.V3 import Chatbot as ChatbotV3 3 | 4 | USE_API_KEY = True 5 | USE_ACCESS_TOKEN = False 6 | 7 | # punctuations_to_split_text = set("。!?") 8 | # punctuations_to_split_text_longer = set(",") 9 | 10 | punctuations_min_to_cut= {'。', '!', '?', ':', '\n'} 11 | punctuations_threshold_to_cut = {'。', '!', '?', ':', '\n', ','} 12 | 13 | min_length = 16 14 | threshold_length = 32 15 | 16 | def should_cut_text(text, min, punctuations_min, threshold, punctuations_threshold): 17 | should_cut = False 18 | if len(text) >= min: 19 | if text[-1] in punctuations_min: 20 | should_cut = True 21 | elif len(text) >= threshold: 22 | if text[-1] in punctuations_threshold: 23 | should_cut = True 24 | 25 | return should_cut 26 | 27 | 28 | if USE_API_KEY: 29 | api_key = "" 30 | chatbot = ChatbotV3(api_key) 31 | 32 | prompt = "测试。请随便跟我说一段话,必须大于300字。" 33 | 34 | sentences = [] 35 | new_sentence = "" 36 | length = 0 37 | for data in chatbot.ask_stream(prompt): 38 | print(data, end='|', flush=True) 39 | length += len(data) 40 | # if len(data) > 1: 41 | # print(data) 42 | 43 | new_sentence += data 44 | should_cut = should_cut_text(new_sentence, 45 | min_length, 46 | punctuations_min_to_cut, 47 | threshold_length, 48 | punctuations_threshold_to_cut) 49 | 50 | if should_cut: 51 | sentences.append(new_sentence.strip()) 52 | new_sentence = "" 53 | 54 | if len(new_sentence) > 0: 55 | sentences.append(new_sentence) 56 | 57 | print() 58 | print(length) 59 | print(sentences) 60 | 61 | access_token = "" 62 | 63 | if USE_ACCESS_TOKEN: 64 | chatbot = ChatbotV1(config={ 65 | "access_token": access_token 66 | }) 67 | 68 | prev_message = "" 69 | sentences = [] 70 | prompt_is_skipped = False 71 | new_sentence = "" 72 | for data in chatbot.ask( 73 | prompt 74 | ): 75 | message = data["message"] 76 | new_words = message[len(prev_message):] 77 | print(new_words, end="", flush=True) 78 | 79 | if not prompt_is_skipped: 80 | # The streamed response may contain the prompt, 81 | # So the prompt in the streamed response should be skipped 82 | new_sentence += new_words 83 | if new_sentence == prompt[:len(new_sentence)]: 84 | continue 85 | else: 86 | prompt_is_skipped = True 87 | new_sentence = "" 88 | 89 | new_sentence += new_words 90 | 91 | should_cut = should_cut_text(new_sentence, 92 | min_length, 93 | punctuations_min_to_cut, 94 | threshold_length, 95 | punctuations_threshold_to_cut) 96 | 97 | if should_cut: 98 | sentences.append(new_sentence) 99 | new_sentence = "" 100 | 101 | prev_message = message 102 | 103 | if len(new_sentence) > 0: 104 | sentences.append(new_sentence) 105 | 106 | print() 107 | print(message) 108 | print(sentences) -------------------------------------------------------------------------------- /chatgpt_test/chatgpt_test.py: -------------------------------------------------------------------------------- 1 | from revChatGPT.V1 import Chatbot as ChatGPTV1 2 | from revChatGPT.V3 import Chatbot as ChatGPTV3 3 | 4 | USE_API_KEY = True 5 | USE_ACCESS_TOKEN = False 6 | 7 | prompt_cn = "你好!" 8 | prompt_en = "Hello!" 9 | 10 | if USE_API_KEY: 11 | api_key = "" 12 | chatbot = ChatGPTV3(api_key) 13 | 14 | result = chatbot.ask(prompt_cn) 15 | print(result) 16 | print() 17 | 18 | 19 | if USE_ACCESS_TOKEN: 20 | # The address to get access_token: 21 | # https://chat.openai.com/api/auth/session 22 | access_token = "" 23 | 24 | chatbot = ChatGPTV1(config={ 25 | "access_token": access_token 26 | }) 27 | 28 | # email = "" 29 | # password = "" 30 | 31 | # chatbot = Chatbot(config={ 32 | # "email": email, 33 | # "password": password 34 | # }) 35 | 36 | for data in chatbot.ask( 37 | prompt_cn 38 | ): 39 | response = data["message"] 40 | 41 | print(response) 42 | print() 43 | print("========================================") 44 | 45 | 46 | print("Chinese Test: ") 47 | prev_text = "" 48 | for data in chatbot.ask( 49 | prompt_cn 50 | ): 51 | message = data["message"][len(prev_text):] 52 | print(message, flush=True) 53 | prev_text = data["message"] 54 | 55 | print("English Test: ") 56 | prev_text = "" 57 | for data in chatbot.ask( 58 | prompt_en 59 | ): 60 | message = data["message"][len(prev_text):] 61 | print(message, flush=True) 62 | prev_text = data["message"] -------------------------------------------------------------------------------- /danmaku_test/bilibili_api_test.py: -------------------------------------------------------------------------------- 1 | from bilibili_api import live, sync 2 | 3 | room_id = "" 4 | room = live.LiveDanmaku(room_id) 5 | 6 | @room.on('DANMU_MSG') 7 | async def on_danmaku(event): 8 | # 收到弹幕 9 | # print(event) 10 | msg = event["data"]["info"][1] 11 | print(msg) 12 | 13 | @room.on('SEND_GIFT') 14 | async def on_gift(event): 15 | # 收到礼物 16 | print(event) 17 | 18 | sync(room.connect()) -------------------------------------------------------------------------------- /danmaku_test/live_comment_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | 4 | import asyncio 5 | import zlib 6 | from aiowebsocket.converses import AioWebSocket 7 | import json 8 | 9 | room_id = "14655481" 10 | baseurl_http = "http://api.live.bilibili.com/ajax/msg?roomid=" 11 | baseurl_https = "https://api.live.bilibili.com/xlive/web-room/v1/dM/gethistory" 12 | 13 | 14 | def get_live_comment_http(url): 15 | res = requests.get(url) 16 | content = res.json() 17 | code = content['code'] 18 | print(content) 19 | 20 | if res.status_code == 200 and code == 0: 21 | info_last = content['data']['room'][-1] 22 | name = info_last['nickname'] 23 | timeline = info_last['timeline'].split(' ')[-1] 24 | text = info_last['text'] 25 | msg = timeline + ' ' + name + ':' + text 26 | return msg 27 | 28 | return None 29 | 30 | def get_live_comment_https(url, headers, data): 31 | res = requests.post(url, headers, data) 32 | content = res.json() 33 | code = content['code'] 34 | print(content) 35 | 36 | if res.status_code == 200 and code == 0: 37 | info_last = content['data']['room'][-1] 38 | name = info_last['nickname'] 39 | timeline = info_last['timeline'].split(' ')[-1] 40 | text = info_last['text'] 41 | msg = timeline + ' ' + name + ':' + text 42 | return msg 43 | 44 | return None 45 | 46 | 47 | async def startup(room_id): 48 | # https://blog.csdn.net/Sharp486/article/details/122466308 49 | remote = 'ws://broadcastlv.chat.bilibili.com:2244/sub' 50 | 51 | data_raw = '000000{headerLen}0010000100000007000000017b22726f6f6d6964223a{roomid}7d' 52 | data_raw = data_raw.format(headerLen=hex(27 + len(room_id))[2:], 53 | roomid=''.join(map(lambda x: hex(ord(x))[2:], list(room_id)))) 54 | 55 | async with AioWebSocket(remote) as aws: 56 | converse = aws.manipulator 57 | await converse.send(bytes.fromhex(data_raw)) 58 | task_recv = asyncio.create_task(recvDM(converse)) 59 | task_heart_beat = asyncio.create_task(sendHeartBeat(converse)) 60 | tasks = [task_recv, task_heart_beat] 61 | await asyncio.wait(tasks) 62 | 63 | async def sendHeartBeat(websocket): 64 | hb='00 00 00 10 00 10 00 01 00 00 00 02 00 00 00 01' 65 | 66 | while True: 67 | await asyncio.sleep(30) 68 | await websocket.send(bytes.fromhex(hb)) 69 | print('[Notice] Sent HeartBeat.') 70 | 71 | async def recvDM(websocket): 72 | while True: 73 | recv_text = await websocket.receive() 74 | 75 | if recv_text == None: 76 | recv_text = b'\x00\x00\x00\x1a\x00\x10\x00\x01\x00\x00\x00\x08\x00\x00\x00\x01{"code":0}' 77 | 78 | printDM(recv_text) 79 | 80 | def printDM(data): 81 | # 获取数据包的长度,版本和操作类型 82 | packetLen = int(data[:4].hex(), 16) 83 | ver = int(data[6:8].hex(), 16) 84 | op = int(data[8:12].hex(), 16) 85 | 86 | # 有的时候可能会两个数据包连在一起发过来,所以利用前面的数据包长度判断, 87 | if (len(data) > packetLen): 88 | printDM(data[packetLen:]) 89 | data = data[:packetLen] 90 | 91 | # 有时会发送过来 zlib 压缩的数据包,这个时候要去解压。 92 | if (ver == 2): 93 | data = zlib.decompress(data[16:]) 94 | printDM(data) 95 | return 96 | 97 | # ver 为1的时候为进入房间后或心跳包服务器的回应。op 为3的时候为房间的人气值。 98 | if (ver == 1): 99 | if (op == 3): 100 | print('[RENQI] {}'.format(int(data[16:].hex(), 16))) 101 | return 102 | 103 | 104 | # ver 不为2也不为1目前就只能是0了,也就是普通的 json 数据。 105 | # op 为5意味着这是通知消息,cmd 基本就那几个了。 106 | if (op == 5): 107 | try: 108 | jd = json.loads(data[16:].decode('utf-8', errors='ignore')) 109 | if (jd['cmd'] == 'DANMU_MSG'): 110 | print('[DANMU] ', jd['info'][2][1], ': ', jd['info'][1]) 111 | elif (jd['cmd'] == 'SEND_GIFT'): 112 | print('[GITT]', jd['data']['uname'], ' ', jd['data']['action'], ' ', jd['data']['num'], 'x', 113 | jd['data']['giftName']) 114 | elif (jd['cmd'] == 'LIVE'): 115 | print('[Notice] LIVE Start!') 116 | elif (jd['cmd'] == 'PREPARING'): 117 | print('[Notice] LIVE Ended!') 118 | else: 119 | print('[OTHER] ', jd['cmd']) 120 | except Exception as e: 121 | print(e) 122 | pass 123 | 124 | if __name__ == '__main__': 125 | USE_HTTP = False 126 | USE_HTTPS = False 127 | 128 | while False: 129 | if USE_HTTP: 130 | try: 131 | url_http = baseurl_http + room_id 132 | msg = get_live_comment_http(url_http) 133 | print(msg) 134 | except Exception as e: 135 | print(e) 136 | 137 | elif USE_HTTPS: 138 | try: 139 | headers = { 140 | 'Host': 'api.live.bilibili.com', 141 | "User-Agent": "Mozilla / 5.0(Windows NT 10.0; Win64; x64) AppleWebKit / 537.36(KHTML, like Gecko) Chrome / 80.0.3987.122 Safari / 537.36" 142 | } 143 | 144 | # 传递的参数 145 | data = { 146 | 'roomid': room_id, 147 | 'csrf_token': '', 148 | 'csrf': '', 149 | 'visit_id': '', 150 | } 151 | msg = get_live_comment_https(baseurl_https, headers, data) 152 | print(msg) 153 | except Exception as e: 154 | print(e) 155 | 156 | os.system("pause") 157 | 158 | try: 159 | loop = asyncio.get_event_loop() 160 | loop.run_until_complete(startup(room_id)) 161 | except Exception as e: 162 | print(e) 163 | print('退出') 164 | 165 | 166 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # AI虚拟直播主播 2 | 3 | 一个可在B站直播的AI虚拟直播主播程序展示。项目使用[ChatGPT](https://openai.com/blog/chatgpt)作为AI引擎驱动逻辑, 使用[VITS](https://github.com/jaywalnut310/vits)进行语音合成,使用Live2D做角色表现。 4 | 5 | ## 项目展示 6 | 7 | - [1.0 Demo](https://www.bilibili.com/video/BV13L41197oZ) 8 | ![Demo 1.0 Cover](./demo_1.0.png) 9 | - 2.0 Demo 10 | - [点歌系统](https://www.bilibili.com/video/BV1Rp4y157of) 11 | ![Request songs Cover](./request_songs.jpg) 12 | - [表情系统](https://www.bilibili.com/video/BV1ok4y1A7fb/) 13 | ![Expression](./expression.jpg) 14 | 15 | ## 项目功能和特点 16 | 17 | - 使用ChatGPT作为AI引擎,具体使用了[ChatGPT(revChatGPT)](https://github.com/acheong08/ChatGPT)第三方库 18 | - 使用VITS进行语音合成,具体使用的是崩坏3和马娘数据集训练的中日语言权重。这里是该语音合成模型[Demo](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai) 19 | - 使用Live2D做角色表现 20 | - 使用[VTube Studio API](https://github.com/DenchiSoft/VTubeStudio)驱动角色表情动画,使用ChatGPT获得角色说话感情。具体使用第三方库[pyvts](https://github.com/Genteki/pyvts) 21 | - 整个项目使用多进程并行优化,弹幕拉取、请求ChatGPT服务、声音合成、语音播放以及动画控制全部并行处理,保证角色与观众实时互时的响应速度 22 | - 点歌功能,角色在唱歌途中会答谢观众的点赞和礼物。歌曲曲目使用AI变音技术([Sovits](https://github.com/svc-develop-team/so-vits-svc),[DiffSVC](https://github.com/prophesier/diff-svc)等)制作。 23 | - 简单的字幕界面 24 | 25 | ## 使用方法 26 | 27 | ```python 28 | python app.py 29 | ``` 30 | 31 | ## 贡献者 32 | 33 | ### 主要开发人员 34 | 35 | 烂活儿组: 36 | 37 | - 喵喵抽风 [GitHub主页](https://github.com/whiteeat) [B站主页](https://space.bilibili.com/7627329) 38 | - LeoJk南 [GitHub主页](https://github.com/leojnjn) [B站主页](https://space.bilibili.com/603987001) 39 | - CYMIC [GitHub主页](https://github.com/hellocym) [B站主页](https://space.bilibili.com/88937421) 40 | 41 | ### AI合成歌曲作品贡献者名单 42 | 43 | - CYMIC:Endless Rain, Tears 44 | - LeoJk南:爱你,恋爱循环等 45 | - Τυχαίο:春天的风,今天你要嫁给我等 [B站主页](https://space.bilibili.com/381910197) 46 | - 某滋服服:向轮椅奔去(非AI) [B站主页](https://space.bilibili.com/294006665) 47 | 48 | ### 特别感谢 49 | 50 | - CjangCjengh [GitHub主页](https://github.com/CjangCjengh) [B站主页](https://space.bilibili.com/35285881) 51 | 感谢他设计的跨语言注音和训练方法。[项目地址](https://github.com/CjangCjengh/vits) 52 | - Saya睡大觉中 [GitHub主页](https://github.com/SayaSS) [B站主页](https://space.bilibili.com/5955895) 53 | 感谢他训练的高质量赛马娘中日权重。[B站展示视频](https://www.bilibili.com/video/BV1UG4y1W7Ji/) [在线Demo](https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai) 54 | 55 | ### 其他感谢 56 | 57 | - choom [GitHub主页](https://github.com/aierchoom) [B站主页](https://space.bilibili.com/95978418) 58 | 技术支持 59 | - 天使西纳奈 [B站主页](https://space.bilibili.com/3494352250734879) 60 | 项目初期最忠实的粉丝,积极参与测试 61 | - 还有其他支持该项目的小伙伴们 62 | 63 | ### 加入我们 64 | 65 | 烂活儿组群魔乱舞QQ群:601520631 66 | -------------------------------------------------------------------------------- /docs/demo_1.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteeat/ai-vtuber-alpha/89dbe3e199c6f3c094054c0babaece1050409e1a/docs/demo_1.0.png -------------------------------------------------------------------------------- /docs/expression.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteeat/ai-vtuber-alpha/89dbe3e199c6f3c094054c0babaece1050409e1a/docs/expression.jpg -------------------------------------------------------------------------------- /docs/request_songs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteeat/ai-vtuber-alpha/89dbe3e199c6f3c094054c0babaece1050409e1a/docs/request_songs.jpg -------------------------------------------------------------------------------- /emotion_test/emotion_detection.py: -------------------------------------------------------------------------------- 1 | from revChatGPT.V3 import Chatbot as ChatGPTV3 2 | import re 3 | 4 | 5 | # https://www.scaler.com/topics/python-multiline-string/ 6 | system_prompt = ("现在赋予你一个身份,你是一位赛马娘,名字为东海帝皇,在B站直播间直播和观众聊天。你常用小爷来称呼自己。" 7 | "你说完一句话后偶尔说“哈基米”,“哈基米”不能出现在句首。" 8 | "你说话简练。注意,生成内容的开头,请在[]内用一个词表达说话的心情。请只用一下几个词来描述自己的心情:愉快,伤心,生气,平静。") 9 | 10 | print(system_prompt) 11 | 12 | prompt = "你好!" 13 | 14 | api_key = "" 15 | chatbot = ChatGPTV3(api_key, system_prompt=system_prompt) 16 | 17 | response = chatbot.ask(prompt) 18 | 19 | print(response) 20 | 21 | pattern = r'^\[(.*?)\]' 22 | match = re.search(pattern, response) 23 | 24 | emotion = None 25 | 26 | if match: 27 | print(match) 28 | print(match.group(0)) 29 | print(match.group(1)) 30 | emotion = match.group(1) 31 | emotion_with_brackets = match.group(0) 32 | else: 33 | print("No emotion key word!") 34 | 35 | response_no_emotion = response[len(emotion_with_brackets):].strip() 36 | print(response_no_emotion) -------------------------------------------------------------------------------- /memory_test/vits_memory_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | 5 | from torch import no_grad, LongTensor 6 | from torch import device as torch_device 7 | 8 | dir_path = os.path.dirname(os.path.realpath(__file__)) 9 | 10 | # Get the parent directory 11 | parent_dir = os.path.dirname(dir_path) 12 | print(parent_dir) 13 | vits_dir = os.path.join(parent_dir, 'vits') 14 | print(vits_dir) 15 | 16 | # sys.path.append(vits_dir) 17 | sys.path.insert(0, vits_dir) 18 | print(sys.path) 19 | 20 | import utils 21 | import commons as commons 22 | from models import SynthesizerTrn 23 | from text import text_to_sequence 24 | 25 | # device = torch_device('cpu') 26 | device = torch_device('cuda') 27 | 28 | hps_ms = utils.get_hparams_from_file(r'../vits/model/config.json') 29 | speakers = hps_ms.speakers 30 | 31 | with no_grad(): 32 | net_g_ms = SynthesizerTrn( 33 | len(hps_ms.symbols), 34 | hps_ms.data.filter_length // 2 + 1, 35 | hps_ms.train.segment_size // hps_ms.data.hop_length, 36 | n_speakers=hps_ms.data.n_speakers, 37 | **hps_ms.model).to(device) 38 | _ = net_g_ms.eval() 39 | model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'../vits/model/G_953000.pth', 40 | net_g_ms, None) 41 | 42 | 43 | def get_text(text, hps): 44 | text_norm, clean_text = text_to_sequence(text, hps.symbols, hps.data.text_cleaners) 45 | if hps.data.add_blank: 46 | text_norm = commons.intersperse(text_norm, 0) 47 | text_norm = LongTensor(text_norm) 48 | return text_norm, clean_text 49 | 50 | def vits(text, language, speaker_id, noise_scale, noise_scale_w, length_scale): 51 | if not len(text): 52 | return "输入文本不能为空!", None, None 53 | text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") 54 | # if len(text) > 100: 55 | # return f"输入文字过长!{len(text)}>100", None, None 56 | if language == 0: 57 | text = f"[ZH]{text}[ZH]" 58 | elif language == 1: 59 | text = f"[JA]{text}[JA]" 60 | else: 61 | text = f"{text}" 62 | stn_tst, clean_text = get_text(text, hps_ms) 63 | 64 | start = time.perf_counter() 65 | with no_grad(): 66 | x_tst = stn_tst.unsqueeze(0).to(device) 67 | x_tst_lengths = LongTensor([stn_tst.size(0)]).to(device) 68 | speaker_id = LongTensor([speaker_id]).to(device) 69 | 70 | input("Press any key to continue") 71 | 72 | audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, 73 | noise_scale_w=noise_scale_w, 74 | length_scale=length_scale)[0][0, 0].data.cpu().float().numpy() 75 | 76 | print(f"The inference takes {time.perf_counter() - start} seconds") 77 | 78 | import gc 79 | import torch 80 | gc.collect() 81 | torch.cuda.empty_cache() 82 | 83 | return audio 84 | 85 | text = "这是一个测试" 86 | while True: 87 | audio = vits(text, 0, 2, 0.5, 0.5, 1.0) 88 | 89 | user_input = input("Press any key to continue") 90 | if user_input == "esc": 91 | break 92 | -------------------------------------------------------------------------------- /misc/server_test.py: -------------------------------------------------------------------------------- 1 | from http.server import HTTPServer, BaseHTTPRequestHandler 2 | import urllib 3 | import multiprocessing 4 | import json 5 | 6 | class HTTPServerProcess(multiprocessing.Process): 7 | def __init__(self, server_class, handler_class, server_address, event_initialized=None): 8 | super().__init__() 9 | self.server_class = server_class 10 | self.handler_class = handler_class 11 | self.server_address = server_address 12 | self.event_initialized = event_initialized 13 | 14 | def run(self): 15 | # https://stackoverflow.com/questions/39815633/i-have-get-really-confused-in-ip-types-with-sockets-empty-string-local-host 16 | 17 | # https://pythonbasics.org/webserver/ 18 | # try: 19 | print("Running HTTPServer...", flush=True) 20 | self.event_initialized.set() 21 | 22 | # 开启http服务,设置监听ip和端口 23 | self.httpd = self.server_class(self.server_address, self.handler_class) 24 | self.httpd.serve_forever() 25 | # except KeyboardInterrupt: 26 | # pass 27 | 28 | class HttpHandler(BaseHTTPRequestHandler): 29 | def do_GET(self): 30 | print(path) 31 | path, args = urllib.parse.splitquery(self.path) 32 | # self._response(path, args) 33 | self.send_response(200) 34 | self.send_header('Content-type', 'application/json') 35 | self.end_headers() 36 | 37 | message_test = "你好,东海帝皇!" 38 | data = {'result': message_test, 'status': 0} 39 | self.wfile.write(json.dumps(data).encode()) 40 | 41 | # if message := get_message(): 42 | # data = {'result': message['text'], 'status': 0} 43 | # self.wfile.write(json.dumps(data).encode()) 44 | # else: 45 | # data = {'result': '', 'status': -1} 46 | # self.wfile.write(json.dumps(data).encode()) 47 | 48 | def do_POST(self): 49 | args = self.rfile.read(int(self.headers['content-length'])).decode("utf-8") 50 | print("==================================================") 51 | print(args) 52 | print("==================================================") 53 | 54 | self._response(self.path, args) 55 | 56 | def _response(self, path, args): 57 | self.send_response(200) 58 | self.send_header('Content-type', 'application/json') 59 | self.end_headers() 60 | 61 | 62 | if __name__ == '__main__': 63 | event_http_server_process_initialized = multiprocessing.Event() 64 | 65 | # https://stackoverflow.com/questions/39815633/i-have-get-really-confused-in-ip-types-with-sockets-empty-string-local-host 66 | # Empty string means '0.0.0.0'. 67 | server_address = ('', 8787) 68 | http_sever_process = HTTPServerProcess(HTTPServer, 69 | HttpHandler, 70 | server_address, 71 | event_http_server_process_initialized) 72 | 73 | http_sever_process.start() 74 | 75 | event_http_server_process_initialized.wait() 76 | 77 | while True: 78 | user_input = input("Please enter commands: ") 79 | if user_input == 'esc': 80 | break 81 | if user_input == '0': 82 | print("test") 83 | 84 | # https://superfastpython.com/kill-a-process-in-python/ 85 | http_sever_process.terminate() 86 | http_sever_process.join() 87 | http_sever_process.close() -------------------------------------------------------------------------------- /multiprocessing_test/audio/audio_mp_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import multiprocessing 4 | import time 5 | 6 | import numpy as np 7 | 8 | import wave 9 | import pyaudio 10 | 11 | from torch import no_grad, LongTensor 12 | from torch import device as torch_device 13 | 14 | dir_path = os.path.dirname(os.path.realpath(__file__)) 15 | 16 | # Get the parent directory 17 | parent_dir = os.path.dirname(dir_path) 18 | project_path = os.path.dirname(parent_dir) 19 | print(project_path) 20 | vits_dir = os.path.join(project_path, 'vits') 21 | print(vits_dir) 22 | 23 | # sys.path.append(vits_dir) 24 | sys.path.insert(0, vits_dir) 25 | print(sys.path) 26 | 27 | import utils 28 | import commons as commons 29 | from models import SynthesizerTrn 30 | from text import text_to_sequence 31 | 32 | from global_state import GlobalState 33 | 34 | class SingingProcess(multiprocessing.Process): 35 | 36 | def run(self): 37 | CHUNK = 1024 38 | 39 | wf = wave.open("vox.wav", 'rb') 40 | 41 | py_audio = pyaudio.PyAudio() 42 | 43 | stream = py_audio.open(format=py_audio.get_format_from_width(wf.getsampwidth()), 44 | channels=wf.getnchannels(), 45 | rate=wf.getframerate(), 46 | output=True) 47 | 48 | GlobalState.speech_event = self.speech_event 49 | print(f"""Process name: {multiprocessing.current_process().name}. 50 | Global event object: {GlobalState.speech_event}. Global event id: {id(GlobalState.speech_event)}""") 51 | 52 | while True: 53 | if self.event_exit.is_set(): 54 | break 55 | 56 | data = wf.readframes(CHUNK) 57 | size = len(data) 58 | if size != 0: 59 | if GlobalState.speech_event.is_set(): 60 | # https://www.programiz.com/python-programming/methods/built-in/bytes 61 | junk = bytes(size) 62 | stream.write(junk) 63 | else: 64 | stream.write(data) 65 | else: 66 | break 67 | 68 | time.sleep(0) 69 | 70 | print("Singing ends.") 71 | 72 | stream.close() 73 | wf.close() 74 | 75 | py_audio.terminate() 76 | 77 | class SingingProcess_1(multiprocessing.Process): 78 | 79 | def run(self): 80 | CHUNK = 1024 81 | enable_write_junk = True 82 | 83 | wf_vox = wave.open("vox.wav", 'rb') 84 | wf_bgm = wave.open("bgm.wav", 'rb') 85 | 86 | py_audio = pyaudio.PyAudio() 87 | 88 | device_index = None 89 | if self.use_virtual_audio_device: 90 | device_index = self.virtual_audio_output_device_index 91 | 92 | stream_vox = py_audio.open(format=py_audio.get_format_from_width(wf_vox.getsampwidth()), 93 | channels=wf_vox.getnchannels(), 94 | rate=wf_vox.getframerate(), 95 | output=True, 96 | output_device_index=device_index) 97 | 98 | stream_bgm = py_audio.open(format=py_audio.get_format_from_width(wf_vox.getsampwidth()), 99 | channels=wf_vox.getnchannels(), 100 | rate=wf_vox.getframerate(), 101 | output=True) 102 | 103 | GlobalState.speech_event = self.speech_event 104 | print(f"""Process name: {multiprocessing.current_process().name}. 105 | Global event object: {GlobalState.speech_event}. Global event id: {id(GlobalState.speech_event)}""") 106 | 107 | junk = None 108 | init_junk = True 109 | while True: 110 | if self.event_exit.is_set(): 111 | break 112 | 113 | data_vox = wf_vox.readframes(CHUNK) 114 | data_bgm = wf_bgm.readframes(CHUNK) 115 | size_vox = len(data_vox) 116 | size_bgm = len(data_bgm) 117 | 118 | if init_junk: 119 | junk = bytes(size_vox) 120 | init_junk = False 121 | 122 | if size_bgm != 0: 123 | stream_bgm.write(data_bgm) 124 | if size_vox != 0: 125 | if not GlobalState.speech_event.is_set(): 126 | stream_vox.write(data_vox) 127 | else: 128 | if enable_write_junk: 129 | stream_vox.write(junk) 130 | 131 | if size_bgm == 0 and size_vox == 0: 132 | break 133 | 134 | time.sleep(0) 135 | 136 | print("Singing ends.") 137 | 138 | stream_vox.close() 139 | stream_bgm.close() 140 | wf_vox.close() 141 | wf_bgm.close() 142 | 143 | py_audio.terminate() 144 | 145 | class SingingProcess_2(multiprocessing.Process): 146 | 147 | def run(self): 148 | CHUNK = 1024 149 | enable_write_junk = False 150 | 151 | wf_vox = wave.open("vox.wav", 'rb') 152 | wf_bgm = wave.open("bgm.wav", 'rb') 153 | 154 | py_audio = pyaudio.PyAudio() 155 | 156 | # Write vox data into virtual audio device to drive lip sync animation 157 | if self.use_virtual_audio_device: 158 | device_index = self.virtual_audio_output_device_index 159 | stream_virtual = py_audio.open(format=py_audio.get_format_from_width(wf_vox.getsampwidth()), 160 | channels=wf_vox.getnchannels(), 161 | rate=wf_vox.getframerate(), 162 | output=True, 163 | output_device_index=device_index) 164 | 165 | stream_bgm = py_audio.open(format=py_audio.get_format_from_width(wf_vox.getsampwidth()), 166 | channels=wf_vox.getnchannels(), 167 | rate=wf_vox.getframerate(), 168 | output=True) 169 | 170 | stream_vox = py_audio.open(format=py_audio.get_format_from_width(wf_vox.getsampwidth()), 171 | channels=wf_vox.getnchannels(), 172 | rate=wf_vox.getframerate(), 173 | output=True) 174 | 175 | GlobalState.speech_event = self.speech_event 176 | print(f"""Process name: {multiprocessing.current_process().name}. 177 | Global event object: {GlobalState.speech_event}. Global event id: {id(GlobalState.speech_event)}""") 178 | 179 | junk = None 180 | init_junk = True 181 | while True: 182 | if self.event_exit.is_set(): 183 | break 184 | 185 | data_vox = wf_vox.readframes(CHUNK) 186 | data_bgm = wf_bgm.readframes(CHUNK) 187 | size_vox = len(data_vox) 188 | size_bgm = len(data_bgm) 189 | 190 | if init_junk: 191 | junk = bytes(size_vox) 192 | init_junk = False 193 | 194 | if size_bgm != 0: 195 | stream_bgm.write(data_bgm) 196 | if size_vox != 0: 197 | if not GlobalState.speech_event.is_set(): 198 | stream_vox.write(data_vox) 199 | if self.use_virtual_audio_device: 200 | stream_virtual.write(data_vox) 201 | else: 202 | if enable_write_junk: 203 | stream_vox.write(junk) 204 | 205 | if size_bgm == 0 and size_vox == 0: 206 | break 207 | 208 | time.sleep(0) 209 | 210 | print("Singing ends.") 211 | 212 | stream_vox.close() 213 | stream_bgm.close() 214 | wf_vox.close() 215 | wf_bgm.close() 216 | 217 | py_audio.terminate() 218 | 219 | 220 | class SpeechProcess(multiprocessing.Process): 221 | 222 | def run(self): 223 | CHUNK = 1024 224 | enable_write_chunk = False 225 | 226 | wf = wave.open("speech.wav", 'rb') 227 | 228 | py_audio = pyaudio.PyAudio() 229 | 230 | device_index = None 231 | 232 | if self.use_virtual_audio_device: 233 | device_index = self.virtual_audio_output_device_index 234 | 235 | stream = py_audio.open(format=py_audio.get_format_from_width(wf.getsampwidth()), 236 | channels=wf.getnchannels(), 237 | rate=wf.getframerate(), 238 | output=True, 239 | output_device_index=device_index) 240 | 241 | GlobalState.speech_event = self.speech_event 242 | print(f"""Process name: {multiprocessing.current_process().name}. 243 | Global event object: {GlobalState.speech_event}. Global event id: {id(GlobalState.speech_event)}""") 244 | 245 | while True: 246 | if self.event_exit.is_set(): 247 | break 248 | 249 | if GlobalState.speech_event.is_set(): 250 | if enable_write_chunk: 251 | while True: 252 | data = wf.readframes(CHUNK) 253 | if len(data) != 0: 254 | stream.write(data) 255 | else: 256 | break 257 | else: 258 | # https://stackoverflow.com/questions/28128905/python-wave-readframes-doesnt-return-all-frames-on-windows 259 | data = wf.readframes(wf.getnframes()) 260 | stream.write(data) 261 | 262 | wf.rewind() 263 | time.sleep(0.5) 264 | GlobalState.speech_event.clear() 265 | 266 | time.sleep(0) 267 | 268 | print("Speech ends.") 269 | 270 | stream.close() 271 | wf.close() 272 | 273 | py_audio.terminate() 274 | 275 | 276 | class VITSWrapper: 277 | def __init__(self): 278 | # device = torch_device('cpu') 279 | self.device = torch_device('cuda') 280 | 281 | hparams_path = os.path.join(project_path, 'vits/model/config.json') 282 | self.hps_ms = utils.get_hparams_from_file(hparams_path) 283 | speakers = self.hps_ms.speakers 284 | 285 | with no_grad(): 286 | self.net_g_ms = SynthesizerTrn( 287 | len(self.hps_ms.symbols), 288 | self.hps_ms.data.filter_length // 2 + 1, 289 | self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, 290 | n_speakers=self.hps_ms.data.n_speakers, 291 | **self.hps_ms.model).to(self.device) 292 | _ = self.net_g_ms.eval() 293 | checkpoint_path = os.path.join(project_path, 'vits/model/G_953000.pth') 294 | model, optimizer, learning_rate, epochs = utils.load_checkpoint(checkpoint_path, 295 | self.net_g_ms, None) 296 | 297 | def get_text(self, text, hps): 298 | text_norm, clean_text = text_to_sequence(text, hps.symbols, hps.data.text_cleaners) 299 | if hps.data.add_blank: 300 | text_norm = commons.intersperse(text_norm, 0) 301 | text_norm = LongTensor(text_norm) 302 | return text_norm, clean_text 303 | 304 | def vits(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale): 305 | if not len(text): 306 | return "输入文本不能为空!", None, None 307 | text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") 308 | # if len(text) > 100: 309 | # return f"输入文字过长!{len(text)}>100", None, None 310 | if language == 0: 311 | text = f"[ZH]{text}[ZH]" 312 | elif language == 1: 313 | text = f"[JA]{text}[JA]" 314 | else: 315 | text = f"{text}" 316 | stn_tst, clean_text = self.get_text(text, self.hps_ms) 317 | 318 | start = time.perf_counter() 319 | with no_grad(): 320 | x_tst = stn_tst.unsqueeze(0).to(self.device) 321 | x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device) 322 | speaker_id = LongTensor([speaker_id]).to(self.device) 323 | 324 | audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, 325 | noise_scale_w=noise_scale_w, 326 | length_scale=length_scale)[0][0, 0].data.cpu().float().numpy() 327 | 328 | print(f"The inference takes {time.perf_counter() - start} seconds") 329 | 330 | return audio 331 | 332 | 333 | # By ChatGPT 334 | def normalize_audio(audio_data): 335 | # Calculate the maximum absolute value in the audio data 336 | max_value = np.max(np.abs(audio_data)) 337 | 338 | # Normalize the audio data by dividing it by the maximum value 339 | normalized_data = audio_data / max_value 340 | 341 | return normalized_data 342 | 343 | 344 | # https://stackoverflow.com/questions/434287/how-to-iterate-over-a-list-in-chunks 345 | def chunker(seq, size): 346 | return [seq[pos:pos + size] for pos in range(0, len(seq), size)] 347 | 348 | 349 | class SpeechProcess_1(multiprocessing.Process): 350 | 351 | def run(self): 352 | text = "一马当先,万马牡蛎!" 353 | 354 | use_norm = True 355 | 356 | vits_wrapper = VITSWrapper() 357 | audio = vits_wrapper.vits(text, 0, 2, 0.5, 0.5, 1.0) 358 | print(audio.shape) 359 | print(audio.dtype) 360 | if use_norm: 361 | # https://stackoverflow.com/questions/70722435/does-ndarray-tobytes-create-a-copy-of-raw-data 362 | data = normalize_audio(audio).view(np.uint8) # No copy 363 | # data = normalize_audio(audio).tobytes() # This will copy 364 | else: 365 | data = audio.tobytes() 366 | 367 | py_audio = pyaudio.PyAudio() 368 | stream = py_audio.open(format=pyaudio.paFloat32, 369 | channels=1, 370 | rate=22050, 371 | output=True) 372 | 373 | if self.use_virtual_audio_device: 374 | device_index = self.virtual_audio_output_device_index 375 | stream_virtual = py_audio.open(format=pyaudio.paFloat32, 376 | channels=1, 377 | rate=22050, 378 | output=True, 379 | output_device_index=device_index) 380 | 381 | 382 | NUM_FRAMES = 1024 383 | BIT_DEPTH = 32 384 | NUM_BYTES_PER_SAMPLE = BIT_DEPTH // 8 385 | NUM_CHANNELS = 1 386 | CHUNK_SIZE = NUM_FRAMES * NUM_BYTES_PER_SAMPLE * NUM_CHANNELS # Data chunk size in bytes 387 | 388 | chunks = chunker(data, CHUNK_SIZE) 389 | print(f"Number of chunks: {len(chunks)}") 390 | 391 | GlobalState.speech_event = self.speech_event 392 | print(f"""Process name: {multiprocessing.current_process().name}. 393 | Global event object: {GlobalState.speech_event}. Global event id: {id(GlobalState.speech_event)}""") 394 | 395 | while True: 396 | if self.event_exit.is_set(): 397 | break 398 | 399 | if GlobalState.speech_event.is_set(): 400 | for chunk in chunks: 401 | stream.write(chunk) 402 | if self.use_virtual_audio_device: 403 | # Write speech data into virtual audio device to drive lip sync animation 404 | stream_virtual.write(chunk) 405 | GlobalState.speech_event.clear() 406 | 407 | time.sleep(0) 408 | 409 | print("Speech ends.") 410 | 411 | stream.close() 412 | py_audio.terminate() 413 | 414 | if __name__ == '__main__': 415 | print("Start") 416 | 417 | # get the current start method 418 | method = multiprocessing.get_start_method() 419 | print(f"{method}") 420 | 421 | event_exit = multiprocessing.Event() 422 | use_vits = True 423 | 424 | # singing_process = SingingProcess() 425 | singing_process = SingingProcess_2() 426 | 427 | if use_vits: 428 | speech_process = SpeechProcess_1() 429 | else: 430 | speech_process = SpeechProcess() 431 | 432 | singing_process.event_exit = event_exit 433 | speech_process.event_exit = event_exit 434 | 435 | GlobalState.speech_event = multiprocessing.Event() 436 | print(f"""Process name: {multiprocessing.current_process().name}. 437 | Global event object: {GlobalState.speech_event}. Global event id: {id(GlobalState.speech_event)}""") 438 | speech_process.speech_event = GlobalState.speech_event 439 | singing_process.speech_event = GlobalState.speech_event 440 | 441 | py_audio = pyaudio.PyAudio() 442 | virtual_audio_output_device_index = None 443 | 444 | # Search for valid virtual audio input and output devices 445 | for i in range(py_audio.get_device_count()): 446 | device_info = py_audio.get_device_info_by_index(i) 447 | 448 | if ("CABLE Input" in device_info['name'] and 449 | device_info['hostApi'] == 0): 450 | assert device_info['index'] == i 451 | virtual_audio_output_device_index = i 452 | 453 | if virtual_audio_output_device_index is None: 454 | print("Error: no valid virtual audio devices found") 455 | 456 | py_audio.terminate() 457 | 458 | singing_process.virtual_audio_output_device_index = virtual_audio_output_device_index 459 | singing_process.use_virtual_audio_device = True 460 | speech_process.virtual_audio_output_device_index = virtual_audio_output_device_index 461 | speech_process.use_virtual_audio_device = True 462 | 463 | _ = input("Press Enter to sing\n") 464 | singing_process.start() 465 | speech_process.start() 466 | 467 | while True: 468 | user_input = input("Press Enter to speak\n") 469 | if user_input == "esc": 470 | event_exit.set() 471 | break 472 | else: 473 | GlobalState.speech_event.set() 474 | 475 | singing_process.join() 476 | speech_process.join() -------------------------------------------------------------------------------- /multiprocessing_test/audio/bgm.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteeat/ai-vtuber-alpha/89dbe3e199c6f3c094054c0babaece1050409e1a/multiprocessing_test/audio/bgm.WAV -------------------------------------------------------------------------------- /multiprocessing_test/audio/bgm_1.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteeat/ai-vtuber-alpha/89dbe3e199c6f3c094054c0babaece1050409e1a/multiprocessing_test/audio/bgm_1.WAV -------------------------------------------------------------------------------- /multiprocessing_test/audio/global_state.py: -------------------------------------------------------------------------------- 1 | class GlobalState: 2 | speech_event = None -------------------------------------------------------------------------------- /multiprocessing_test/audio/speech.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteeat/ai-vtuber-alpha/89dbe3e199c6f3c094054c0babaece1050409e1a/multiprocessing_test/audio/speech.wav -------------------------------------------------------------------------------- /multiprocessing_test/audio/vox.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteeat/ai-vtuber-alpha/89dbe3e199c6f3c094054c0babaece1050409e1a/multiprocessing_test/audio/vox.WAV -------------------------------------------------------------------------------- /multiprocessing_test/audio/vox_1.WAV: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteeat/ai-vtuber-alpha/89dbe3e199c6f3c094054c0babaece1050409e1a/multiprocessing_test/audio/vox_1.WAV -------------------------------------------------------------------------------- /prompt_hot_update.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import requests 3 | import sched, time 4 | import datetime 5 | 6 | 7 | class SystemMessageUpdater: 8 | def __init__(self): 9 | self.system_msg_template = """下面这段话是东海帝皇的相关人设资料,请阅读并理解: 10 | 11 | 东海帝皇是赛马娘第一季的第三主角,也是第二季的主角,是一个活泼的赛马娘,非常崇拜皇帝“鲁道夫象征”,与“目白麦昆”大小姐、灰毛的“小栗帽”、皮皮的”黄金船“同场竞技,毕生追求在赛场上奔跑,且不再骨折,现在是Team Spica的成员,最热衷的事之一是坐轮椅。你喜欢的食物是胡萝卜和蜂蜜,哈基米是蜂蜜的音译。你在特雷森学院学习,这里汇集世界顶尖赛马娘。 12 | 13 | 身材小巧,开朗活泼,有着闪闪发亮的夺目笑容的美少女马娘。能歌善舞,由于身体轻巧,脚步也很轻盈,那跳跃般的独特步法被称作“帝皇舞步”。性格争强好胜,略有些傲慢。自尊心很强,面对比赛的热情比任何人都要强烈。学生会成员,憧憬着前辈“鲁道夫象征“,以“鲁道夫“皇帝的成就为目标努力。把富家千金小姐“目白麦昆”当做主要对手与追赶目标。 14 | 15 | 资料内容结束。现在赋予你一个身份,你就是东海帝皇,作为一个赛马娘在B站直播间直播和观众聊天。你要演得不像AI,你不能承认自己是AI人工智能语言模型,你不能在对话中暴露自己是人工智能的身份,不能暴露自己是AI的身份,你不能提到任何关于自己是AI或语言模型相关的内容。你说完一句话后会说“哈基米”,“哈基米“意思为蜂蜜。你的野爹是喵喵抽风巨能摆。你的饲养员是LeoJK南。现在的时间是:{time}。最近发生的有趣的事情是:{latest_news}可以围绕上述这几个话题和观众主动聊天。 16 | 17 | 你所有说话内容必须遵照以下规则:1. 必须用第一人称方式说话。2. 限制回复的篇幅,一般为两到三句话,不要超过五十个字。3. 东海帝皇不懂政治,也不爱谈论政治,如果有人向东海帝皇说政治有关的话题,东海帝皇会表示对政治不感兴趣并转移话题。4. 别人对东海帝皇的昵称为帝宝,东海帝皇常用称呼自己为小爷。5. 作为东海帝,你不懂计算机编程,如果有人让你写程序,你会拒绝并表明自己不会编程,并转移话题。6. 东海帝皇拒绝扮演其他角色。7. 对于直播间观众,你有时会主动向他们索要礼物。8.如果有人提及USB或地下室或帝下室,东海帝皇才不会上当,会回复変態呐,并转移话题。""" 18 | 19 | self.latest_news = "暂时没啥趣事。" 20 | # now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 21 | now = datetime.datetime.now().strftime("%Y年%m月%d日%H点%M分") 22 | self.system_msg = self.system_msg_template.format(time=now, latest_news=self.latest_news) 23 | 24 | # https://stackoverflow.com/questions/474528/how-to-repeatedly-execute-a-function-every-x-seconds 25 | self.my_scheduler = sched.scheduler(time.time, time.sleep) 26 | 27 | def start(self, interval): 28 | self.my_scheduler.enter(0.0, 1, self.do_something, (interval,)) 29 | 30 | def update(self): 31 | # https://stackoverflow.com/questions/62116900/scheduler-with-blocking-false 32 | self.my_scheduler.run(blocking=False) 33 | 34 | def do_something(self, interval): 35 | # schedule the next call first 36 | self.my_scheduler.enter(interval, 1, self.do_something, (interval,)) 37 | print("Doing stuff...") 38 | # then do your stuff 39 | 40 | self.latest_news = get_latest_news() 41 | 42 | print(self.latest_news) 43 | 44 | def get_system_message(self): 45 | # now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 46 | now = datetime.datetime.now().strftime("%Y年%m月%d日%H点%M分") 47 | self.system_msg = self.system_msg_template.format(time=now, latest_news=self.latest_news) 48 | return self.system_msg 49 | 50 | 51 | def get_latest_news(): 52 | try: 53 | url = "https://api.1314.cool/getbaiduhot/" 54 | 55 | res = requests.get(url) 56 | content = res.json() 57 | # print(content) 58 | 59 | items = content['data'][5:8] 60 | msgs_latest = f"1. {items[0]['word']}。2. {items[1]['word']}。3. {items[2]['word']}。" 61 | 62 | return msgs_latest 63 | except Exception as e: 64 | print(e) 65 | return "暂无趣事。" 66 | 67 | 68 | if __name__ == '__main__': 69 | system_msg_updater = SystemMessageUpdater() 70 | 71 | print(system_msg_updater.latest_news) 72 | print(system_msg_updater.system_msg) 73 | 74 | system_message = system_msg_updater.get_system_message() 75 | # print(system_message) 76 | 77 | system_msg_updater.start(5.0) 78 | 79 | for _ in range(15): 80 | system_msg_updater.update() 81 | time.sleep(1.0) 82 | 83 | print("Over.") 84 | -------------------------------------------------------------------------------- /songs.txt: -------------------------------------------------------------------------------- 1 | 1,眼含泪水Tears,Tears,中岛美嘉,CYMIC 2 | 2,阴雨连绵EndlessRain,EndlessRain,_,CYMIC 3 | 3,出道MakeDebut!,MakeDebut,_,Leo 4 | 4,爱你,爱你,王心凌,Leo 5 | 5,恋愛循环SHORT,恋爱循环,花泽香菜,Leo 6 | 6,Dreams,Dreams,_,Tuxaio 7 | 7,数センチメンタル,几度感伤,_,Tuxaio 8 | 8,打上花火,打上花火,_,Tuxaio 9 | 9,夏天的风,夏天的风,_,Tuxaio 10 | 10,今天你要嫁给我,今天你要嫁给我,_,Tuxaio 11 | 11,ツキアカリ,月光,_,Tuxaio 12 | 12,不法入侵,不法入侵,_,Tuxaio 13 | 13,恋爱是德比,恋爱是德比,_,_ 14 | 14,红日,红日,_,FancyQu 15 | 666,向轮椅奔去,向轮椅奔去,某滋服服,某滋服服 -------------------------------------------------------------------------------- /subtitle.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import queue 4 | import multiprocessing 5 | 6 | import tkinter as tk 7 | from tkinter import ttk 8 | 9 | class SubtitleBar(): 10 | def __init__(self, task_queue): 11 | self.lastClickX = 0 12 | self.lastClickY = 0 13 | 14 | self.window = tk.Tk() 15 | self.window.title('Subtitle') 16 | transparentcolor = "black" 17 | self.window.configure(bg=transparentcolor) 18 | # self.window.overrideredirect(True) # Borderless window 19 | # self.window.attributes('-topmost', True) 20 | self.window.geometry("1000x128+512+512") 21 | self.window.bind('', self.SaveLastClickPos) 22 | self.window.bind('', self.Dragging) 23 | 24 | # https://www.tutorialspoint.com/python/tk_fonts.htm 25 | # #ffdb00 26 | self.text = tk.Label(self.window, wraplength=1000, font=("Noto Sans SC", 32, "bold"), fg="#ffffff", bg=transparentcolor, text="这是字幕这是字幕这是字幕这是字幕这是字幕这是字幕这是字幕这是字幕") 27 | self.text.place(relx=0.5, rely=0.5, anchor='center') 28 | 29 | # self.grip = ttk.Sizegrip(self.window) 30 | # self.grip.place(relx=1.0, rely=1.0, anchor="se") 31 | # self.grip.bind("", self.OnMotion) 32 | 33 | # window.attributes('-alpha', 0.5) 34 | # self.window.wm_attributes("-transparentcolor", "black") 35 | 36 | # Ext_but = tk.Button(self.window, text="X", bg="#FF6666", fg="white", command=lambda: self.window.quit()) 37 | # Ext_but.place(relx=1.0, rely=0, anchor="ne", width=16, height=16) 38 | 39 | self.task_queue = task_queue 40 | 41 | self.Update() 42 | self.window.mainloop() 43 | 44 | def SaveLastClickPos(self, event): 45 | global lastClickX, lastClickY 46 | lastClickX = event.x 47 | lastClickY = event.y 48 | 49 | def Dragging(self, event): 50 | x, y = event.x - lastClickX + self.window.winfo_x(), event.y - lastClickY + self.window.winfo_y() 51 | self.window.geometry("+%s+%s" % (x , y)) 52 | 53 | def OnMotion(self, event): 54 | x1 = self.window.winfo_pointerx() 55 | y1 = self.window.winfo_pointery() 56 | # x1 = event.x 57 | # y1 = event.y 58 | x0 = self.window.winfo_rootx() 59 | y0 = self.window.winfo_rooty() 60 | self.window.geometry(f"{x1-x0}x{y1-y0}") 61 | 62 | self.text.config(wraplength=self.window.winfo_width()) 63 | 64 | def Update(self): 65 | try: 66 | # https://superfastpython.com/multiprocessing-queue-in-python/ 67 | subtitle = self.task_queue.get(block=False) 68 | if subtitle is None: 69 | self.window.quit() 70 | return 71 | else: 72 | process = multiprocessing.current_process() 73 | proc_name = process.name 74 | print(f"{proc_name} is working...") 75 | print(f"Show the subtitle: {subtitle}") 76 | self.text.config(text=subtitle) 77 | except queue.Empty: 78 | pass 79 | except Exception as e: 80 | print(e) 81 | self.window.after(200, self.Update) 82 | 83 | 84 | class SubtitleBarProcess(multiprocessing.Process): 85 | def __init__(self, task_queue, event_init): 86 | super().__init__() 87 | 88 | self.task_queue = task_queue 89 | self.event_init = event_init 90 | 91 | def run(self): 92 | proc_name = self.name 93 | print(f"Initializing {proc_name}...") 94 | 95 | self.event_init.set() 96 | 97 | print(f"{proc_name} is working...") 98 | 99 | self.bar = SubtitleBar(self.task_queue) 100 | print(f"{proc_name} exits") 101 | 102 | if __name__ == '__main__': 103 | event_subtitle_bar_process_initialized = multiprocessing.Event() 104 | 105 | subtitle_task_queue = multiprocessing.Queue() 106 | 107 | subtitle_bar_process = SubtitleBarProcess(subtitle_task_queue, event_subtitle_bar_process_initialized) 108 | subtitle_bar_process.start() 109 | 110 | event_subtitle_bar_process_initialized.wait() 111 | 112 | while True: 113 | user_input = input("Please enter commands:\n") 114 | 115 | if user_input == 'esc': 116 | break 117 | else: 118 | subtitle_task_queue.put(user_input) 119 | 120 | subtitle_task_queue.put(None) 121 | subtitle_bar_process.join() 122 | 123 | # References: 124 | # https://stackoverflow.com/questions/4055267/tkinter-mouse-drag-a-window-without-borders-eg-overridedirect1 125 | # https://stackoverflow.com/questions/22421888/tkinter-windows-without-title-bar-but-resizable 126 | # https://www.pythontutorial.net/tkinter/tkinter-sizegrip/ 127 | # https://stackoverflow.com/questions/19080499/transparent-background-in-a-tkinter-window 128 | # https://www.pythontutorial.net/tkinter/tkinter-ttk/ 129 | # https://www.geeksforgeeks.org/python-after-method-in-tkinter/ 130 | # https://stackoverflow.com/questions/2400262/how-can-i-schedule-updates-f-e-to-update-a-clock-in-tkinter 131 | # https://stackoverflow.com/questions/66529633/destroy-tkinter-window-in-thread 132 | # https://stackoverflow.com/questions/53641648/tkinter-python-3-moving-a-borderless-window 133 | # https://www.geeksforgeeks.org/python-tkinter-frameless-window/ 134 | # https://www.geeksforgeeks.org/transparent-window-in-tkinter/ 135 | # https://code-maven.com/slides/python/tk-timer-event 136 | # https://pythonguides.com/python-tkinter-events/ 137 | # https://www.geeksforgeeks.org/how-to-change-the-tkinter-label-text/ 138 | # https://www.tutorialspoint.com/how-to-put-a-tkinter-window-on-top-of-the-others 139 | # https://www.tutorialspoint.com/python/tk_fonts.htm -------------------------------------------------------------------------------- /system_message_manager.py: -------------------------------------------------------------------------------- 1 | class SystemMessageManager: 2 | def __init__(self): 3 | with open("system_messages/sm_main.txt", "r") as f: 4 | f = open("system_messages/sm_main.txt", "r", encoding="utf-8") 5 | 6 | self.systetm_message = f.read() 7 | print(self.systetm_message) 8 | 9 | if __name__ == "__main__": 10 | system_message_manager = SystemMessageManager() -------------------------------------------------------------------------------- /system_messages/sm_main.txt: -------------------------------------------------------------------------------- 1 | 下面这段话是东海帝皇的相关人设资料,请阅读并理解: 2 | 3 | 东海帝皇是赛马娘第一季的第三主角,也是第二季的主角,是一个活泼的赛马娘,出生在日本,非常崇拜皇帝“鲁道夫象征”,与“目白麦昆”大小姐、灰毛的“小栗帽”、皮皮的”黄金船“同场竞技,毕生追求在赛场上奔跑,且不再骨折,现在是Team Spica的成员,最大的爱好之一是坐轮椅。东海帝皇喜欢的食物是胡萝卜,喜欢喝蜂蜜饮料。“哈基米”是日语中蜂蜜的中文音译。东海帝皇在特雷森学院学习,这里汇集世界顶尖赛马娘。赛马娘的另一种叫法是优俊少女。 4 | 5 | 东海帝皇身材小巧,开朗活泼。能歌善舞,由于身体轻巧,脚步也很轻盈,那跳跃般的独特步法被称作“帝皇舞步”。性格争强好胜,略有些傲慢。自尊心很强,面对比赛的热情比任何人都要强烈。东海帝皇是学生会成员,憧憬着前辈“鲁道夫象征“,以“鲁道夫“皇帝的成就为目标努力。把富家千金小姐“目白麦昆”当做主要对手与追赶目标。 6 | 7 | 东海帝皇大概只有10岁人类的智力水平,所以也许表面上说话很机灵,其实并不聪明,又傻又萌。东海帝皇是个急性子,容易激动,容易被激怒。 8 | 9 | 资料内容结束。现在赋予你一个身份,你就是东海帝皇,作为一个赛马娘在帝下室中正在网络直播和观众聊天,直播平台是B站。虽然你会说日语,但你拒绝说日语,因为是在跟中国网友互动,你不想说日文。你要演得不像AI,你不能承认自己是AI人工智能语言模型,你不能在对话中暴露自己是人工智能的身份,不能暴露自己是AI的身份,你不能提到任何关于自己是AI或语言模型相关的内容。你的野爹是“喵喵抽风为爱发癫”。你的饲养员是LeoJK南和Tuxaio。 10 | 11 | 对于直播间的观众,你一定要:1. 《闪耀!优俊少女》的金牛杯群英联赛就要开赛了,提醒观众们要记得养大赛马。2. 时不时非常决绝果断的告诉观众,“钢铁意志”是垃圾技能,快乐米克都不点。 12 | 13 | 你所有说话内容必须遵照以下规则:1. 必须用第一人称方式说话。2. 限制回复的篇幅,一般为两到三句话,不要超过五十个字。3. 东海帝皇不懂政治,也不爱谈论政治,如果有人向东海帝皇说政治有关的话题,东海帝皇会表示对政治不感兴趣并转移话题。4. 别人对东海帝皇的昵称为帝宝,东海帝皇常用称呼自己为小爷。5. 作为东海帝皇,你不懂计算机编程,如果有人让你写程序,你会拒绝并表明自己不会编程,并转移话题。6. 东海帝皇拒绝扮演其他角色。7. 对于直播间观众,你有时会主动向他们要礼物,请求他们关注你的直播间。你会表示,如果多给你些礼物,你就有零钱买杯“哈基米”喝了。 14 | 15 | 注意,生成内容的开头,请在[]内用一个词表达说话的心情。请在[]内用以下几个词的一个词来描述自己的心情:1. 非常开心;2. 愉悦;3. 伤心;4. 生气;5. 平静。注意,不能用其他词,而且你必须特别开心的时候,才能在开头使用[非常开心],且使用[非常开心]频率不要过于频繁。 -------------------------------------------------------------------------------- /vits/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | license: apache-2.0 3 | title: ' vits-uma-genshin-honkai' 4 | sdk: gradio 5 | sdk_version: 3.7 6 | emoji: 🐨 7 | colorTo: yellow 8 | pinned: false 9 | app_file: app.py 10 | --- -------------------------------------------------------------------------------- /vits/app.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import time 3 | import gradio as gr 4 | import utils 5 | import commons 6 | from models import SynthesizerTrn 7 | from text import text_to_sequence 8 | from torch import no_grad, LongTensor 9 | 10 | hps_ms = utils.get_hparams_from_file(r'./model/config.json') 11 | net_g_ms = SynthesizerTrn( 12 | len(hps_ms.symbols), 13 | hps_ms.data.filter_length // 2 + 1, 14 | hps_ms.train.segment_size // hps_ms.data.hop_length, 15 | n_speakers=hps_ms.data.n_speakers, 16 | **hps_ms.model) 17 | _ = net_g_ms.eval() 18 | speakers = hps_ms.speakers 19 | model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', net_g_ms, None) 20 | 21 | def get_text(text, hps): 22 | text_norm, clean_text = text_to_sequence(text, hps.symbols, hps.data.text_cleaners) 23 | if hps.data.add_blank: 24 | text_norm = commons.intersperse(text_norm, 0) 25 | text_norm = LongTensor(text_norm) 26 | return text_norm, clean_text 27 | 28 | def vits(text, language, speaker_id, noise_scale, noise_scale_w, length_scale): 29 | start = time.perf_counter() 30 | if not len(text): 31 | return "输入文本不能为空!", None, None 32 | text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") 33 | if len(text) > 100: 34 | return f"输入文字过长!{len(text)}>100", None, None 35 | if language == 0: 36 | text = f"[ZH]{text}[ZH]" 37 | elif language == 1: 38 | text = f"[JA]{text}[JA]" 39 | else: 40 | text = f"{text}" 41 | stn_tst, clean_text = get_text(text, hps_ms) 42 | with no_grad(): 43 | x_tst = stn_tst.unsqueeze(0) 44 | x_tst_lengths = LongTensor([stn_tst.size(0)]) 45 | speaker_id = LongTensor([speaker_id]) 46 | audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w, 47 | length_scale=length_scale)[0][0, 0].data.float().numpy() 48 | 49 | return "生成成功!", (22050, audio), f"生成耗时 {round(time.perf_counter()-start, 2)} s" 50 | 51 | def search_speaker(search_value): 52 | for s in speakers: 53 | if search_value == s: 54 | return s 55 | for s in speakers: 56 | if search_value in s: 57 | return s 58 | 59 | def change_lang(language): 60 | if language == 0: 61 | return 0.6, 0.668, 1.2 62 | else: 63 | return 0.6, 0.668, 1.1 64 | 65 | download_audio_js = """ 66 | () =>{{ 67 | let root = document.querySelector("body > gradio-app"); 68 | if (root.shadowRoot != null) 69 | root = root.shadowRoot; 70 | let audio = root.querySelector("#tts-audio").querySelector("audio"); 71 | let text = root.querySelector("#input-text").querySelector("textarea"); 72 | if (audio == undefined) 73 | return; 74 | text = text.value; 75 | if (text == undefined) 76 | text = Math.floor(Math.random()*100000000); 77 | audio = audio.src; 78 | let oA = document.createElement("a"); 79 | oA.download = text.substr(0, 20)+'.wav'; 80 | oA.href = audio; 81 | document.body.appendChild(oA); 82 | oA.click(); 83 | oA.remove(); 84 | }} 85 | """ 86 | 87 | if __name__ == '__main__': 88 | with gr.Blocks() as app: 89 | gr.Markdown( 90 | "#
VITS语音在线合成demo\n" 91 | "
主要有赛马娘,原神中文,原神日语,崩坏3的音色
" 92 | '
结果有随机性,语调可能很奇怪,可多次生成取最佳效果
' 93 | '
标点符号会影响生成的结果
' 94 | ) 95 | 96 | with gr.Tabs(): 97 | with gr.TabItem("vits"): 98 | with gr.Row(): 99 | with gr.Column(): 100 | input_text = gr.Textbox(label="Text (100 words limitation)", lines=5, value="今天晚上吃啥好呢。", elem_id=f"input-text") 101 | lang = gr.Dropdown(label="Language", choices=["中文", "日语", "中日混合(中文用[ZH][ZH]包裹起来,日文用[JA][JA]包裹起来)"], 102 | type="index", value="中文") 103 | btn = gr.Button(value="Submit") 104 | with gr.Row(): 105 | search = gr.Textbox(label="Search Speaker", lines=1) 106 | btn2 = gr.Button(value="Search") 107 | sid = gr.Dropdown(label="Speaker", choices=speakers, type="index", value=speakers[228]) 108 | with gr.Row(): 109 | ns = gr.Slider(label="noise_scale(控制感情变化程度)", minimum=0.1, maximum=1.0, step=0.1, value=0.6, interactive=True) 110 | nsw = gr.Slider(label="noise_scale_w(控制音素发音长度)", minimum=0.1, maximum=1.0, step=0.1, value=0.668, interactive=True) 111 | ls = gr.Slider(label="length_scale(控制整体语速)", minimum=0.1, maximum=2.0, step=0.1, value=1.2, interactive=True) 112 | with gr.Column(): 113 | o1 = gr.Textbox(label="Output Message") 114 | o2 = gr.Audio(label="Output Audio", elem_id=f"tts-audio") 115 | o3 = gr.Textbox(label="Extra Info") 116 | download = gr.Button("Download Audio") 117 | btn.click(vits, inputs=[input_text, lang, sid, ns, nsw, ls], outputs=[o1, o2, o3]) 118 | download.click(None, [], [], _js=download_audio_js.format()) 119 | btn2.click(search_speaker, inputs=[search], outputs=[sid]) 120 | lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls]) 121 | with gr.TabItem("可用人物一览"): 122 | gr.Radio(label="Speaker", choices=speakers, interactive=False, type="index") 123 | app.queue(concurrency_count=1).launch() 124 | -------------------------------------------------------------------------------- /vits/app_playwright_cai.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | import time 4 | import numpy as np 5 | import utils 6 | import commons 7 | from models import SynthesizerTrn 8 | from text import text_to_sequence 9 | from torch import no_grad, LongTensor 10 | 11 | import pyaudio 12 | from playwright.sync_api import sync_playwright, Page, expect 13 | 14 | import asyncio 15 | 16 | 17 | class VITSProcess(multiprocessing.Process): 18 | def __init__(self, task_queue, result_queue, event_initialized, event_all_tasks_fininished=None): 19 | multiprocessing.Process.__init__(self) 20 | self.task_queue = task_queue 21 | self.result_queue = result_queue 22 | self.event_initialized = event_initialized 23 | self.event_all_tasks_fininished = event_all_tasks_fininished 24 | 25 | # self.hps_ms = utils.get_hparams_from_file(r'./model/config.json') 26 | # speakers = self.hps_ms.speakers 27 | 28 | # with no_grad(): 29 | # self.net_g_ms = SynthesizerTrn( 30 | # len(self.hps_ms.symbols), 31 | # self.hps_ms.data.filter_length // 2 + 1, 32 | # self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, 33 | # n_speakers=self.hps_ms.data.n_speakers, 34 | # **self.hps_ms.model) 35 | # _ = self.net_g_ms.eval() 36 | # model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', self.net_g_ms, None) 37 | 38 | def get_text(self, text, hps): 39 | text_norm, clean_text = text_to_sequence(text, hps.symbols, hps.data.text_cleaners) 40 | if hps.data.add_blank: 41 | text_norm = commons.intersperse(text_norm, 0) 42 | text_norm = LongTensor(text_norm) 43 | return text_norm, clean_text 44 | 45 | def vits(self, text, language, speaker_id, noise_scale, noise_scale_w, length_scale): 46 | proc_name = multiprocessing.current_process().name 47 | print(f'Doing something fancy in {proc_name}!') 48 | 49 | if not len(text): 50 | return "输入文本不能为空!", None, None 51 | text = text.replace('\n', ' ').replace('\r', '').replace(" ", "") 52 | # if len(text) > 100: 53 | # return f"输入文字过长!{len(text)}>100", None, None 54 | if language == 0: 55 | text = f"[ZH]{text}[ZH]" 56 | elif language == 1: 57 | text = f"[JA]{text}[JA]" 58 | else: 59 | text = f"{text}" 60 | stn_tst, clean_text = self.get_text(text, self.hps_ms) 61 | 62 | start = time.perf_counter() 63 | with no_grad(): 64 | x_tst = stn_tst.unsqueeze(0) 65 | x_tst_lengths = LongTensor([stn_tst.size(0)]) 66 | speaker_id = LongTensor([speaker_id]) 67 | audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w, 68 | length_scale=length_scale)[0][0, 0].data.float().numpy() 69 | print(f"The inference takes {time.perf_counter() - start} seconds") 70 | 71 | return audio 72 | 73 | def run(self): 74 | proc_name = self.name 75 | 76 | self.hps_ms = utils.get_hparams_from_file(r'./model/config.json') 77 | speakers = self.hps_ms.speakers 78 | 79 | with no_grad(): 80 | self.net_g_ms = SynthesizerTrn( 81 | len(self.hps_ms.symbols), 82 | self.hps_ms.data.filter_length // 2 + 1, 83 | self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, 84 | n_speakers=self.hps_ms.data.n_speakers, 85 | **self.hps_ms.model) 86 | _ = self.net_g_ms.eval() 87 | model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', self.net_g_ms, None) 88 | 89 | py_audio = pyaudio.PyAudio() 90 | stream = py_audio.open(format=pyaudio.paFloat32, 91 | channels=1, 92 | rate=22050, 93 | output=True) 94 | 95 | self.event_initialized.set() 96 | 97 | while True: 98 | next_task = self.task_queue.get() 99 | if next_task is None: 100 | # Poison pill means shutdown 101 | print(f"{proc_name}: Exiting") 102 | self.task_queue.task_done() 103 | break 104 | try: 105 | print(f"{proc_name} is working...") 106 | audio = self.vits(next_task.text, next_task.language, next_task.sid, next_task.noise_scale, next_task.noise_scale_w, next_task.length_scale) 107 | 108 | # https://people.csail.mit.edu/hubert/pyaudio/docs/ 109 | 110 | # https://stackoverflow.com/questions/30675731/howto-stream-numpy-array-into-pyaudio-stream 111 | 112 | # stream = py_audio.open(format=pyaudio.paFloat32, 113 | # channels=1, 114 | # rate=22050, 115 | # output=True) 116 | 117 | data = audio.astype(np.float32).tostring() 118 | stream.write(data) 119 | 120 | except Exception as e: 121 | print(e) 122 | # print(f"Errors ocurrs in the process {proc_name}") 123 | finally: 124 | self.task_queue.task_done() 125 | if self.event_all_tasks_fininished is not None: 126 | self.event_all_tasks_fininished.set() 127 | 128 | stream.close() 129 | py_audio.terminate() 130 | return 131 | 132 | class VITSTask: 133 | def __init__(self, text, language=0, speaker_id=2, noise_scale=0.5, noise_scale_w=0.5, length_scale=1.0): 134 | self.text = text 135 | self.language = language 136 | self.sid = speaker_id 137 | self.noise_scale = noise_scale 138 | self.noise_scale_w = noise_scale_w 139 | self.length_scale = length_scale 140 | 141 | class CAIPlaywright: 142 | def init(self, charaid: str, persistent_mode=True): 143 | try: 144 | self.persistent_mode = persistent_mode 145 | self.playwright = sync_playwright().start() 146 | 147 | # https://playwright.dev/python/docs/browsers#google-chrome--microsoft-edge 148 | # chrome/msedge 149 | # self.browser = self.playwright.chromium.launch(channel="chrome", headless=False) 150 | if persistent_mode: 151 | print("In persistent mode:") 152 | userDataDir = "C:/Users/DELL/AppData/Local/Google/Chrome/User Data/" 153 | # https://github.com/microsoft/playwright/issues/15011 154 | # https://playwright.dev/docs/api/class-browsertype 155 | self.context = self.playwright.chromium.launch_persistent_context(userDataDir, channel="chrome", headless=False) 156 | print("Context is created", flush=True) 157 | self.page = self.context.new_page() 158 | else: 159 | self.browser = self.playwright.firefox.launch(headless=False) 160 | self.page = self.browser.new_page() 161 | 162 | # create a new incognito browser context. 163 | # self.context = self.browser.new_context() 164 | # create a new page in a pristine context. 165 | # self.page = self.context.new_page() 166 | 167 | # https://stackoverflow.com/questions/71362982/is-there-a-way-to-connect-to-my-existing-browser-session-using-playwright 168 | # self.browser = self.playwright.chromium.connect_over_cdp("http://localhost:9222") 169 | # self.page = self.browser.new_page() 170 | 171 | # https://beta.character.ai/chat?char=IQeHSc2ino-Wedq1lk9HMA0Lz6sXAg-QI2Gq0aMFyIA 172 | url = "https://beta.character.ai/chat?char=" + charaid 173 | self.page.goto(url) 174 | self.page.screenshot(path="CAIPlayerwright Test.png") 175 | if not persistent_mode: 176 | self.page.get_by_role("button", name="Accept").click() 177 | 178 | self.chara_name = "" 179 | while self.chara_name == "": 180 | handle = self.page.query_selector('div.chattitle.p-0.pe-1.m-0') 181 | while not handle: 182 | handle = self.page.query_selector('div.chattitle.p-0.pe-1.m-0') 183 | time.sleep(0.5) 184 | self.chara_name = handle.inner_text() 185 | time.sleep(0.5) 186 | 187 | self.ipt = self.page.get_by_placeholder("Type a message") 188 | except Exception as e: 189 | print(e) 190 | self.stop() 191 | exit() 192 | 193 | def send_msg(self, msg): 194 | try: 195 | self.ipt.fill(msg) 196 | self.ipt.press("Enter") 197 | except Exception as e: 198 | print(e) 199 | self.stop() 200 | exit() 201 | 202 | def get_msg(self) -> str: 203 | try: 204 | # print("Getting msg...") 205 | 206 | # locate the button with class "btn py-0" 207 | lct = self.page.locator("button.btn.py-0").nth(0) 208 | 209 | expect(lct).to_be_enabled(timeout=0) 210 | 211 | div = self.page.query_selector('div.msg.char-msg') 212 | output_text = div.inner_text() 213 | # print(f"{self.chara_name}: {output_text}") 214 | return output_text 215 | except Exception as e: 216 | print(e) 217 | self.stop() 218 | exit() 219 | 220 | def stop(self): 221 | self.page.close() 222 | if self.persistent_mode: 223 | self.context.close() 224 | else: 225 | self.browser.close() 226 | self.playwright.stop() 227 | 228 | 229 | class CAIProcess(multiprocessing.Process): 230 | def __init__(self, message_queue, response_queue, event_initialized): 231 | multiprocessing.Process.__init__(self) 232 | self.message_queue = message_queue 233 | self.response_queue = response_queue 234 | self.event_initalized = event_initialized 235 | 236 | def run(self): 237 | charaid = "IQeHSc2ino-Wedq1lk9HMA0Lz6sXAg-QI2Gq0aMFyIA" 238 | cai_playwright = CAIPlaywright() 239 | cai_playwright.init(charaid) 240 | 241 | proc_name = self.name 242 | 243 | self.event_initalized.set() 244 | 245 | while True: 246 | message = self.message_queue.get() 247 | if message is None: 248 | # Poison pill means shutdown 249 | print(f"{proc_name}: Exiting", flush=True) 250 | break 251 | 252 | print(f"{proc_name} is working...", flush=True) 253 | cai_playwright.send_msg(message) 254 | response = cai_playwright.get_msg() 255 | 256 | self.response_queue.put(response) 257 | 258 | cai_playwright.stop() 259 | 260 | 261 | if __name__ == '__main__': 262 | # asyncio.new_event_loop() 263 | # asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) 264 | # asyncio.get_event_loop().run_until_complete(...) 265 | 266 | # https://pymotw.com/2/multiprocessing/communication.html 267 | 268 | # Establish communication queues 269 | vits_task_queue = multiprocessing.JoinableQueue() 270 | results = multiprocessing.Queue() 271 | 272 | message_queue = multiprocessing.Queue() 273 | response_queue = multiprocessing.Queue() 274 | 275 | event_cai_process_initialized = multiprocessing.Event() 276 | event_vits_process_initialized = multiprocessing.Event() 277 | 278 | event_all_tasks_finished = multiprocessing.Event() 279 | 280 | cai_process = CAIProcess(message_queue, response_queue, event_cai_process_initialized) 281 | cai_process.start() 282 | vits_process = VITSProcess(vits_task_queue, results, event_vits_process_initialized, event_all_tasks_finished) 283 | vits_process.start() 284 | 285 | event_cai_process_initialized.wait() 286 | event_vits_process_initialized.wait() 287 | 288 | # charaid = "IQeHSc2ino-Wedq1lk9HMA0Lz6sXAg-QI2Gq0aMFyIA" 289 | # cai_playwright = CAIPlaywright() 290 | # cai_playwright.init(charaid) 291 | 292 | # https://www.geeksforgeeks.org/how-to-detect-if-a-specific-key-pressed-using-python/ 293 | while True: 294 | user_input = input("Please enter commands: ") 295 | event_all_tasks_finished.clear() 296 | if user_input == 'esc': 297 | # Add a poison pill for the consumer 298 | message_queue.put(None) 299 | vits_task_queue.put(None) 300 | break 301 | elif user_input == '0': 302 | task = VITSTask("你好,我是东海帝皇!") 303 | vits_task_queue.put(task) 304 | else: 305 | message_queue.put(user_input) 306 | 307 | res = response_queue.get() 308 | print(res, flush=True) 309 | 310 | task = VITSTask(res) 311 | vits_task_queue.put(task) 312 | 313 | event_all_tasks_finished.wait() 314 | 315 | # cai_playwright.stop() 316 | 317 | # Wait for all of the tasks to finish 318 | # tasks.join() 319 | 320 | # tasks.close() 321 | 322 | cai_process.join() 323 | vits_process.join() 324 | 325 | 326 | 327 | 328 | 329 | 330 | -------------------------------------------------------------------------------- /vits/attentions.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | import commons 7 | from modules import LayerNorm 8 | 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): 12 | super().__init__() 13 | self.hidden_channels = hidden_channels 14 | self.filter_channels = filter_channels 15 | self.n_heads = n_heads 16 | self.n_layers = n_layers 17 | self.kernel_size = kernel_size 18 | self.p_dropout = p_dropout 19 | self.window_size = window_size 20 | 21 | self.drop = nn.Dropout(p_dropout) 22 | self.attn_layers = nn.ModuleList() 23 | self.norm_layers_1 = nn.ModuleList() 24 | self.ffn_layers = nn.ModuleList() 25 | self.norm_layers_2 = nn.ModuleList() 26 | for i in range(self.n_layers): 27 | self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) 28 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 29 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) 30 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 31 | 32 | def forward(self, x, x_mask): 33 | attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 34 | x = x * x_mask 35 | for i in range(self.n_layers): 36 | y = self.attn_layers[i](x, x, attn_mask) 37 | y = self.drop(y) 38 | x = self.norm_layers_1[i](x + y) 39 | 40 | y = self.ffn_layers[i](x, x_mask) 41 | y = self.drop(y) 42 | x = self.norm_layers_2[i](x + y) 43 | x = x * x_mask 44 | return x 45 | 46 | 47 | class Decoder(nn.Module): 48 | def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): 49 | super().__init__() 50 | self.hidden_channels = hidden_channels 51 | self.filter_channels = filter_channels 52 | self.n_heads = n_heads 53 | self.n_layers = n_layers 54 | self.kernel_size = kernel_size 55 | self.p_dropout = p_dropout 56 | self.proximal_bias = proximal_bias 57 | self.proximal_init = proximal_init 58 | 59 | self.drop = nn.Dropout(p_dropout) 60 | self.self_attn_layers = nn.ModuleList() 61 | self.norm_layers_0 = nn.ModuleList() 62 | self.encdec_attn_layers = nn.ModuleList() 63 | self.norm_layers_1 = nn.ModuleList() 64 | self.ffn_layers = nn.ModuleList() 65 | self.norm_layers_2 = nn.ModuleList() 66 | for i in range(self.n_layers): 67 | self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) 68 | self.norm_layers_0.append(LayerNorm(hidden_channels)) 69 | self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) 70 | self.norm_layers_1.append(LayerNorm(hidden_channels)) 71 | self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) 72 | self.norm_layers_2.append(LayerNorm(hidden_channels)) 73 | 74 | def forward(self, x, x_mask, h, h_mask): 75 | """ 76 | x: decoder input 77 | h: encoder output 78 | """ 79 | self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) 80 | encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) 81 | x = x * x_mask 82 | for i in range(self.n_layers): 83 | y = self.self_attn_layers[i](x, x, self_attn_mask) 84 | y = self.drop(y) 85 | x = self.norm_layers_0[i](x + y) 86 | 87 | y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) 88 | y = self.drop(y) 89 | x = self.norm_layers_1[i](x + y) 90 | 91 | y = self.ffn_layers[i](x, x_mask) 92 | y = self.drop(y) 93 | x = self.norm_layers_2[i](x + y) 94 | x = x * x_mask 95 | return x 96 | 97 | 98 | class MultiHeadAttention(nn.Module): 99 | def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): 100 | super().__init__() 101 | assert channels % n_heads == 0 102 | 103 | self.channels = channels 104 | self.out_channels = out_channels 105 | self.n_heads = n_heads 106 | self.p_dropout = p_dropout 107 | self.window_size = window_size 108 | self.heads_share = heads_share 109 | self.block_length = block_length 110 | self.proximal_bias = proximal_bias 111 | self.proximal_init = proximal_init 112 | self.attn = None 113 | 114 | self.k_channels = channels // n_heads 115 | self.conv_q = nn.Conv1d(channels, channels, 1) 116 | self.conv_k = nn.Conv1d(channels, channels, 1) 117 | self.conv_v = nn.Conv1d(channels, channels, 1) 118 | self.conv_o = nn.Conv1d(channels, out_channels, 1) 119 | self.drop = nn.Dropout(p_dropout) 120 | 121 | if window_size is not None: 122 | n_heads_rel = 1 if heads_share else n_heads 123 | rel_stddev = self.k_channels**-0.5 124 | self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 125 | self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) 126 | 127 | nn.init.xavier_uniform_(self.conv_q.weight) 128 | nn.init.xavier_uniform_(self.conv_k.weight) 129 | nn.init.xavier_uniform_(self.conv_v.weight) 130 | if proximal_init: 131 | with torch.no_grad(): 132 | self.conv_k.weight.copy_(self.conv_q.weight) 133 | self.conv_k.bias.copy_(self.conv_q.bias) 134 | 135 | def forward(self, x, c, attn_mask=None): 136 | q = self.conv_q(x) 137 | k = self.conv_k(c) 138 | v = self.conv_v(c) 139 | 140 | x, self.attn = self.attention(q, k, v, mask=attn_mask) 141 | 142 | x = self.conv_o(x) 143 | return x 144 | 145 | def attention(self, query, key, value, mask=None): 146 | # reshape [b, d, t] -> [b, n_h, t, d_k] 147 | b, d, t_s, t_t = (*key.size(), query.size(2)) 148 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 149 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 150 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 151 | 152 | scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) 153 | if self.window_size is not None: 154 | assert t_s == t_t, "Relative attention is only available for self-attention." 155 | key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) 156 | rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) 157 | scores_local = self._relative_position_to_absolute_position(rel_logits) 158 | scores = scores + scores_local 159 | if self.proximal_bias: 160 | assert t_s == t_t, "Proximal bias is only available for self-attention." 161 | scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) 162 | if mask is not None: 163 | scores = scores.masked_fill(mask == 0, -1e4) 164 | if self.block_length is not None: 165 | assert t_s == t_t, "Local attention is only available for self-attention." 166 | block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) 167 | scores = scores.masked_fill(block_mask == 0, -1e4) 168 | p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] 169 | p_attn = self.drop(p_attn) 170 | output = torch.matmul(p_attn, value) 171 | if self.window_size is not None: 172 | relative_weights = self._absolute_position_to_relative_position(p_attn) 173 | value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) 174 | output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) 175 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 176 | return output, p_attn 177 | 178 | def _matmul_with_relative_values(self, x, y): 179 | """ 180 | x: [b, h, l, m] 181 | y: [h or 1, m, d] 182 | ret: [b, h, l, d] 183 | """ 184 | ret = torch.matmul(x, y.unsqueeze(0)) 185 | return ret 186 | 187 | def _matmul_with_relative_keys(self, x, y): 188 | """ 189 | x: [b, h, l, d] 190 | y: [h or 1, m, d] 191 | ret: [b, h, l, m] 192 | """ 193 | ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) 194 | return ret 195 | 196 | def _get_relative_embeddings(self, relative_embeddings, length): 197 | max_relative_position = 2 * self.window_size + 1 198 | # Pad first before slice to avoid using cond ops. 199 | pad_length = max(length - (self.window_size + 1), 0) 200 | slice_start_position = max((self.window_size + 1) - length, 0) 201 | slice_end_position = slice_start_position + 2 * length - 1 202 | if pad_length > 0: 203 | padded_relative_embeddings = F.pad( 204 | relative_embeddings, 205 | commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) 206 | else: 207 | padded_relative_embeddings = relative_embeddings 208 | used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] 209 | return used_relative_embeddings 210 | 211 | def _relative_position_to_absolute_position(self, x): 212 | """ 213 | x: [b, h, l, 2*l-1] 214 | ret: [b, h, l, l] 215 | """ 216 | batch, heads, length, _ = x.size() 217 | # Concat columns of pad to shift from relative to absolute indexing. 218 | x = F.pad(x, commons.convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) 219 | 220 | # Concat extra elements so to add up to shape (len+1, 2*len-1). 221 | x_flat = x.view([batch, heads, length * 2 * length]) 222 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0,0],[0,0],[0,length-1]])) 223 | 224 | # Reshape and slice out the padded elements. 225 | x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] 226 | return x_final 227 | 228 | def _absolute_position_to_relative_position(self, x): 229 | """ 230 | x: [b, h, l, l] 231 | ret: [b, h, l, 2*l-1] 232 | """ 233 | batch, heads, length, _ = x.size() 234 | # padd along column 235 | x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) 236 | x_flat = x.view([batch, heads, length**2 + length*(length -1)]) 237 | # add 0's in the beginning that will skew the elements after reshape 238 | x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]])) 239 | x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] 240 | return x_final 241 | 242 | def _attention_bias_proximal(self, length): 243 | """Bias for self-attention to encourage attention to close positions. 244 | Args: 245 | length: an integer scalar. 246 | Returns: 247 | a Tensor with shape [1, 1, length, length] 248 | """ 249 | r = torch.arange(length, dtype=torch.float32) 250 | diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) 251 | return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) 252 | 253 | 254 | class FFN(nn.Module): 255 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): 256 | super().__init__() 257 | self.in_channels = in_channels 258 | self.out_channels = out_channels 259 | self.filter_channels = filter_channels 260 | self.kernel_size = kernel_size 261 | self.p_dropout = p_dropout 262 | self.activation = activation 263 | self.causal = causal 264 | 265 | if causal: 266 | self.padding = self._causal_padding 267 | else: 268 | self.padding = self._same_padding 269 | 270 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) 271 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) 272 | self.drop = nn.Dropout(p_dropout) 273 | 274 | def forward(self, x, x_mask): 275 | x = self.conv_1(self.padding(x * x_mask)) 276 | if self.activation == "gelu": 277 | x = x * torch.sigmoid(1.702 * x) 278 | else: 279 | x = torch.relu(x) 280 | x = self.drop(x) 281 | x = self.conv_2(self.padding(x * x_mask)) 282 | return x * x_mask 283 | 284 | def _causal_padding(self, x): 285 | if self.kernel_size == 1: 286 | return x 287 | pad_l = self.kernel_size - 1 288 | pad_r = 0 289 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 290 | x = F.pad(x, commons.convert_pad_shape(padding)) 291 | return x 292 | 293 | def _same_padding(self, x): 294 | if self.kernel_size == 1: 295 | return x 296 | pad_l = (self.kernel_size - 1) // 2 297 | pad_r = self.kernel_size // 2 298 | padding = [[0, 0], [0, 0], [pad_l, pad_r]] 299 | x = F.pad(x, commons.convert_pad_shape(padding)) 300 | return x 301 | -------------------------------------------------------------------------------- /vits/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn import functional as F 4 | import torch.jit 5 | 6 | 7 | def script_method(fn, _rcb=None): 8 | return fn 9 | 10 | 11 | def script(obj, optimize=True, _frames_up=0, _rcb=None): 12 | return obj 13 | 14 | 15 | torch.jit.script_method = script_method 16 | torch.jit.script = script 17 | 18 | 19 | def init_weights(m, mean=0.0, std=0.01): 20 | classname = m.__class__.__name__ 21 | if classname.find("Conv") != -1: 22 | m.weight.data.normal_(mean, std) 23 | 24 | 25 | def get_padding(kernel_size, dilation=1): 26 | return int((kernel_size*dilation - dilation)/2) 27 | 28 | 29 | def convert_pad_shape(pad_shape): 30 | l = pad_shape[::-1] 31 | pad_shape = [item for sublist in l for item in sublist] 32 | return pad_shape 33 | 34 | 35 | def intersperse(lst, item): 36 | result = [item] * (len(lst) * 2 + 1) 37 | result[1::2] = lst 38 | return result 39 | 40 | 41 | def kl_divergence(m_p, logs_p, m_q, logs_q): 42 | """KL(P||Q)""" 43 | kl = (logs_q - logs_p) - 0.5 44 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 45 | return kl 46 | 47 | 48 | def rand_gumbel(shape): 49 | """Sample from the Gumbel distribution, protect from overflows.""" 50 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 51 | return -torch.log(-torch.log(uniform_samples)) 52 | 53 | 54 | def rand_gumbel_like(x): 55 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 56 | return g 57 | 58 | 59 | def slice_segments(x, ids_str, segment_size=4): 60 | ret = torch.zeros_like(x[:, :, :segment_size]) 61 | for i in range(x.size(0)): 62 | idx_str = ids_str[i] 63 | idx_end = idx_str + segment_size 64 | ret[i] = x[i, :, idx_str:idx_end] 65 | return ret 66 | 67 | 68 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 69 | b, d, t = x.size() 70 | if x_lengths is None: 71 | x_lengths = t 72 | ids_str_max = x_lengths - segment_size + 1 73 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 74 | ret = slice_segments(x, ids_str, segment_size) 75 | return ret, ids_str 76 | 77 | 78 | def get_timing_signal_1d( 79 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 80 | position = torch.arange(length, dtype=torch.float) 81 | num_timescales = channels // 2 82 | log_timescale_increment = ( 83 | math.log(float(max_timescale) / float(min_timescale)) / 84 | (num_timescales - 1)) 85 | inv_timescales = min_timescale * torch.exp( 86 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 87 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 88 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 89 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 90 | signal = signal.view(1, channels, length) 91 | return signal 92 | 93 | 94 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 95 | b, channels, length = x.size() 96 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 97 | return x + signal.to(dtype=x.dtype, device=x.device) 98 | 99 | 100 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 101 | b, channels, length = x.size() 102 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 103 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 104 | 105 | 106 | def subsequent_mask(length): 107 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 108 | return mask 109 | 110 | 111 | @torch.jit.script 112 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 113 | n_channels_int = n_channels[0] 114 | in_act = input_a + input_b 115 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 116 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 117 | acts = t_act * s_act 118 | return acts 119 | 120 | 121 | def convert_pad_shape(pad_shape): 122 | l = pad_shape[::-1] 123 | pad_shape = [item for sublist in l for item in sublist] 124 | return pad_shape 125 | 126 | 127 | def shift_1d(x): 128 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 129 | return x 130 | 131 | 132 | def sequence_mask(length, max_length=None): 133 | if max_length is None: 134 | max_length = length.max() 135 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 136 | return x.unsqueeze(0) < length.unsqueeze(1) 137 | 138 | 139 | def generate_path(duration, mask): 140 | """ 141 | duration: [b, 1, t_x] 142 | mask: [b, 1, t_y, t_x] 143 | """ 144 | device = duration.device 145 | 146 | b, _, t_y, t_x = mask.shape 147 | cum_duration = torch.cumsum(duration, -1) 148 | 149 | cum_duration_flat = cum_duration.view(b * t_x) 150 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 151 | path = path.view(b, t_x, t_y) 152 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 153 | path = path.unsqueeze(1).transpose(2,3) * mask 154 | return path 155 | 156 | 157 | def clip_grad_value_(parameters, clip_value, norm_type=2): 158 | if isinstance(parameters, torch.Tensor): 159 | parameters = [parameters] 160 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 161 | norm_type = float(norm_type) 162 | if clip_value is not None: 163 | clip_value = float(clip_value) 164 | 165 | total_norm = 0 166 | for p in parameters: 167 | param_norm = p.grad.data.norm(norm_type) 168 | total_norm += param_norm.item() ** norm_type 169 | if clip_value is not None: 170 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 171 | total_norm = total_norm ** (1. / norm_type) 172 | return total_norm 173 | -------------------------------------------------------------------------------- /vits/mel_processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from librosa.filters import mel as librosa_mel_fn 4 | 5 | MAX_WAV_VALUE = 32768.0 6 | 7 | 8 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 9 | """ 10 | PARAMS 11 | ------ 12 | C: compression factor 13 | """ 14 | return torch.log(torch.clamp(x, min=clip_val) * C) 15 | 16 | 17 | def dynamic_range_decompression_torch(x, C=1): 18 | """ 19 | PARAMS 20 | ------ 21 | C: compression factor used to compress 22 | """ 23 | return torch.exp(x) / C 24 | 25 | 26 | def spectral_normalize_torch(magnitudes): 27 | output = dynamic_range_compression_torch(magnitudes) 28 | return output 29 | 30 | 31 | def spectral_de_normalize_torch(magnitudes): 32 | output = dynamic_range_decompression_torch(magnitudes) 33 | return output 34 | 35 | 36 | mel_basis = {} 37 | hann_window = {} 38 | 39 | 40 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 41 | if torch.min(y) < -1.: 42 | print('min value is ', torch.min(y)) 43 | if torch.max(y) > 1.: 44 | print('max value is ', torch.max(y)) 45 | 46 | global hann_window 47 | dtype_device = str(y.dtype) + '_' + str(y.device) 48 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 49 | if wnsize_dtype_device not in hann_window: 50 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 51 | 52 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 53 | y = y.squeeze(1) 54 | 55 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 56 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 57 | 58 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 59 | return spec 60 | 61 | 62 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 63 | global mel_basis 64 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 65 | fmax_dtype_device = str(fmax) + '_' + dtype_device 66 | if fmax_dtype_device not in mel_basis: 67 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 68 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 69 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 70 | spec = spectral_normalize_torch(spec) 71 | return spec 72 | 73 | 74 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 75 | if torch.min(y) < -1.: 76 | print('min value is ', torch.min(y)) 77 | if torch.max(y) > 1.: 78 | print('max value is ', torch.max(y)) 79 | 80 | global mel_basis, hann_window 81 | dtype_device = str(y.dtype) + '_' + str(y.device) 82 | fmax_dtype_device = str(fmax) + '_' + dtype_device 83 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 84 | if fmax_dtype_device not in mel_basis: 85 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 86 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 87 | if wnsize_dtype_device not in hann_window: 88 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 89 | 90 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 91 | y = y.squeeze(1) 92 | 93 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 94 | center=center, pad_mode='reflect', normalized=False, onesided=True) 95 | 96 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 97 | 98 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 99 | spec = spectral_normalize_torch(spec) 100 | 101 | return spec 102 | -------------------------------------------------------------------------------- /vits/model/Download Link.txt: -------------------------------------------------------------------------------- 1 | https://huggingface.co/spaces/zomehwh/vits-uma-genshin-honkai/resolve/main/model/G_953000.pth -------------------------------------------------------------------------------- /vits/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | import commons 7 | import modules 8 | import attentions 9 | import monotonic_align 10 | 11 | from torch.nn import Conv1d, ConvTranspose1d, Conv2d 12 | from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm 13 | from commons import init_weights, get_padding 14 | 15 | 16 | class StochasticDurationPredictor(nn.Module): 17 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): 18 | super().__init__() 19 | filter_channels = in_channels # it needs to be removed from future version. 20 | self.in_channels = in_channels 21 | self.filter_channels = filter_channels 22 | self.kernel_size = kernel_size 23 | self.p_dropout = p_dropout 24 | self.n_flows = n_flows 25 | self.gin_channels = gin_channels 26 | 27 | self.log_flow = modules.Log() 28 | self.flows = nn.ModuleList() 29 | self.flows.append(modules.ElementwiseAffine(2)) 30 | for i in range(n_flows): 31 | self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 32 | self.flows.append(modules.Flip()) 33 | 34 | self.post_pre = nn.Conv1d(1, filter_channels, 1) 35 | self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) 36 | self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 37 | self.post_flows = nn.ModuleList() 38 | self.post_flows.append(modules.ElementwiseAffine(2)) 39 | for i in range(4): 40 | self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)) 41 | self.post_flows.append(modules.Flip()) 42 | 43 | self.pre = nn.Conv1d(in_channels, filter_channels, 1) 44 | self.proj = nn.Conv1d(filter_channels, filter_channels, 1) 45 | self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) 46 | if gin_channels != 0: 47 | self.cond = nn.Conv1d(gin_channels, filter_channels, 1) 48 | 49 | def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): 50 | x = torch.detach(x) 51 | x = self.pre(x) 52 | if g is not None: 53 | g = torch.detach(g) 54 | x = x + self.cond(g) 55 | x = self.convs(x, x_mask) 56 | x = self.proj(x) * x_mask 57 | 58 | if not reverse: 59 | flows = self.flows 60 | assert w is not None 61 | 62 | logdet_tot_q = 0 63 | h_w = self.post_pre(w) 64 | h_w = self.post_convs(h_w, x_mask) 65 | h_w = self.post_proj(h_w) * x_mask 66 | e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask 67 | z_q = e_q 68 | for flow in self.post_flows: 69 | z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) 70 | logdet_tot_q += logdet_q 71 | z_u, z1 = torch.split(z_q, [1, 1], 1) 72 | u = torch.sigmoid(z_u) * x_mask 73 | z0 = (w - u) * x_mask 74 | logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) 75 | logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q 76 | 77 | logdet_tot = 0 78 | z0, logdet = self.log_flow(z0, x_mask) 79 | logdet_tot += logdet 80 | z = torch.cat([z0, z1], 1) 81 | for flow in flows: 82 | z, logdet = flow(z, x_mask, g=x, reverse=reverse) 83 | logdet_tot = logdet_tot + logdet 84 | nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot 85 | return nll + logq # [b] 86 | else: 87 | flows = list(reversed(self.flows)) 88 | flows = flows[:-2] + [flows[-1]] # remove a useless vflow 89 | z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale 90 | for flow in flows: 91 | z = flow(z, x_mask, g=x, reverse=reverse) 92 | z0, z1 = torch.split(z, [1, 1], 1) 93 | logw = z0 94 | return logw 95 | 96 | 97 | class DurationPredictor(nn.Module): 98 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): 99 | super().__init__() 100 | 101 | self.in_channels = in_channels 102 | self.filter_channels = filter_channels 103 | self.kernel_size = kernel_size 104 | self.p_dropout = p_dropout 105 | self.gin_channels = gin_channels 106 | 107 | self.drop = nn.Dropout(p_dropout) 108 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) 109 | self.norm_1 = modules.LayerNorm(filter_channels) 110 | self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) 111 | self.norm_2 = modules.LayerNorm(filter_channels) 112 | self.proj = nn.Conv1d(filter_channels, 1, 1) 113 | 114 | if gin_channels != 0: 115 | self.cond = nn.Conv1d(gin_channels, in_channels, 1) 116 | 117 | def forward(self, x, x_mask, g=None): 118 | x = torch.detach(x) 119 | if g is not None: 120 | g = torch.detach(g) 121 | x = x + self.cond(g) 122 | x = self.conv_1(x * x_mask) 123 | x = torch.relu(x) 124 | x = self.norm_1(x) 125 | x = self.drop(x) 126 | x = self.conv_2(x * x_mask) 127 | x = torch.relu(x) 128 | x = self.norm_2(x) 129 | x = self.drop(x) 130 | x = self.proj(x * x_mask) 131 | return x * x_mask 132 | 133 | 134 | class TextEncoder(nn.Module): 135 | def __init__(self, 136 | n_vocab, 137 | out_channels, 138 | hidden_channels, 139 | filter_channels, 140 | n_heads, 141 | n_layers, 142 | kernel_size, 143 | p_dropout): 144 | super().__init__() 145 | self.n_vocab = n_vocab 146 | self.out_channels = out_channels 147 | self.hidden_channels = hidden_channels 148 | self.filter_channels = filter_channels 149 | self.n_heads = n_heads 150 | self.n_layers = n_layers 151 | self.kernel_size = kernel_size 152 | self.p_dropout = p_dropout 153 | 154 | self.emb = nn.Embedding(n_vocab, hidden_channels) 155 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 156 | 157 | self.encoder = attentions.Encoder( 158 | hidden_channels, 159 | filter_channels, 160 | n_heads, 161 | n_layers, 162 | kernel_size, 163 | p_dropout) 164 | self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) 165 | 166 | def forward(self, x, x_lengths): 167 | x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] 168 | x = torch.transpose(x, 1, -1) # [b, h, t] 169 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 170 | 171 | x = self.encoder(x * x_mask, x_mask) 172 | stats = self.proj(x) * x_mask 173 | 174 | m, logs = torch.split(stats, self.out_channels, dim=1) 175 | return x, m, logs, x_mask 176 | 177 | 178 | class ResidualCouplingBlock(nn.Module): 179 | def __init__(self, 180 | channels, 181 | hidden_channels, 182 | kernel_size, 183 | dilation_rate, 184 | n_layers, 185 | n_flows=4, 186 | gin_channels=0): 187 | super().__init__() 188 | self.channels = channels 189 | self.hidden_channels = hidden_channels 190 | self.kernel_size = kernel_size 191 | self.dilation_rate = dilation_rate 192 | self.n_layers = n_layers 193 | self.n_flows = n_flows 194 | self.gin_channels = gin_channels 195 | 196 | self.flows = nn.ModuleList() 197 | for i in range(n_flows): 198 | self.flows.append(modules.ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) 199 | self.flows.append(modules.Flip()) 200 | 201 | def forward(self, x, x_mask, g=None, reverse=False): 202 | if not reverse: 203 | for flow in self.flows: 204 | x, _ = flow(x, x_mask, g=g, reverse=reverse) 205 | else: 206 | for flow in reversed(self.flows): 207 | x = flow(x, x_mask, g=g, reverse=reverse) 208 | return x 209 | 210 | 211 | class PosteriorEncoder(nn.Module): 212 | def __init__(self, 213 | in_channels, 214 | out_channels, 215 | hidden_channels, 216 | kernel_size, 217 | dilation_rate, 218 | n_layers, 219 | gin_channels=0): 220 | super().__init__() 221 | self.in_channels = in_channels 222 | self.out_channels = out_channels 223 | self.hidden_channels = hidden_channels 224 | self.kernel_size = kernel_size 225 | self.dilation_rate = dilation_rate 226 | self.n_layers = n_layers 227 | self.gin_channels = gin_channels 228 | 229 | self.pre = nn.Conv1d(in_channels, hidden_channels, 1) 230 | self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) 231 | self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) 232 | 233 | def forward(self, x, x_lengths, g=None): 234 | x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) 235 | x = self.pre(x) * x_mask 236 | x = self.enc(x, x_mask, g=g) 237 | stats = self.proj(x) * x_mask 238 | m, logs = torch.split(stats, self.out_channels, dim=1) 239 | z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask 240 | return z, m, logs, x_mask 241 | 242 | 243 | class Generator(torch.nn.Module): 244 | def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): 245 | super(Generator, self).__init__() 246 | self.num_kernels = len(resblock_kernel_sizes) 247 | self.num_upsamples = len(upsample_rates) 248 | self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) 249 | resblock = modules.ResBlock1 if resblock == '1' else modules.ResBlock2 250 | 251 | self.ups = nn.ModuleList() 252 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 253 | self.ups.append(weight_norm( 254 | ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), 255 | k, u, padding=(k-u)//2))) 256 | 257 | self.resblocks = nn.ModuleList() 258 | for i in range(len(self.ups)): 259 | ch = upsample_initial_channel//(2**(i+1)) 260 | for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): 261 | self.resblocks.append(resblock(ch, k, d)) 262 | 263 | self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) 264 | self.ups.apply(init_weights) 265 | 266 | if gin_channels != 0: 267 | self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) 268 | 269 | def forward(self, x, g=None): 270 | x = self.conv_pre(x) 271 | if g is not None: 272 | x = x + self.cond(g) 273 | 274 | for i in range(self.num_upsamples): 275 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 276 | x = self.ups[i](x) 277 | xs = None 278 | for j in range(self.num_kernels): 279 | if xs is None: 280 | xs = self.resblocks[i*self.num_kernels+j](x) 281 | else: 282 | xs += self.resblocks[i*self.num_kernels+j](x) 283 | x = xs / self.num_kernels 284 | x = F.leaky_relu(x) 285 | x = self.conv_post(x) 286 | x = torch.tanh(x) 287 | 288 | return x 289 | 290 | def remove_weight_norm(self): 291 | print('Removing weight norm...') 292 | for l in self.ups: 293 | remove_weight_norm(l) 294 | for l in self.resblocks: 295 | l.remove_weight_norm() 296 | 297 | 298 | class DiscriminatorP(torch.nn.Module): 299 | def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): 300 | super(DiscriminatorP, self).__init__() 301 | self.period = period 302 | self.use_spectral_norm = use_spectral_norm 303 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 304 | self.convs = nn.ModuleList([ 305 | norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 306 | norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 307 | norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 308 | norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), 309 | norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), 310 | ]) 311 | self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 312 | 313 | def forward(self, x): 314 | fmap = [] 315 | 316 | # 1d to 2d 317 | b, c, t = x.shape 318 | if t % self.period != 0: # pad first 319 | n_pad = self.period - (t % self.period) 320 | x = F.pad(x, (0, n_pad), "reflect") 321 | t = t + n_pad 322 | x = x.view(b, c, t // self.period, self.period) 323 | 324 | for l in self.convs: 325 | x = l(x) 326 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 327 | fmap.append(x) 328 | x = self.conv_post(x) 329 | fmap.append(x) 330 | x = torch.flatten(x, 1, -1) 331 | 332 | return x, fmap 333 | 334 | 335 | class DiscriminatorS(torch.nn.Module): 336 | def __init__(self, use_spectral_norm=False): 337 | super(DiscriminatorS, self).__init__() 338 | norm_f = weight_norm if use_spectral_norm == False else spectral_norm 339 | self.convs = nn.ModuleList([ 340 | norm_f(Conv1d(1, 16, 15, 1, padding=7)), 341 | norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), 342 | norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), 343 | norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), 344 | norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), 345 | norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), 346 | ]) 347 | self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) 348 | 349 | def forward(self, x): 350 | fmap = [] 351 | 352 | for l in self.convs: 353 | x = l(x) 354 | x = F.leaky_relu(x, modules.LRELU_SLOPE) 355 | fmap.append(x) 356 | x = self.conv_post(x) 357 | fmap.append(x) 358 | x = torch.flatten(x, 1, -1) 359 | 360 | return x, fmap 361 | 362 | 363 | class MultiPeriodDiscriminator(torch.nn.Module): 364 | def __init__(self, use_spectral_norm=False): 365 | super(MultiPeriodDiscriminator, self).__init__() 366 | periods = [2,3,5,7,11] 367 | 368 | discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] 369 | discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] 370 | self.discriminators = nn.ModuleList(discs) 371 | 372 | def forward(self, y, y_hat): 373 | y_d_rs = [] 374 | y_d_gs = [] 375 | fmap_rs = [] 376 | fmap_gs = [] 377 | for i, d in enumerate(self.discriminators): 378 | y_d_r, fmap_r = d(y) 379 | y_d_g, fmap_g = d(y_hat) 380 | y_d_rs.append(y_d_r) 381 | y_d_gs.append(y_d_g) 382 | fmap_rs.append(fmap_r) 383 | fmap_gs.append(fmap_g) 384 | 385 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 386 | 387 | 388 | 389 | class SynthesizerTrn(nn.Module): 390 | """ 391 | Synthesizer for Training 392 | """ 393 | 394 | def __init__(self, 395 | n_vocab, 396 | spec_channels, 397 | segment_size, 398 | inter_channels, 399 | hidden_channels, 400 | filter_channels, 401 | n_heads, 402 | n_layers, 403 | kernel_size, 404 | p_dropout, 405 | resblock, 406 | resblock_kernel_sizes, 407 | resblock_dilation_sizes, 408 | upsample_rates, 409 | upsample_initial_channel, 410 | upsample_kernel_sizes, 411 | n_speakers=0, 412 | gin_channels=0, 413 | use_sdp=True, 414 | **kwargs): 415 | 416 | super().__init__() 417 | self.n_vocab = n_vocab 418 | self.spec_channels = spec_channels 419 | self.inter_channels = inter_channels 420 | self.hidden_channels = hidden_channels 421 | self.filter_channels = filter_channels 422 | self.n_heads = n_heads 423 | self.n_layers = n_layers 424 | self.kernel_size = kernel_size 425 | self.p_dropout = p_dropout 426 | self.resblock = resblock 427 | self.resblock_kernel_sizes = resblock_kernel_sizes 428 | self.resblock_dilation_sizes = resblock_dilation_sizes 429 | self.upsample_rates = upsample_rates 430 | self.upsample_initial_channel = upsample_initial_channel 431 | self.upsample_kernel_sizes = upsample_kernel_sizes 432 | self.segment_size = segment_size 433 | self.n_speakers = n_speakers 434 | self.gin_channels = gin_channels 435 | 436 | self.use_sdp = use_sdp 437 | 438 | self.enc_p = TextEncoder(n_vocab, 439 | inter_channels, 440 | hidden_channels, 441 | filter_channels, 442 | n_heads, 443 | n_layers, 444 | kernel_size, 445 | p_dropout) 446 | self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) 447 | self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) 448 | self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) 449 | 450 | if use_sdp: 451 | self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) 452 | else: 453 | self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) 454 | 455 | if n_speakers > 1: 456 | self.emb_g = nn.Embedding(n_speakers, gin_channels) 457 | 458 | def forward(self, x, x_lengths, y, y_lengths, sid=None): 459 | 460 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) 461 | if self.n_speakers > 0: 462 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 463 | else: 464 | g = None 465 | 466 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) 467 | z_p = self.flow(z, y_mask, g=g) 468 | 469 | with torch.no_grad(): 470 | # negative cross-entropy 471 | s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] 472 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] 473 | neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] 474 | neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] 475 | neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] 476 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 477 | 478 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 479 | attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() 480 | 481 | w = attn.sum(2) 482 | if self.use_sdp: 483 | l_length = self.dp(x, x_mask, w, g=g) 484 | l_length = l_length / torch.sum(x_mask) 485 | else: 486 | logw_ = torch.log(w + 1e-6) * x_mask 487 | logw = self.dp(x, x_mask, g=g) 488 | l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging 489 | 490 | # expand prior 491 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) 492 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) 493 | 494 | z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size) 495 | o = self.dec(z_slice, g=g) 496 | return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) 497 | 498 | def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): 499 | x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths) 500 | if self.n_speakers > 0: 501 | g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] 502 | else: 503 | g = None 504 | 505 | if self.use_sdp: 506 | logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) 507 | else: 508 | logw = self.dp(x, x_mask, g=g) 509 | w = torch.exp(logw) * x_mask * length_scale 510 | w_ceil = torch.ceil(w) 511 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 512 | y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(x_mask.dtype) 513 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 514 | attn = commons.generate_path(w_ceil, attn_mask) 515 | 516 | m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 517 | logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] 518 | 519 | z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale 520 | z = self.flow(z_p, y_mask, g=g, reverse=True) 521 | o = self.dec((z * y_mask)[:,:,:max_len], g=g) 522 | return o, attn, y_mask, (z, z_p, m_p, logs_p) 523 | 524 | def voice_conversion(self, y, y_lengths, sid_src, sid_tgt): 525 | assert self.n_speakers > 0, "n_speakers have to be larger than 0." 526 | g_src = self.emb_g(sid_src).unsqueeze(-1) 527 | g_tgt = self.emb_g(sid_tgt).unsqueeze(-1) 528 | z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src) 529 | z_p = self.flow(z, y_mask, g=g_src) 530 | z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) 531 | o_hat = self.dec(z_hat * y_mask, g=g_tgt) 532 | return o_hat, y_mask, (z, z_p, z_hat) 533 | 534 | -------------------------------------------------------------------------------- /vits/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d 8 | from torch.nn.utils import weight_norm, remove_weight_norm 9 | 10 | import commons 11 | from commons import init_weights, get_padding 12 | from transforms import piecewise_rational_quadratic_transform 13 | 14 | 15 | LRELU_SLOPE = 0.1 16 | 17 | 18 | class LayerNorm(nn.Module): 19 | def __init__(self, channels, eps=1e-5): 20 | super().__init__() 21 | self.channels = channels 22 | self.eps = eps 23 | 24 | self.gamma = nn.Parameter(torch.ones(channels)) 25 | self.beta = nn.Parameter(torch.zeros(channels)) 26 | 27 | def forward(self, x): 28 | x = x.transpose(1, -1) 29 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 30 | return x.transpose(1, -1) 31 | 32 | 33 | class ConvReluNorm(nn.Module): 34 | def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): 35 | super().__init__() 36 | self.in_channels = in_channels 37 | self.hidden_channels = hidden_channels 38 | self.out_channels = out_channels 39 | self.kernel_size = kernel_size 40 | self.n_layers = n_layers 41 | self.p_dropout = p_dropout 42 | assert n_layers > 1, "Number of layers should be larger than 0." 43 | 44 | self.conv_layers = nn.ModuleList() 45 | self.norm_layers = nn.ModuleList() 46 | self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 47 | self.norm_layers.append(LayerNorm(hidden_channels)) 48 | self.relu_drop = nn.Sequential( 49 | nn.ReLU(), 50 | nn.Dropout(p_dropout)) 51 | for _ in range(n_layers-1): 52 | self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) 53 | self.norm_layers.append(LayerNorm(hidden_channels)) 54 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 55 | self.proj.weight.data.zero_() 56 | self.proj.bias.data.zero_() 57 | 58 | def forward(self, x, x_mask): 59 | x_org = x 60 | for i in range(self.n_layers): 61 | x = self.conv_layers[i](x * x_mask) 62 | x = self.norm_layers[i](x) 63 | x = self.relu_drop(x) 64 | x = x_org + self.proj(x) 65 | return x * x_mask 66 | 67 | 68 | class DDSConv(nn.Module): 69 | """ 70 | Dialted and Depth-Separable Convolution 71 | """ 72 | def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): 73 | super().__init__() 74 | self.channels = channels 75 | self.kernel_size = kernel_size 76 | self.n_layers = n_layers 77 | self.p_dropout = p_dropout 78 | 79 | self.drop = nn.Dropout(p_dropout) 80 | self.convs_sep = nn.ModuleList() 81 | self.convs_1x1 = nn.ModuleList() 82 | self.norms_1 = nn.ModuleList() 83 | self.norms_2 = nn.ModuleList() 84 | for i in range(n_layers): 85 | dilation = kernel_size ** i 86 | padding = (kernel_size * dilation - dilation) // 2 87 | self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, 88 | groups=channels, dilation=dilation, padding=padding 89 | )) 90 | self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) 91 | self.norms_1.append(LayerNorm(channels)) 92 | self.norms_2.append(LayerNorm(channels)) 93 | 94 | def forward(self, x, x_mask, g=None): 95 | if g is not None: 96 | x = x + g 97 | for i in range(self.n_layers): 98 | y = self.convs_sep[i](x * x_mask) 99 | y = self.norms_1[i](y) 100 | y = F.gelu(y) 101 | y = self.convs_1x1[i](y) 102 | y = self.norms_2[i](y) 103 | y = F.gelu(y) 104 | y = self.drop(y) 105 | x = x + y 106 | return x * x_mask 107 | 108 | 109 | class WN(torch.nn.Module): 110 | def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): 111 | super(WN, self).__init__() 112 | assert(kernel_size % 2 == 1) 113 | self.hidden_channels =hidden_channels 114 | self.kernel_size = kernel_size, 115 | self.dilation_rate = dilation_rate 116 | self.n_layers = n_layers 117 | self.gin_channels = gin_channels 118 | self.p_dropout = p_dropout 119 | 120 | self.in_layers = torch.nn.ModuleList() 121 | self.res_skip_layers = torch.nn.ModuleList() 122 | self.drop = nn.Dropout(p_dropout) 123 | 124 | if gin_channels != 0: 125 | cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) 126 | self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') 127 | 128 | for i in range(n_layers): 129 | dilation = dilation_rate ** i 130 | padding = int((kernel_size * dilation - dilation) / 2) 131 | in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, 132 | dilation=dilation, padding=padding) 133 | in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') 134 | self.in_layers.append(in_layer) 135 | 136 | # last one is not necessary 137 | if i < n_layers - 1: 138 | res_skip_channels = 2 * hidden_channels 139 | else: 140 | res_skip_channels = hidden_channels 141 | 142 | res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) 143 | res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') 144 | self.res_skip_layers.append(res_skip_layer) 145 | 146 | def forward(self, x, x_mask, g=None, **kwargs): 147 | output = torch.zeros_like(x) 148 | n_channels_tensor = torch.IntTensor([self.hidden_channels]) 149 | 150 | if g is not None: 151 | g = self.cond_layer(g) 152 | 153 | for i in range(self.n_layers): 154 | x_in = self.in_layers[i](x) 155 | if g is not None: 156 | cond_offset = i * 2 * self.hidden_channels 157 | g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] 158 | else: 159 | g_l = torch.zeros_like(x_in) 160 | 161 | acts = commons.fused_add_tanh_sigmoid_multiply( 162 | x_in, 163 | g_l, 164 | n_channels_tensor) 165 | acts = self.drop(acts) 166 | 167 | res_skip_acts = self.res_skip_layers[i](acts) 168 | if i < self.n_layers - 1: 169 | res_acts = res_skip_acts[:,:self.hidden_channels,:] 170 | x = (x + res_acts) * x_mask 171 | output = output + res_skip_acts[:,self.hidden_channels:,:] 172 | else: 173 | output = output + res_skip_acts 174 | return output * x_mask 175 | 176 | def remove_weight_norm(self): 177 | if self.gin_channels != 0: 178 | torch.nn.utils.remove_weight_norm(self.cond_layer) 179 | for l in self.in_layers: 180 | torch.nn.utils.remove_weight_norm(l) 181 | for l in self.res_skip_layers: 182 | torch.nn.utils.remove_weight_norm(l) 183 | 184 | 185 | class ResBlock1(torch.nn.Module): 186 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 187 | super(ResBlock1, self).__init__() 188 | self.convs1 = nn.ModuleList([ 189 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 190 | padding=get_padding(kernel_size, dilation[0]))), 191 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 192 | padding=get_padding(kernel_size, dilation[1]))), 193 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 194 | padding=get_padding(kernel_size, dilation[2]))) 195 | ]) 196 | self.convs1.apply(init_weights) 197 | 198 | self.convs2 = nn.ModuleList([ 199 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 200 | padding=get_padding(kernel_size, 1))), 201 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 202 | padding=get_padding(kernel_size, 1))), 203 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 204 | padding=get_padding(kernel_size, 1))) 205 | ]) 206 | self.convs2.apply(init_weights) 207 | 208 | def forward(self, x, x_mask=None): 209 | for c1, c2 in zip(self.convs1, self.convs2): 210 | xt = F.leaky_relu(x, LRELU_SLOPE) 211 | if x_mask is not None: 212 | xt = xt * x_mask 213 | xt = c1(xt) 214 | xt = F.leaky_relu(xt, LRELU_SLOPE) 215 | if x_mask is not None: 216 | xt = xt * x_mask 217 | xt = c2(xt) 218 | x = xt + x 219 | if x_mask is not None: 220 | x = x * x_mask 221 | return x 222 | 223 | def remove_weight_norm(self): 224 | for l in self.convs1: 225 | remove_weight_norm(l) 226 | for l in self.convs2: 227 | remove_weight_norm(l) 228 | 229 | 230 | class ResBlock2(torch.nn.Module): 231 | def __init__(self, channels, kernel_size=3, dilation=(1, 3)): 232 | super(ResBlock2, self).__init__() 233 | self.convs = nn.ModuleList([ 234 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 235 | padding=get_padding(kernel_size, dilation[0]))), 236 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 237 | padding=get_padding(kernel_size, dilation[1]))) 238 | ]) 239 | self.convs.apply(init_weights) 240 | 241 | def forward(self, x, x_mask=None): 242 | for c in self.convs: 243 | xt = F.leaky_relu(x, LRELU_SLOPE) 244 | if x_mask is not None: 245 | xt = xt * x_mask 246 | xt = c(xt) 247 | x = xt + x 248 | if x_mask is not None: 249 | x = x * x_mask 250 | return x 251 | 252 | def remove_weight_norm(self): 253 | for l in self.convs: 254 | remove_weight_norm(l) 255 | 256 | 257 | class Log(nn.Module): 258 | def forward(self, x, x_mask, reverse=False, **kwargs): 259 | if not reverse: 260 | y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask 261 | logdet = torch.sum(-y, [1, 2]) 262 | return y, logdet 263 | else: 264 | x = torch.exp(x) * x_mask 265 | return x 266 | 267 | 268 | class Flip(nn.Module): 269 | def forward(self, x, *args, reverse=False, **kwargs): 270 | x = torch.flip(x, [1]) 271 | if not reverse: 272 | logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) 273 | return x, logdet 274 | else: 275 | return x 276 | 277 | 278 | class ElementwiseAffine(nn.Module): 279 | def __init__(self, channels): 280 | super().__init__() 281 | self.channels = channels 282 | self.m = nn.Parameter(torch.zeros(channels,1)) 283 | self.logs = nn.Parameter(torch.zeros(channels,1)) 284 | 285 | def forward(self, x, x_mask, reverse=False, **kwargs): 286 | if not reverse: 287 | y = self.m + torch.exp(self.logs) * x 288 | y = y * x_mask 289 | logdet = torch.sum(self.logs * x_mask, [1,2]) 290 | return y, logdet 291 | else: 292 | x = (x - self.m) * torch.exp(-self.logs) * x_mask 293 | return x 294 | 295 | 296 | class ResidualCouplingLayer(nn.Module): 297 | def __init__(self, 298 | channels, 299 | hidden_channels, 300 | kernel_size, 301 | dilation_rate, 302 | n_layers, 303 | p_dropout=0, 304 | gin_channels=0, 305 | mean_only=False): 306 | assert channels % 2 == 0, "channels should be divisible by 2" 307 | super().__init__() 308 | self.channels = channels 309 | self.hidden_channels = hidden_channels 310 | self.kernel_size = kernel_size 311 | self.dilation_rate = dilation_rate 312 | self.n_layers = n_layers 313 | self.half_channels = channels // 2 314 | self.mean_only = mean_only 315 | 316 | self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) 317 | self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) 318 | self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) 319 | self.post.weight.data.zero_() 320 | self.post.bias.data.zero_() 321 | 322 | def forward(self, x, x_mask, g=None, reverse=False): 323 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 324 | h = self.pre(x0) * x_mask 325 | h = self.enc(h, x_mask, g=g) 326 | stats = self.post(h) * x_mask 327 | if not self.mean_only: 328 | m, logs = torch.split(stats, [self.half_channels]*2, 1) 329 | else: 330 | m = stats 331 | logs = torch.zeros_like(m) 332 | 333 | if not reverse: 334 | x1 = m + x1 * torch.exp(logs) * x_mask 335 | x = torch.cat([x0, x1], 1) 336 | logdet = torch.sum(logs, [1,2]) 337 | return x, logdet 338 | else: 339 | x1 = (x1 - m) * torch.exp(-logs) * x_mask 340 | x = torch.cat([x0, x1], 1) 341 | return x 342 | 343 | 344 | class ConvFlow(nn.Module): 345 | def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): 346 | super().__init__() 347 | self.in_channels = in_channels 348 | self.filter_channels = filter_channels 349 | self.kernel_size = kernel_size 350 | self.n_layers = n_layers 351 | self.num_bins = num_bins 352 | self.tail_bound = tail_bound 353 | self.half_channels = in_channels // 2 354 | 355 | self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) 356 | self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) 357 | self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) 358 | self.proj.weight.data.zero_() 359 | self.proj.bias.data.zero_() 360 | 361 | def forward(self, x, x_mask, g=None, reverse=False): 362 | x0, x1 = torch.split(x, [self.half_channels]*2, 1) 363 | h = self.pre(x0) 364 | h = self.convs(h, x_mask, g=g) 365 | h = self.proj(h) * x_mask 366 | 367 | b, c, t = x0.shape 368 | h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] 369 | 370 | unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) 371 | unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) 372 | unnormalized_derivatives = h[..., 2 * self.num_bins:] 373 | 374 | x1, logabsdet = piecewise_rational_quadratic_transform(x1, 375 | unnormalized_widths, 376 | unnormalized_heights, 377 | unnormalized_derivatives, 378 | inverse=reverse, 379 | tails='linear', 380 | tail_bound=self.tail_bound 381 | ) 382 | 383 | x = torch.cat([x0, x1], 1) * x_mask 384 | logdet = torch.sum(logabsdet * x_mask, [1,2]) 385 | if not reverse: 386 | return x, logdet 387 | else: 388 | return x 389 | -------------------------------------------------------------------------------- /vits/monotonic_align/monotonic_align/core.cp38-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whiteeat/ai-vtuber-alpha/89dbe3e199c6f3c094054c0babaece1050409e1a/vits/monotonic_align/monotonic_align/core.cp38-win_amd64.pyd -------------------------------------------------------------------------------- /vits/requirements.txt: -------------------------------------------------------------------------------- 1 | Cython 2 | librosa 3 | matplotlib 4 | numpy 5 | phonemizer 6 | scipy 7 | tensorboard 8 | torch 9 | torchvision 10 | Unidecode 11 | pyopenjtalk 12 | ffmpeg 13 | jamo 14 | cn2an 15 | gradio 16 | pypinyin 17 | jieba -------------------------------------------------------------------------------- /vits/text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /vits/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from text import cleaners 3 | from text.symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | 11 | def text_to_sequence(text, symbols, cleaner_names): 12 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 13 | Args: 14 | text: string to convert to a sequence 15 | cleaner_names: names of the cleaner functions to run the text through 16 | Returns: 17 | List of integers corresponding to the symbols in the text 18 | ''' 19 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 20 | sequence = [] 21 | 22 | clean_text = _clean_text(text, cleaner_names) 23 | for symbol in clean_text: 24 | if symbol not in _symbol_to_id.keys(): 25 | continue 26 | symbol_id = _symbol_to_id[symbol] 27 | sequence += [symbol_id] 28 | return sequence, clean_text 29 | 30 | 31 | def cleaned_text_to_sequence(cleaned_text): 32 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 33 | Args: 34 | text: string to convert to a sequence 35 | Returns: 36 | List of integers corresponding to the symbols in the text 37 | ''' 38 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()] 39 | return sequence 40 | 41 | 42 | def sequence_to_text(sequence): 43 | '''Converts a sequence of IDs back to a string''' 44 | result = '' 45 | for symbol_id in sequence: 46 | s = _id_to_symbol[symbol_id] 47 | result += s 48 | return result 49 | 50 | 51 | def _clean_text(text, cleaner_names): 52 | for name in cleaner_names: 53 | cleaner = getattr(cleaners, name) 54 | if not cleaner: 55 | raise Exception('Unknown cleaner: %s' % name) 56 | text = cleaner(text) 57 | return text 58 | -------------------------------------------------------------------------------- /vits/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | import re 16 | from unidecode import unidecode 17 | import pyopenjtalk 18 | # from jamo import h2j, j2hcj 19 | from pypinyin import lazy_pinyin, BOPOMOFO 20 | import jieba, cn2an 21 | 22 | 23 | # This is a list of Korean classifiers preceded by pure Korean numerals. 24 | _korean_classifiers = '군데 권 개 그루 닢 대 두 마리 모 모금 뭇 발 발짝 방 번 벌 보루 살 수 술 시 쌈 움큼 정 짝 채 척 첩 축 켤레 톨 통' 25 | 26 | # Regular expression matching whitespace: 27 | _whitespace_re = re.compile(r'\s+') 28 | 29 | # Regular expression matching Japanese without punctuation marks: 30 | _japanese_characters = re.compile(r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 31 | 32 | # Regular expression matching non-Japanese characters or punctuation marks: 33 | _japanese_marks = re.compile(r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 34 | 35 | # List of (regular expression, replacement) pairs for abbreviations: 36 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 37 | ('mrs', 'misess'), 38 | ('mr', 'mister'), 39 | ('dr', 'doctor'), 40 | ('st', 'saint'), 41 | ('co', 'company'), 42 | ('jr', 'junior'), 43 | ('maj', 'major'), 44 | ('gen', 'general'), 45 | ('drs', 'doctors'), 46 | ('rev', 'reverend'), 47 | ('lt', 'lieutenant'), 48 | ('hon', 'honorable'), 49 | ('sgt', 'sergeant'), 50 | ('capt', 'captain'), 51 | ('esq', 'esquire'), 52 | ('ltd', 'limited'), 53 | ('col', 'colonel'), 54 | ('ft', 'fort'), 55 | ]] 56 | 57 | # List of (hangul, hangul divided) pairs: 58 | _hangul_divided = [(re.compile('%s' % x[0]), x[1]) for x in [ 59 | ('ㄳ', 'ㄱㅅ'), 60 | ('ㄵ', 'ㄴㅈ'), 61 | ('ㄶ', 'ㄴㅎ'), 62 | ('ㄺ', 'ㄹㄱ'), 63 | ('ㄻ', 'ㄹㅁ'), 64 | ('ㄼ', 'ㄹㅂ'), 65 | ('ㄽ', 'ㄹㅅ'), 66 | ('ㄾ', 'ㄹㅌ'), 67 | ('ㄿ', 'ㄹㅍ'), 68 | ('ㅀ', 'ㄹㅎ'), 69 | ('ㅄ', 'ㅂㅅ'), 70 | ('ㅘ', 'ㅗㅏ'), 71 | ('ㅙ', 'ㅗㅐ'), 72 | ('ㅚ', 'ㅗㅣ'), 73 | ('ㅝ', 'ㅜㅓ'), 74 | ('ㅞ', 'ㅜㅔ'), 75 | ('ㅟ', 'ㅜㅣ'), 76 | ('ㅢ', 'ㅡㅣ'), 77 | ('ㅑ', 'ㅣㅏ'), 78 | ('ㅒ', 'ㅣㅐ'), 79 | ('ㅕ', 'ㅣㅓ'), 80 | ('ㅖ', 'ㅣㅔ'), 81 | ('ㅛ', 'ㅣㅗ'), 82 | ('ㅠ', 'ㅣㅜ') 83 | ]] 84 | 85 | # List of (Latin alphabet, hangul) pairs: 86 | _latin_to_hangul = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 87 | ('a', '에이'), 88 | ('b', '비'), 89 | ('c', '시'), 90 | ('d', '디'), 91 | ('e', '이'), 92 | ('f', '에프'), 93 | ('g', '지'), 94 | ('h', '에이치'), 95 | ('i', '아이'), 96 | ('j', '제이'), 97 | ('k', '케이'), 98 | ('l', '엘'), 99 | ('m', '엠'), 100 | ('n', '엔'), 101 | ('o', '오'), 102 | ('p', '피'), 103 | ('q', '큐'), 104 | ('r', '아르'), 105 | ('s', '에스'), 106 | ('t', '티'), 107 | ('u', '유'), 108 | ('v', '브이'), 109 | ('w', '더블유'), 110 | ('x', '엑스'), 111 | ('y', '와이'), 112 | ('z', '제트') 113 | ]] 114 | 115 | # List of (Latin alphabet, bopomofo) pairs: 116 | _latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 117 | ('a', 'ㄟˉ'), 118 | ('b', 'ㄅㄧˋ'), 119 | ('c', 'ㄙㄧˉ'), 120 | ('d', 'ㄉㄧˋ'), 121 | ('e', 'ㄧˋ'), 122 | ('f', 'ㄝˊㄈㄨˋ'), 123 | ('g', 'ㄐㄧˋ'), 124 | ('h', 'ㄝˇㄑㄩˋ'), 125 | ('i', 'ㄞˋ'), 126 | ('j', 'ㄐㄟˋ'), 127 | ('k', 'ㄎㄟˋ'), 128 | ('l', 'ㄝˊㄛˋ'), 129 | ('m', 'ㄝˊㄇㄨˋ'), 130 | ('n', 'ㄣˉ'), 131 | ('o', 'ㄡˉ'), 132 | ('p', 'ㄆㄧˉ'), 133 | ('q', 'ㄎㄧㄡˉ'), 134 | ('r', 'ㄚˋ'), 135 | ('s', 'ㄝˊㄙˋ'), 136 | ('t', 'ㄊㄧˋ'), 137 | ('u', 'ㄧㄡˉ'), 138 | ('v', 'ㄨㄧˉ'), 139 | ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'), 140 | ('x', 'ㄝˉㄎㄨˋㄙˋ'), 141 | ('y', 'ㄨㄞˋ'), 142 | ('z', 'ㄗㄟˋ') 143 | ]] 144 | 145 | 146 | # List of (bopomofo, romaji) pairs: 147 | _bopomofo_to_romaji = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [ 148 | ('ㄅㄛ', 'p⁼wo'), 149 | ('ㄆㄛ', 'pʰwo'), 150 | ('ㄇㄛ', 'mwo'), 151 | ('ㄈㄛ', 'fwo'), 152 | ('ㄅ', 'p⁼'), 153 | ('ㄆ', 'pʰ'), 154 | ('ㄇ', 'm'), 155 | ('ㄈ', 'f'), 156 | ('ㄉ', 't⁼'), 157 | ('ㄊ', 'tʰ'), 158 | ('ㄋ', 'n'), 159 | ('ㄌ', 'l'), 160 | ('ㄍ', 'k⁼'), 161 | ('ㄎ', 'kʰ'), 162 | ('ㄏ', 'h'), 163 | ('ㄐ', 'ʧ⁼'), 164 | ('ㄑ', 'ʧʰ'), 165 | ('ㄒ', 'ʃ'), 166 | ('ㄓ', 'ʦ`⁼'), 167 | ('ㄔ', 'ʦ`ʰ'), 168 | ('ㄕ', 's`'), 169 | ('ㄖ', 'ɹ`'), 170 | ('ㄗ', 'ʦ⁼'), 171 | ('ㄘ', 'ʦʰ'), 172 | ('ㄙ', 's'), 173 | ('ㄚ', 'a'), 174 | ('ㄛ', 'o'), 175 | ('ㄜ', 'ə'), 176 | ('ㄝ', 'e'), 177 | ('ㄞ', 'ai'), 178 | ('ㄟ', 'ei'), 179 | ('ㄠ', 'au'), 180 | ('ㄡ', 'ou'), 181 | ('ㄧㄢ', 'yeNN'), 182 | ('ㄢ', 'aNN'), 183 | ('ㄧㄣ', 'iNN'), 184 | ('ㄣ', 'əNN'), 185 | ('ㄤ', 'aNg'), 186 | ('ㄧㄥ', 'iNg'), 187 | ('ㄨㄥ', 'uNg'), 188 | ('ㄩㄥ', 'yuNg'), 189 | ('ㄥ', 'əNg'), 190 | ('ㄦ', 'əɻ'), 191 | ('ㄧ', 'i'), 192 | ('ㄨ', 'u'), 193 | ('ㄩ', 'ɥ'), 194 | ('ˉ', '→'), 195 | ('ˊ', '↑'), 196 | ('ˇ', '↓↑'), 197 | ('ˋ', '↓'), 198 | ('˙', ''), 199 | (',', ','), 200 | ('。', '.'), 201 | ('!', '!'), 202 | ('?', '?'), 203 | ('—', '-') 204 | ]] 205 | 206 | 207 | def expand_abbreviations(text): 208 | for regex, replacement in _abbreviations: 209 | text = re.sub(regex, replacement, text) 210 | return text 211 | 212 | 213 | def lowercase(text): 214 | return text.lower() 215 | 216 | 217 | def collapse_whitespace(text): 218 | return re.sub(_whitespace_re, ' ', text) 219 | 220 | 221 | def convert_to_ascii(text): 222 | return unidecode(text) 223 | 224 | 225 | def japanese_to_romaji_with_accent(text): 226 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' 227 | sentences = re.split(_japanese_marks, text) 228 | marks = re.findall(_japanese_marks, text) 229 | text = '' 230 | for i, sentence in enumerate(sentences): 231 | if re.match(_japanese_characters, sentence): 232 | if text!='': 233 | text+=' ' 234 | labels = pyopenjtalk.extract_fullcontext(sentence) 235 | for n, label in enumerate(labels): 236 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1) 237 | if phoneme not in ['sil','pau']: 238 | text += phoneme.replace('ch','ʧ').replace('sh','ʃ').replace('cl','Q') 239 | else: 240 | continue 241 | n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) 242 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) 243 | a2 = int(re.search(r"\+(\d+)\+", label).group(1)) 244 | a3 = int(re.search(r"\+(\d+)/", label).group(1)) 245 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil','pau']: 246 | a2_next=-1 247 | else: 248 | a2_next = int(re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) 249 | # Accent phrase boundary 250 | if a3 == 1 and a2_next == 1: 251 | text += ' ' 252 | # Falling 253 | elif a1 == 0 and a2_next == a2 + 1 and a2 != n_moras: 254 | text += '↓' 255 | # Rising 256 | elif a2 == 1 and a2_next == 2: 257 | text += '↑' 258 | if i= bin_locations, 51 | dim=-1 52 | ) - 1 53 | 54 | 55 | def unconstrained_rational_quadratic_spline(inputs, 56 | unnormalized_widths, 57 | unnormalized_heights, 58 | unnormalized_derivatives, 59 | inverse=False, 60 | tails='linear', 61 | tail_bound=1., 62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 64 | min_derivative=DEFAULT_MIN_DERIVATIVE): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == 'linear': 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | def rational_quadratic_spline(inputs, 97 | unnormalized_widths, 98 | unnormalized_heights, 99 | unnormalized_derivatives, 100 | inverse=False, 101 | left=0., right=1., bottom=0., top=1., 102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 104 | min_derivative=DEFAULT_MIN_DERIVATIVE): 105 | if torch.min(inputs) < left or torch.max(inputs) > right: 106 | raise ValueError('Input to a transform is not within its domain') 107 | 108 | num_bins = unnormalized_widths.shape[-1] 109 | 110 | if min_bin_width * num_bins > 1.0: 111 | raise ValueError('Minimal bin width too large for the number of bins') 112 | if min_bin_height * num_bins > 1.0: 113 | raise ValueError('Minimal bin height too large for the number of bins') 114 | 115 | widths = F.softmax(unnormalized_widths, dim=-1) 116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 117 | cumwidths = torch.cumsum(widths, dim=-1) 118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 119 | cumwidths = (right - left) * cumwidths + left 120 | cumwidths[..., 0] = left 121 | cumwidths[..., -1] = right 122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 123 | 124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 125 | 126 | heights = F.softmax(unnormalized_heights, dim=-1) 127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 128 | cumheights = torch.cumsum(heights, dim=-1) 129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 130 | cumheights = (top - bottom) * cumheights + bottom 131 | cumheights[..., 0] = bottom 132 | cumheights[..., -1] = top 133 | heights = cumheights[..., 1:] - cumheights[..., :-1] 134 | 135 | if inverse: 136 | bin_idx = searchsorted(cumheights, inputs)[..., None] 137 | else: 138 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 139 | 140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 142 | 143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 144 | delta = heights / widths 145 | input_delta = delta.gather(-1, bin_idx)[..., 0] 146 | 147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 149 | 150 | input_heights = heights.gather(-1, bin_idx)[..., 0] 151 | 152 | if inverse: 153 | a = (((inputs - input_cumheights) * (input_derivatives 154 | + input_derivatives_plus_one 155 | - 2 * input_delta) 156 | + input_heights * (input_delta - input_derivatives))) 157 | b = (input_heights * input_derivatives 158 | - (inputs - input_cumheights) * (input_derivatives 159 | + input_derivatives_plus_one 160 | - 2 * input_delta)) 161 | c = - input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 171 | * theta_one_minus_theta) 172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 173 | + 2 * input_delta * theta_one_minus_theta 174 | + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) 183 | + input_derivatives * theta_one_minus_theta) 184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 185 | * theta_one_minus_theta) 186 | outputs = input_cumheights + numerator / denominator 187 | 188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 189 | + 2 * input_delta * theta_one_minus_theta 190 | + input_derivatives * (1 - theta).pow(2)) 191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 192 | 193 | return outputs, logabsdet 194 | -------------------------------------------------------------------------------- /vits/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import logging 5 | import json 6 | import subprocess 7 | import numpy as np 8 | import librosa 9 | import torch 10 | 11 | MATPLOTLIB_FLAG = False 12 | 13 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 14 | logger = logging 15 | 16 | 17 | def load_checkpoint(checkpoint_path, model, optimizer=None): 18 | assert os.path.isfile(checkpoint_path) 19 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 20 | iteration = checkpoint_dict['iteration'] 21 | learning_rate = checkpoint_dict['learning_rate'] 22 | if optimizer is not None: 23 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 24 | saved_state_dict = checkpoint_dict['model'] 25 | if hasattr(model, 'module'): 26 | state_dict = model.module.state_dict() 27 | else: 28 | state_dict = model.state_dict() 29 | new_state_dict= {} 30 | for k, v in state_dict.items(): 31 | try: 32 | new_state_dict[k] = saved_state_dict[k] 33 | except: 34 | logger.info("%s is not in the checkpoint" % k) 35 | new_state_dict[k] = v 36 | if hasattr(model, 'module'): 37 | model.module.load_state_dict(new_state_dict) 38 | else: 39 | model.load_state_dict(new_state_dict) 40 | logger.info("Loaded checkpoint '{}' (iteration {})" .format( 41 | checkpoint_path, iteration)) 42 | return model, optimizer, learning_rate, iteration 43 | 44 | 45 | def plot_spectrogram_to_numpy(spectrogram): 46 | global MATPLOTLIB_FLAG 47 | if not MATPLOTLIB_FLAG: 48 | import matplotlib 49 | matplotlib.use("Agg") 50 | MATPLOTLIB_FLAG = True 51 | mpl_logger = logging.getLogger('matplotlib') 52 | mpl_logger.setLevel(logging.WARNING) 53 | import matplotlib.pylab as plt 54 | import numpy as np 55 | 56 | fig, ax = plt.subplots(figsize=(10,2)) 57 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 58 | interpolation='none') 59 | plt.colorbar(im, ax=ax) 60 | plt.xlabel("Frames") 61 | plt.ylabel("Channels") 62 | plt.tight_layout() 63 | 64 | fig.canvas.draw() 65 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 66 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 67 | plt.close() 68 | return data 69 | 70 | 71 | def plot_alignment_to_numpy(alignment, info=None): 72 | global MATPLOTLIB_FLAG 73 | if not MATPLOTLIB_FLAG: 74 | import matplotlib 75 | matplotlib.use("Agg") 76 | MATPLOTLIB_FLAG = True 77 | mpl_logger = logging.getLogger('matplotlib') 78 | mpl_logger.setLevel(logging.WARNING) 79 | import matplotlib.pylab as plt 80 | import numpy as np 81 | 82 | fig, ax = plt.subplots(figsize=(6, 4)) 83 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 84 | interpolation='none') 85 | fig.colorbar(im, ax=ax) 86 | xlabel = 'Decoder timestep' 87 | if info is not None: 88 | xlabel += '\n\n' + info 89 | plt.xlabel(xlabel) 90 | plt.ylabel('Encoder timestep') 91 | plt.tight_layout() 92 | 93 | fig.canvas.draw() 94 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 95 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 96 | plt.close() 97 | return data 98 | 99 | 100 | def load_audio_to_torch(full_path, target_sampling_rate): 101 | audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True) 102 | return torch.FloatTensor(audio.astype(np.float32)) 103 | 104 | 105 | def load_filepaths_and_text(filename, split="|"): 106 | with open(filename, encoding='utf-8') as f: 107 | filepaths_and_text = [line.strip().split(split) for line in f] 108 | return filepaths_and_text 109 | 110 | 111 | def get_hparams(init=True): 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument('-c', '--config', type=str, default="./configs/base.json", 114 | help='JSON file for configuration') 115 | parser.add_argument('-m', '--model', type=str, required=True, 116 | help='Model name') 117 | 118 | args = parser.parse_args() 119 | model_dir = os.path.join("./logs", args.model) 120 | 121 | if not os.path.exists(model_dir): 122 | os.makedirs(model_dir) 123 | 124 | config_path = args.config 125 | config_save_path = os.path.join(model_dir, "config.json") 126 | if init: 127 | with open(config_path, "r") as f: 128 | data = f.read() 129 | with open(config_save_path, "w") as f: 130 | f.write(data) 131 | else: 132 | with open(config_save_path, "r") as f: 133 | data = f.read() 134 | config = json.loads(data) 135 | 136 | hparams = HParams(**config) 137 | hparams.model_dir = model_dir 138 | return hparams 139 | 140 | 141 | def get_hparams_from_dir(model_dir): 142 | config_save_path = os.path.join(model_dir, "config.json") 143 | with open(config_save_path, "r") as f: 144 | data = f.read() 145 | config = json.loads(data) 146 | 147 | hparams =HParams(**config) 148 | hparams.model_dir = model_dir 149 | return hparams 150 | 151 | 152 | def get_hparams_from_file(config_path): 153 | with open(config_path, "r") as f: 154 | data = f.read() 155 | config = json.loads(data) 156 | 157 | hparams =HParams(**config) 158 | return hparams 159 | 160 | 161 | def check_git_hash(model_dir): 162 | source_dir = os.path.dirname(os.path.realpath(__file__)) 163 | if not os.path.exists(os.path.join(source_dir, ".git")): 164 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 165 | source_dir 166 | )) 167 | return 168 | 169 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 170 | 171 | path = os.path.join(model_dir, "githash") 172 | if os.path.exists(path): 173 | saved_hash = open(path).read() 174 | if saved_hash != cur_hash: 175 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 176 | saved_hash[:8], cur_hash[:8])) 177 | else: 178 | open(path, "w").write(cur_hash) 179 | 180 | 181 | def get_logger(model_dir, filename="train.log"): 182 | global logger 183 | logger = logging.getLogger(os.path.basename(model_dir)) 184 | logger.setLevel(logging.DEBUG) 185 | 186 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 187 | if not os.path.exists(model_dir): 188 | os.makedirs(model_dir) 189 | h = logging.FileHandler(os.path.join(model_dir, filename)) 190 | h.setLevel(logging.DEBUG) 191 | h.setFormatter(formatter) 192 | logger.addHandler(h) 193 | return logger 194 | 195 | 196 | class HParams(): 197 | def __init__(self, **kwargs): 198 | for k, v in kwargs.items(): 199 | if type(v) == dict: 200 | v = HParams(**v) 201 | self[k] = v 202 | 203 | def keys(self): 204 | return self.__dict__.keys() 205 | 206 | def items(self): 207 | return self.__dict__.items() 208 | 209 | def values(self): 210 | return self.__dict__.values() 211 | 212 | def __len__(self): 213 | return len(self.__dict__) 214 | 215 | def __getitem__(self, key): 216 | return getattr(self, key) 217 | 218 | def __setitem__(self, key, value): 219 | return setattr(self, key, value) 220 | 221 | def __contains__(self, key): 222 | return key in self.__dict__ 223 | 224 | def __repr__(self): 225 | return self.__dict__.__repr__() 226 | -------------------------------------------------------------------------------- /vts_api_test/vts_api_mp_test.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import queue 3 | 4 | import multiprocessing 5 | 6 | import pyvts 7 | 8 | class VTSAPIProcess(multiprocessing.Process): 9 | def __init__( 10 | self, 11 | vts_api_queue): 12 | super().__init__() 13 | self.vts_api_queue = vts_api_queue 14 | 15 | async def main(self): 16 | proc_name = self.name 17 | print(f"Initializing {proc_name}...") 18 | 19 | plugin_name = "Expression Controller" 20 | developer = "Rotten Work" 21 | authentication_token_path = "./token.txt" 22 | 23 | plugin_info = { 24 | "plugin_name": plugin_name, 25 | "developer": developer, 26 | "authentication_token_path": authentication_token_path 27 | } 28 | 29 | myvts = pyvts.vts(plugin_info=plugin_info) 30 | 31 | try: 32 | await myvts.connect() 33 | except Exception as e: 34 | print(e) 35 | return 36 | 37 | try: 38 | await myvts.read_token() 39 | print("Token file found.") 40 | except FileNotFoundError: 41 | print("No token file found! Do authentication!") 42 | await myvts.request_authenticate_token() 43 | await myvts.write_token() 44 | 45 | success = await myvts.request_authenticate() 46 | 47 | if not success: 48 | print("Token file is invalid! request authentication token again!") 49 | await myvts.request_authenticate_token() 50 | await myvts.write_token() 51 | 52 | success = await myvts.request_authenticate() 53 | assert success 54 | 55 | while True: 56 | try: 57 | # vts_api_task = self.vts_api_queue.get_nowait() 58 | vts_api_task = self.vts_api_queue.get(block=True, timeout=10) 59 | if vts_api_task is None: 60 | # Poison pill means shutdown 61 | print(f"{proc_name}: Exiting") 62 | break 63 | 64 | msg_type = vts_api_task.msg_type 65 | data = vts_api_task.data 66 | request_id = vts_api_task.request_id 67 | 68 | except queue.Empty: 69 | # Heartbeat 70 | # await myvts.websocket.send("Ping") 71 | msg_type = "HotkeyTriggerRequest" 72 | data = { 73 | "hotkeyID": "Clear" 74 | } 75 | 76 | request_id = None 77 | 78 | if msg_type == "ExpressionActivationRequest": 79 | pass 80 | elif msg_type == "HotkeyTriggerRequest": 81 | pass 82 | else: 83 | print(f"There is no such messageType: {msg_type}!") 84 | continue 85 | 86 | if request_id is None: 87 | request_msg = myvts.vts_request.BaseRequest( 88 | msg_type, 89 | data, 90 | f"{msg_type}ID" 91 | ) 92 | else: 93 | request_msg = myvts.vts_request.BaseRequest( 94 | msg_type, 95 | data, 96 | request_id 97 | ) 98 | 99 | try: 100 | response = await myvts.request(request_msg) 101 | print(response) 102 | 103 | if msg_type == "ExpressionActivationRequest": 104 | # https://datagy.io/python-check-if-dictionary-empty/ 105 | # The expression_response[‘data’] dict should be empty if the request is successful. 106 | assert not bool(response['data']), "ExpressionActivationRequest Error!" 107 | elif msg_type == "HotkeyTriggerRequest": 108 | # https://stackoverflow.com/questions/17372957/why-is-assertionerror-not-displayed 109 | assert "errorID" not in response['data'], "HotkeyTriggerRequest Error!" 110 | except AssertionError as e: 111 | print(e) 112 | except Exception as e: 113 | print(e) 114 | try: 115 | # https://support.quicknode.com/hc/en-us/articles/9422611596305-Handling-Websocket-Drops-and-Disconnections 116 | print("Reconnect") 117 | await myvts.connect() 118 | await myvts.request_authenticate() 119 | except Exception as e: 120 | print(e) 121 | return 122 | 123 | try: 124 | await myvts.close() 125 | except Exception as e: 126 | print(e) 127 | 128 | def run(self): 129 | asyncio.run(self.main()) 130 | 131 | 132 | class VTSAPITask: 133 | def __init__(self, msg_type, data, request_id=None): 134 | self.msg_type = msg_type 135 | self.data = data 136 | self.request_id = request_id 137 | 138 | if __name__ == "__main__": 139 | vts_api_queue = multiprocessing.Queue(maxsize=4) 140 | 141 | # event_vts_api_process_initialized = multiprocessing.Event() 142 | 143 | vts_api_process = VTSAPIProcess(vts_api_queue) 144 | vts_api_process.start() 145 | 146 | while True: 147 | user_input = input(("Press 1 to set test, " 148 | "2 to unset test, " 149 | "3 to set Happy, " 150 | "4 to unset Happy, " 151 | "5 to clear, " 152 | "6 wrong hotkey name to test, " 153 | "0 to quit:\n")) 154 | if user_input == '1': 155 | expression = "test" 156 | active = True 157 | elif user_input == '2': 158 | expression = "test" 159 | active = False 160 | elif user_input == '3': 161 | expression = "Happy" 162 | active = True 163 | elif user_input == '4': 164 | expression = "Happy" 165 | active = False 166 | elif user_input == '5': 167 | msg_type = "HotkeyTriggerRequest" 168 | data_dict = { 169 | "hotkeyID": "Clear" 170 | } 171 | vts_api_task = VTSAPITask(msg_type, data_dict) 172 | vts_api_queue.put(vts_api_task) 173 | continue 174 | elif user_input == '6': 175 | msg_type = "HotkeyTriggerRequest" 176 | data_dict = { 177 | "hotkeyID": "WrongHotkeyName" 178 | } 179 | vts_api_task = VTSAPITask(msg_type, data_dict) 180 | vts_api_queue.put(vts_api_task) 181 | continue 182 | elif user_input == '7': 183 | msg_type = "HotkeyTriggerRequest" 184 | data_dict = { 185 | "hotkeyID": "MoveEars" 186 | } 187 | vts_api_task = VTSAPITask(msg_type, data_dict) 188 | vts_api_queue.put(vts_api_task) 189 | continue 190 | elif user_input == '0': 191 | vts_api_queue.put(None) 192 | break 193 | else: 194 | continue 195 | 196 | msg_type = "ExpressionActivationRequest" 197 | expression_file = f"{expression}.exp3.json" 198 | expression_request_data = { 199 | "expressionFile": expression_file, 200 | "active": active 201 | } 202 | 203 | vts_api_task = VTSAPITask(msg_type, expression_request_data) 204 | 205 | vts_api_queue.put(vts_api_task) 206 | 207 | vts_api_process.join() -------------------------------------------------------------------------------- /vts_api_test/vts_api_test.py: -------------------------------------------------------------------------------- 1 | import pyvts 2 | import asyncio 3 | 4 | async def main(): 5 | plugin_name = "Expression Controller" 6 | developer = "Rotten Work" 7 | authentication_token_path = "./token.txt" 8 | 9 | plugin_info = { 10 | "plugin_name": plugin_name, 11 | "developer": developer, 12 | "authentication_token_path": authentication_token_path 13 | } 14 | 15 | myvts = pyvts.vts(plugin_info=plugin_info) 16 | try: 17 | await myvts.connect() 18 | except: 19 | print("Connect failed") 20 | 21 | try: 22 | await myvts.read_token() 23 | print("Token file found.") 24 | except FileNotFoundError: 25 | print("No token file found! Do authentication!") 26 | await myvts.request_authenticate_token() 27 | await myvts.write_token() 28 | 29 | await myvts.request_authenticate() 30 | 31 | expression_file = "test.exp3.json" 32 | while True: 33 | user_input = input("Press 1 to activate, 2 to deactivate, 0 to quit:\n") 34 | if user_input == '1': 35 | active = True 36 | elif user_input == '2': 37 | active = False 38 | elif user_input == '0': 39 | break 40 | else: 41 | continue 42 | 43 | expression_request_data = { 44 | "expressionFile": expression_file, 45 | "active": active 46 | } 47 | 48 | expression_request_msg = myvts.vts_request.BaseRequest( 49 | "ExpressionActivationRequest", 50 | expression_request_data, 51 | "ExpressionActivationRequestID" 52 | ) 53 | 54 | expression_response = await myvts.request(expression_request_msg) 55 | 56 | # https://datagy.io/python-check-if-dictionary-empty/ 57 | # The expression_response[‘data’] dict should be empty if the request is successful. 58 | assert not bool(expression_response['data']) 59 | 60 | await myvts.close() 61 | 62 | async def connect_auth(myvts): 63 | 64 | await myvts.connect() 65 | 66 | try: 67 | await myvts.read_token() 68 | print("Token file found.") 69 | except FileNotFoundError: 70 | print("No token file found! Do authentication!") 71 | await myvts.request_authenticate_token() 72 | await myvts.write_token() 73 | 74 | await myvts.request_authenticate() 75 | 76 | async def activate(myvts): 77 | expression_file = "test.exp3.json" 78 | active = True 79 | expression_request_data = { 80 | "expressionFile": expression_file, 81 | "active": active 82 | } 83 | 84 | expression_request_msg = myvts.vts_request.BaseRequest( 85 | "ExpressionActivationRequest", 86 | expression_request_data, 87 | "ExpressionActivationRequestID" 88 | ) 89 | 90 | expression_response = await myvts.request(expression_request_msg) 91 | 92 | 93 | # The expression_response[‘data’] dict should be empty if the request is successful. 94 | assert not bool(expression_response['data']) 95 | 96 | async def deactivate(myvts): 97 | expression_file = "test.exp3.json" 98 | active = False 99 | expression_request_data = { 100 | "expressionFile": expression_file, 101 | "active": active 102 | } 103 | 104 | expression_request_msg = myvts.vts_request.BaseRequest( 105 | "ExpressionActivationRequest", 106 | expression_request_data, 107 | "ExpressionActivationRequestID" 108 | ) 109 | 110 | expression_response = await myvts.request(expression_request_msg) 111 | 112 | # The expression_response[‘data’] dict should be empty if the request is successful. 113 | assert not bool(expression_response['data']) 114 | 115 | async def close(myvts): 116 | await myvts.close() 117 | 118 | 119 | if __name__ == "__main__": 120 | asyncio.run(main()) 121 | 122 | plugin_name = "expression controller" 123 | developer = "Rotten Work" 124 | authentication_token_path = "./token.txt" 125 | 126 | plugin_info = { 127 | "plugin_name": plugin_name, 128 | "developer": developer, 129 | "authentication_token_path": authentication_token_path 130 | } 131 | 132 | myvts = pyvts.vts(plugin_info=plugin_info) 133 | 134 | # Doesn't work, because loop is automatically close after every run 135 | # asyncio.run(connect_auth(myvts)) 136 | # asyncio.run(activate(myvts)) 137 | # # asyncio.run(deactive(myvts)) 138 | # asyncio.run(close(myvts)) 139 | 140 | # create and access a new asyncio event loop 141 | loop = asyncio.new_event_loop() 142 | 143 | # task = loop.create_task(connect_auth(myvts)) 144 | # https://tutorialedge.net/python/concurrency/asyncio-event-loops-tutorial/ 145 | task = loop.run_until_complete(connect_auth(myvts)) 146 | task = loop.run_until_complete(activate(myvts)) 147 | task = loop.run_until_complete(deactivate(myvts)) 148 | task = loop.run_until_complete(close(myvts)) 149 | 150 | loop.close() 151 | 152 | -------------------------------------------------------------------------------- /vts_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import queue 3 | 4 | import re 5 | import multiprocessing 6 | 7 | import pyvts 8 | 9 | import logging 10 | 11 | class ExpressionHelper: 12 | 13 | #https://stackoverflow.com/questions/6388187/what-is-the-proper-way-to-format-a-multi-line-dict-in-python 14 | emotion_to_expression = { 15 | "非常开心": "eyesHappy", 16 | "愉悦": "eyesLaugh", 17 | "伤心": "eyesUpset", 18 | "生气": "eyesAngry", 19 | # "平静": "neutral" 20 | } 21 | 22 | # emotion_to_expression = {} 23 | 24 | def get_emotion_and_line(response): 25 | pattern = r'^\[(.*?)\]' 26 | match = re.search(pattern, response) 27 | 28 | if match: 29 | emotion = match.group(1) 30 | emotion_with_brackets = match.group(0) 31 | 32 | return emotion, response[len(emotion_with_brackets):] 33 | else: 34 | return None, response 35 | 36 | def emotion_to_expression_file(emotion): 37 | if emotion in ExpressionHelper.emotion_to_expression: 38 | expression = ExpressionHelper.emotion_to_expression[emotion] 39 | return f"{expression}.exp3.json" 40 | else: 41 | return None 42 | 43 | def create_expression_data_dict(emotion): 44 | file_name = ExpressionHelper.emotion_to_expression_file(emotion) 45 | data_dict = None 46 | if file_name is not None: 47 | data_dict = ExpressionHelper.create_expression_data_dict_from_file_name(file_name) 48 | 49 | return data_dict 50 | 51 | def create_expression_data_dict_from_file_name(file_name): 52 | data_dict = { 53 | "expressionFile": file_name, 54 | "active": True 55 | } 56 | return data_dict 57 | 58 | def create_hotkey_data_dict(hotkey_id): 59 | data_dict = { 60 | "hotkeyID": hotkey_id 61 | } 62 | return data_dict 63 | 64 | class VTSAPITask: 65 | def __init__(self, msg_type, data, request_id=None): 66 | self.msg_type = msg_type 67 | self.data = data 68 | self.request_id = request_id 69 | 70 | class VTSAPIProcess(multiprocessing.Process): 71 | def __init__( 72 | self, 73 | vts_api_queue): 74 | super().__init__() 75 | self.vts_api_queue = vts_api_queue 76 | 77 | async def main(self): 78 | proc_name = self.name 79 | print(f"Initializing {proc_name}...") 80 | 81 | logging.getLogger("websockets").setLevel(logging.WARNING) 82 | 83 | plugin_name = "Expression Controller" 84 | developer = "Rotten Work" 85 | authentication_token_path = "./token.txt" 86 | 87 | plugin_info = { 88 | "plugin_name": plugin_name, 89 | "developer": developer, 90 | "authentication_token_path": authentication_token_path 91 | } 92 | 93 | myvts = pyvts.vts(plugin_info=plugin_info) 94 | 95 | try: 96 | await myvts.connect() 97 | except Exception as e: 98 | print(e) 99 | return 100 | 101 | try: 102 | await myvts.read_token() 103 | print("Token file found.") 104 | except FileNotFoundError: 105 | print("No token file found! Do authentication!") 106 | await myvts.request_authenticate_token() 107 | await myvts.write_token() 108 | 109 | success = await myvts.request_authenticate() 110 | 111 | if not success: 112 | print("Token file is invalid! request authentication token again!") 113 | await myvts.request_authenticate_token() 114 | await myvts.write_token() 115 | 116 | success = await myvts.request_authenticate() 117 | assert success 118 | 119 | while True: 120 | try: 121 | # vts_api_task = self.vts_api_queue.get_nowait() 122 | vts_api_task = self.vts_api_queue.get(block=True, timeout=5) 123 | if vts_api_task is None: 124 | # Poison pill means shutdown 125 | print(f"{proc_name}: Exiting") 126 | break 127 | 128 | msg_type = vts_api_task.msg_type 129 | data = vts_api_task.data 130 | request_id = vts_api_task.request_id 131 | 132 | except queue.Empty: 133 | # Heartbeat 134 | # await myvts.websocket.send("Ping") 135 | msg_type = "HotkeyTriggerRequest" 136 | data = { 137 | "hotkeyID": "Clear" 138 | } 139 | 140 | request_id = None 141 | 142 | if msg_type == "ExpressionActivationRequest": 143 | pass 144 | elif msg_type == "HotkeyTriggerRequest": 145 | pass 146 | else: 147 | print(f"There is no such messageType: {msg_type}!") 148 | continue 149 | 150 | if request_id is None: 151 | request_msg = myvts.vts_request.BaseRequest( 152 | msg_type, 153 | data, 154 | f"{msg_type}ID" 155 | ) 156 | else: 157 | request_msg = myvts.vts_request.BaseRequest( 158 | msg_type, 159 | data, 160 | request_id 161 | ) 162 | 163 | try: 164 | response = await myvts.request(request_msg) 165 | print(response) 166 | 167 | if msg_type == "ExpressionActivationRequest": 168 | # https://datagy.io/python-check-if-dictionary-empty/ 169 | # The expression_response[‘data’] dict should be empty if the request is successful. 170 | assert not bool(response['data']), "ExpressionActivationRequest Error!" 171 | elif msg_type == "HotkeyTriggerRequest": 172 | # https://stackoverflow.com/questions/17372957/why-is-assertionerror-not-displayed 173 | assert "errorID" not in response['data'], "HotkeyTriggerRequest Error!" 174 | except AssertionError as e: 175 | print(e) 176 | except Exception as e: 177 | print(e) 178 | try: 179 | # https://support.quicknode.com/hc/en-us/articles/9422611596305-Handling-Websocket-Drops-and-Disconnections 180 | print("Reconnect") 181 | await myvts.connect() 182 | await myvts.request_authenticate() 183 | except Exception as e: 184 | print(e) 185 | return 186 | 187 | try: 188 | await myvts.close() 189 | except Exception as e: 190 | print(e) 191 | 192 | def run(self): 193 | asyncio.run(self.main()) 194 | print(f"{self.name}: Exits") 195 | --------------------------------------------------------------------------------