├── ckpt └── .gitkeep ├── data ├── .gitkeep └── binary │ └── .gitkeep ├── segments └── .gitkeep ├── dictionary ├── .gitkeep └── opencpop-extension.txt ├── modules ├── layer │ ├── __init__.py │ ├── scaling │ │ ├── __init__.py │ │ ├── base.py │ │ └── stride_conv.py │ ├── backbone │ │ ├── __init__.py │ │ └── unet.py │ ├── activation │ │ ├── __init__.py │ │ └── GLU.py │ └── block │ │ ├── __init__.py │ │ ├── func_module.py │ │ ├── convolution.py │ │ ├── residual.py │ │ ├── resnet_block.py │ │ ├── conformer.py │ │ └── attention.py ├── utils │ ├── __init__.py │ ├── get_melspec.py │ ├── load_wav.py │ ├── plot.py │ ├── post_processing.py │ ├── label.py │ ├── export_tool.py │ ├── metrics.py │ └── metrics_test.py ├── loss │ ├── __init__.py │ ├── BinaryEMDLoss.py │ └── GHMLoss.py ├── __init__.py ├── g2p │ ├── __init__.py │ ├── phoneme_g2p.py │ ├── none_g2p.py │ ├── dictionary_g2p.py │ ├── base_g2p.py │ ├── readme_g2p_zh.md │ └── readme_g2p.md ├── scheduler │ ├── __init__.py │ ├── none_scheduler.py │ └── gaussian_ramp_up_scheduler.py ├── rmvpe │ ├── constants.py │ ├── __init__.py │ ├── seq.py │ ├── model.py │ ├── utils.py │ ├── spec.py │ ├── inference.py │ └── deepunet.py └── AP_detector │ ├── __init__.py │ ├── none_detector.py │ ├── base_detector.py │ └── loudnesss_pectralcentroid_detector.py ├── example.png ├── .gitignore ├── requirements.txt ├── configs ├── binarize_config.yaml └── train_config.yaml ├── LICENSE ├── infer.py ├── evaluate.py ├── train.py ├── README_zh.MD ├── export_onnx.py ├── README.MD ├── dataset.py ├── onnx_infer.py └── binarize.py /ckpt/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /segments/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/binary/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dictionary/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/layer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/layer/scaling/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/layer/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/layer/activation/__init__.py: -------------------------------------------------------------------------------- 1 | from .GLU import * 2 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiuqiao/SOFA/HEAD/example.png -------------------------------------------------------------------------------- /modules/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .BinaryEMDLoss import BinaryEMDLoss 2 | from .GHMLoss import GHMLoss 3 | -------------------------------------------------------------------------------- /modules/layer/block/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import * 2 | from .func_module import * 3 | from .residual import * 4 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | import modules.layer.activation 2 | import modules.layer.block 3 | import modules.loss 4 | import modules.rmvpe 5 | import modules.task 6 | -------------------------------------------------------------------------------- /modules/g2p/__init__.py: -------------------------------------------------------------------------------- 1 | from modules.g2p.dictionary_g2p import DictionaryG2P 2 | from modules.g2p.none_g2p import NoneG2P 3 | from modules.g2p.phoneme_g2p import PhonemeG2P 4 | -------------------------------------------------------------------------------- /modules/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from modules.scheduler.gaussian_ramp_up_scheduler import GaussianRampUpScheduler 2 | from modules.scheduler.none_scheduler import NoneScheduler 3 | -------------------------------------------------------------------------------- /modules/rmvpe/constants.py: -------------------------------------------------------------------------------- 1 | SAMPLE_RATE = 16000 2 | 3 | N_CLASS = 360 4 | 5 | N_MELS = 128 6 | MEL_FMIN = 30 7 | MEL_FMAX = 8000 8 | WINDOW_LENGTH = 1024 9 | CONST = 1997.3794084376191 10 | -------------------------------------------------------------------------------- /modules/rmvpe/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import * 2 | from .model import E2E0 3 | from .utils import to_local_average_f0, to_viterbi_f0 4 | from .inference import RMVPE 5 | from .spec import MelSpectrogram 6 | -------------------------------------------------------------------------------- /modules/AP_detector/__init__.py: -------------------------------------------------------------------------------- 1 | from modules.AP_detector.loudnesss_pectralcentroid_detector import ( 2 | LoudnessSpectralcentroidAPDetector, 3 | ) 4 | from modules.AP_detector.none_detector import NoneAPDetector 5 | -------------------------------------------------------------------------------- /modules/layer/block/func_module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FuncModule(nn.Module): 5 | def __init__(self, func): 6 | super(FuncModule, self).__init__() 7 | self.func = func 8 | 9 | def forward(self, x): 10 | return self.func(x) 11 | -------------------------------------------------------------------------------- /modules/scheduler/none_scheduler.py: -------------------------------------------------------------------------------- 1 | class NoneScheduler: 2 | def __init__(self): 3 | pass 4 | 5 | def __call__(self): 6 | return 1 7 | 8 | def step(self): 9 | pass 10 | 11 | def resume(self, global_step): 12 | pass 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .*/ 2 | __*__ 3 | data/binary/* 4 | data/no_label/* 5 | data/weak_label/* 6 | data/full_label/* 7 | ckpt/* 8 | segments/* 9 | .vscode/* 10 | lightning_logs/* 11 | 12 | *:Zone.Identifier 13 | *.ipynb 14 | *.zip 15 | *.data 16 | *.idx 17 | *.wav 18 | *.csv 19 | *.TextGrid 20 | *.h5 21 | *.h5py 22 | **/tmp** 23 | **/test** 24 | 25 | !.gitkeep 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # It is recommended to install torch and torchaudio manually. 2 | 3 | click 4 | einops==0.6.1 5 | h5py 6 | librosa<0.10.0 7 | lightning>=2.0.0 8 | matplotlib~=3.7.3 9 | numpy~=1.24.1 10 | PyYAML~=6.0.1 11 | tensorboard 12 | tensorboardX 13 | tqdm~=4.66.1 14 | textgrid 15 | chardet 16 | numba 17 | 18 | # torch 19 | # torchaudio 20 | pandas~=2.0.3 21 | -------------------------------------------------------------------------------- /modules/rmvpe/seq.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BiGRU(nn.Module): 5 | def __init__(self, input_features, hidden_features, num_layers): 6 | super(BiGRU, self).__init__() 7 | self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) 8 | 9 | def forward(self, x): 10 | return self.gru(x)[0] 11 | -------------------------------------------------------------------------------- /modules/loss/BinaryEMDLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BinaryEMDLoss(torch.nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | self.loss = torch.nn.L1Loss() 8 | 9 | def forward(self, pred, target): 10 | # pred, target: [B,T] 11 | loss = self.loss(pred.cumsum(dim=-1), target.cumsum(dim=-1)) 12 | loss += self.loss( 13 | pred.flip([-1]).cumsum(dim=-1), target.flip([-1]).cumsum(dim=-1) 14 | ) 15 | return loss / 2 16 | -------------------------------------------------------------------------------- /modules/layer/scaling/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BaseDowmSampling(nn.Module): 5 | def __init__(self, input_dims, output_dims, down_sampling_factor=2): 6 | super(BaseDowmSampling, self).__init__() 7 | 8 | def forward(self, x): 9 | raise NotImplementedError 10 | 11 | 12 | class BaseUpSampling(nn.Module): 13 | def __init__(self, input_dims, output_dims, up_sampling_factor=2): 14 | super(BaseUpSampling, self).__init__() 15 | 16 | def forward(self, x): 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /configs/binarize_config.yaml: -------------------------------------------------------------------------------- 1 | melspec_config: 2 | n_mels: 128 3 | sample_rate: 44100 4 | win_length: 1024 5 | hop_length: 512 6 | n_fft: 2048 7 | fmin: 40 8 | fmax: 16000 9 | clamp: 0.00001 10 | scale_factor: 4 11 | 12 | data_folder: data/ 13 | valid_set_size: 15 14 | valid_set_preferred_folders: 15 | - test 16 | ignored_phonemes: 17 | - AP 18 | - SP 19 | - 20 | - 21 | - '' 22 | - pau 23 | - cl 24 | data_augmentation: 25 | size: 2 # If the data contains unlabeled data, it must be equal to or greater than 1. 26 | key_shift_choices: [ 1, 2, 3, 4, 5, 6,-1,-2,-3,-4,-5,-6 ] 27 | 28 | max_length: 45 # unit: second 29 | -------------------------------------------------------------------------------- /modules/layer/activation/GLU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | import torch.nn as nn 4 | 5 | 6 | class GLU(nn.Module): 7 | def __init__(self, input_dim, output_dim): 8 | super(GLU, self).__init__() 9 | self.linear = nn.Linear(input_dim, output_dim) 10 | self.projection = ( 11 | nn.Conv1d(input_dim, output_dim, 1) 12 | if input_dim != output_dim 13 | else nn.Identity() 14 | ) 15 | 16 | def forward(self, x): 17 | # input: Tensor[batch_size seq_len, hidden_dims] 18 | # output: Tensor[batch_size seq_len, hidden_dims] 19 | gate = torch.sigmoid(self.linear(x)) 20 | output = self.projection(x) * gate 21 | return output 22 | -------------------------------------------------------------------------------- /modules/layer/block/convolution.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from attention import ConformerBlock 3 | 4 | 5 | class SeparableConv1d(nn.Module): 6 | def __init__( 7 | self, 8 | in_channels, 9 | out_channels, 10 | kernel_size, 11 | stride=1, 12 | padding=0, 13 | dilation=1, 14 | bias=True, 15 | ): 16 | super(SeparableConv1d, self).__init__() 17 | self.conv1 = nn.Conv1d( 18 | in_channels, 19 | in_channels, 20 | kernel_size, 21 | stride, 22 | padding, 23 | dilation, 24 | groups=in_channels, 25 | bias=bias, 26 | ) 27 | self.point_wise = nn.Conv1d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = self.point_wise(x) 32 | return x 33 | -------------------------------------------------------------------------------- /modules/g2p/phoneme_g2p.py: -------------------------------------------------------------------------------- 1 | from modules.g2p.base_g2p import BaseG2P 2 | 3 | 4 | class PhonemeG2P(BaseG2P): 5 | def __init__(self, **kwargs): 6 | pass 7 | 8 | def _g2p(self, input_text): 9 | word_seq = input_text.strip().split(" ") 10 | word_seq = [ph for ph in word_seq if ph != "SP"] 11 | ph_seq = ["SP"] 12 | ph_idx_to_word_idx = [-1] 13 | for word_idx, word in enumerate(word_seq): 14 | ph_seq.append(word) 15 | ph_idx_to_word_idx.append(word_idx) 16 | ph_seq.append("SP") 17 | ph_idx_to_word_idx.append(-1) 18 | return ph_seq, word_seq, ph_idx_to_word_idx 19 | 20 | 21 | if __name__ == "__main__": 22 | pass 23 | grapheme_to_phoneme = PhonemeG2P() 24 | text = "wo shi yi ge xue sheng SP SP SP" 25 | ph_seq, word_seq, ph_idx_to_word_idx = grapheme_to_phoneme(text) 26 | print(ph_seq) 27 | print(word_seq) 28 | print(ph_idx_to_word_idx) 29 | -------------------------------------------------------------------------------- /modules/g2p/none_g2p.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from modules.g2p.base_g2p import BaseG2P 4 | 5 | 6 | class NoneG2P(BaseG2P): 7 | def __init__(self, **kwargs): 8 | pass 9 | 10 | def _g2p(self, input_text): 11 | input_seq = input_text.strip().split(" ") 12 | 13 | ph_seq = ["SP"] 14 | for i, ph in enumerate(input_seq): 15 | if ph == "SP" and ph_seq[-1] == "SP": 16 | continue 17 | ph_seq.append(ph) 18 | if ph_seq[-1] != "SP": 19 | ph_seq.append("SP") 20 | 21 | word_seq = ph_seq 22 | ph_idx_to_word_idx = np.arange(len(ph_seq)) 23 | 24 | return ph_seq, word_seq, ph_idx_to_word_idx 25 | 26 | 27 | if __name__ == "__main__": 28 | pass 29 | grapheme_to_phoneme = NoneG2P() 30 | text = "wo shi SP yi ge xue sheng" 31 | ph_seq, word_seq, ph_idx_to_word_idx = grapheme_to_phoneme(text) 32 | print(ph_seq) 33 | print(word_seq) 34 | print(ph_idx_to_word_idx) 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 suco 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /modules/scheduler/gaussian_ramp_up_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class GaussianRampUpScheduler: 5 | def __init__(self, max_steps, start_steps=None, end_steps=None): 6 | if end_steps is None: 7 | end_steps = max_steps 8 | if start_steps is None: 9 | start_steps = 0 10 | self.max_steps = max_steps 11 | self.start_steps = start_steps 12 | self.end_steps = end_steps 13 | self.curr_steps = 0 14 | 15 | def __call__(self): 16 | if self.curr_steps < self.start_steps: 17 | return 0 18 | elif self.curr_steps < self.end_steps: 19 | return np.exp( 20 | -5 21 | * ( 22 | 1 23 | - (self.curr_steps - self.start_steps) 24 | / (self.end_steps - self.start_steps) 25 | ) 26 | ** 2 27 | ) 28 | else: 29 | return 1 30 | 31 | def step(self): 32 | self.curr_steps += 1 33 | 34 | def resume(self, global_step): 35 | self.curr_steps = global_step 36 | -------------------------------------------------------------------------------- /modules/utils/get_melspec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import modules.rmvpe 4 | 5 | melspec_transform = None 6 | 7 | 8 | class MelSpecExtractor: 9 | def __init__( 10 | self, 11 | n_mels, 12 | sample_rate, 13 | win_length, 14 | hop_length, 15 | n_fft, 16 | fmin, 17 | fmax, 18 | clamp, 19 | device=None, 20 | scale_factor=None, 21 | ): 22 | global melspec_transform 23 | if device is None: 24 | device = "cuda" if torch.cuda.is_available() else "cpu" 25 | if melspec_transform is None: 26 | melspec_transform = modules.rmvpe.MelSpectrogram( 27 | n_mel_channels=n_mels, 28 | sampling_rate=sample_rate, 29 | win_length=win_length, 30 | hop_length=hop_length, 31 | n_fft=n_fft, 32 | mel_fmin=fmin, 33 | mel_fmax=fmax, 34 | clamp=clamp, 35 | ).to(device) 36 | 37 | def __call__(self, waveform, key_shift=0): 38 | return melspec_transform(waveform.unsqueeze(0), key_shift).squeeze(0) 39 | -------------------------------------------------------------------------------- /modules/rmvpe/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .constants import * 4 | from .deepunet import DeepUnet0 5 | from .seq import BiGRU 6 | 7 | 8 | class E2E0(nn.Module): 9 | def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, 10 | en_out_channels=16): 11 | super(E2E0, self).__init__() 12 | self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) 13 | self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) 14 | if n_gru: 15 | self.fc = nn.Sequential( 16 | BiGRU(3 * N_MELS, 256, n_gru), 17 | nn.Linear(512, N_CLASS), 18 | nn.Dropout(0.25), 19 | nn.Sigmoid() 20 | ) 21 | else: 22 | self.fc = nn.Sequential( 23 | nn.Linear(3 * N_MELS, N_CLASS), 24 | nn.Dropout(0.25), 25 | nn.Sigmoid() 26 | ) 27 | 28 | def forward(self, mel): 29 | mel = mel.transpose(-1, -2).unsqueeze(1) 30 | x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) 31 | x = self.fc(x) 32 | return x 33 | -------------------------------------------------------------------------------- /modules/AP_detector/none_detector.py: -------------------------------------------------------------------------------- 1 | from modules.AP_detector.base_detector import BaseAPDetector 2 | 3 | 4 | class NoneAPDetector(BaseAPDetector): 5 | def __init__(self, **kwargs): 6 | # args: list of str 7 | pass 8 | 9 | def _process_one( 10 | self, 11 | wav_path, 12 | wav_length, 13 | confidence, 14 | ph_seq, 15 | ph_intervals, 16 | word_seq, 17 | word_intervals, 18 | ): 19 | # input: 20 | # wav_path: pathlib.Path 21 | # ph_seq: list of phonemes, SP is the silence phoneme. 22 | # ph_intervals: np.ndarray of shape (n_ph, 2), ph_intervals[i] = [start, end] 23 | # means the i-th phoneme starts at start and ends at end. 24 | # word_seq: list of words. 25 | # word_intervals: np.ndarray of shape (n_word, 2), word_intervals[i] = [start, end] 26 | 27 | # output: same as the input. 28 | return ( 29 | wav_path, 30 | wav_length, 31 | confidence, 32 | ph_seq, 33 | ph_intervals, 34 | word_seq, 35 | word_intervals, 36 | ) 37 | -------------------------------------------------------------------------------- /configs/train_config.yaml: -------------------------------------------------------------------------------- 1 | model_name: mandarin_opencpop-extension 2 | 3 | # settings 4 | float32_matmul_precision: high 5 | random_seed: 114514 6 | 7 | # dataloader 8 | dataloader_workers: 0 9 | dataloader_prefetch_factor: 2 10 | oversampling_weights: [ 1, 1, 1 ] # full_label, weak_label, no_label 11 | batch_max_length: 100 # unit: seconds 12 | binning_length: 1000 # unit: seconds 13 | drop_last: False 14 | 15 | # model 16 | model: 17 | hidden_dims: 128 18 | down_sampling_factor: 3 19 | down_sampling_times: 7 20 | channels_scaleup_factor: 1.5 21 | 22 | optimizer_config: 23 | total_steps: 100000 24 | weight_decay: 0.1 25 | lr: 26 | backbone: 0.0005 27 | head: 0.0005 28 | freeze: 29 | backbone: False 30 | head: False 31 | 32 | loss_config: 33 | losses: 34 | weights: [10.0, 0.1, 0.01, 0.1, 1.0, 1.0, 5.0] 35 | enable_RampUpScheduler: [False,False,False,True,True,True,True] 36 | function: 37 | num_bins: 10 38 | alpha: 0.999 39 | label_smoothing: 0.08 40 | pseudo_label_ratio: 0.3 41 | 42 | # trainer 43 | accelerator: auto 44 | devices: auto # num_devices 45 | precision: bf16-mixed # bf16-mixed , 32-true 46 | gradient_clip_val: 1.0 47 | gradient_clip_algorithm: norm #value 48 | val_check_interval: 500 # 0.25 -------------------------------------------------------------------------------- /modules/utils/load_wav.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import librosa 4 | import torch 5 | 6 | 7 | def check_and_import(package_name): 8 | try: 9 | module = importlib.import_module(package_name) 10 | globals()[package_name] = importlib.import_module(package_name) 11 | print(f"'{package_name}' installed and imported.") 12 | return True, module 13 | except ImportError: 14 | print(f"'{package_name}' not installed.") 15 | return False, None 16 | 17 | 18 | installed_torchaudio, torchaudio = check_and_import("torchaudio") 19 | resample_transform_dict = {} 20 | 21 | 22 | def load_wav(path, device, sample_rate=None): 23 | global installed_torchaudio 24 | if installed_torchaudio: 25 | waveform, sr = torchaudio.load(str(path)) 26 | if sample_rate != sr and sample_rate is not None: 27 | global resample_transform_dict 28 | if sr not in resample_transform_dict: 29 | resample_transform_dict[sr] = torchaudio.transforms.Resample( 30 | sr, sample_rate 31 | ) 32 | 33 | waveform = resample_transform_dict[sr](waveform) 34 | 35 | waveform = waveform[0].to(device) 36 | 37 | else: 38 | waveform, _ = librosa.load(path, sr=sample_rate, mono=True) 39 | waveform = torch.from_numpy(waveform).to(device) 40 | 41 | return waveform 42 | -------------------------------------------------------------------------------- /modules/layer/scaling/stride_conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from modules.layer.scaling.base import BaseDowmSampling, BaseUpSampling 4 | 5 | 6 | class DownSampling(BaseDowmSampling): 7 | def __init__(self, input_dims, output_dims, down_sampling_factor=2): 8 | super(DownSampling, self).__init__( 9 | input_dims, output_dims, down_sampling_factor=2 10 | ) 11 | 12 | self.input_dims = input_dims 13 | self.output_dims = output_dims 14 | self.down_sampling_factor = down_sampling_factor 15 | 16 | self.conv = nn.Conv1d( 17 | self.input_dims, 18 | self.output_dims, 19 | kernel_size=down_sampling_factor, 20 | stride=down_sampling_factor, 21 | ) 22 | 23 | def forward(self, x): 24 | x = x.transpose(1, 2) 25 | padding_len = x.shape[-1] % self.down_sampling_factor 26 | if padding_len != 0: 27 | x = nn.functional.pad(x, (0, self.down_sampling_factor - padding_len)) 28 | return self.conv(x).transpose(1, 2) 29 | 30 | 31 | class UpSampling(BaseUpSampling): 32 | def __init__(self, input_dims, output_dims, up_sampling_factor=2): 33 | super(UpSampling, self).__init__(input_dims, output_dims, up_sampling_factor=2) 34 | 35 | self.input_dims = input_dims 36 | self.output_dims = output_dims 37 | self.up_sampling_factor = up_sampling_factor 38 | 39 | self.conv = nn.ConvTranspose1d( 40 | self.input_dims, 41 | self.output_dims, 42 | kernel_size=up_sampling_factor, 43 | stride=up_sampling_factor, 44 | ) 45 | 46 | def forward(self, x): 47 | return self.conv(x.transpose(1, 2)).transpose(1, 2) 48 | -------------------------------------------------------------------------------- /modules/rmvpe/utils.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | 5 | from .constants import * 6 | 7 | 8 | def to_local_average_f0(hidden, center=None, thred=0.03): 9 | idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N] 10 | idx_cents = idx * 20 + CONST # [B=1, N] 11 | if center is None: 12 | center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1] 13 | start = torch.clip(center - 4, min=0) # [B, T, 1] 14 | end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1] 15 | idx_mask = (idx >= start) & (idx < end) # [B, T, N] 16 | weights = hidden * idx_mask # [B, T, N] 17 | product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T] 18 | weight_sum = torch.sum(weights, dim=2) # [B, T] 19 | cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T] 20 | f0 = 10 * 2 ** (cents / 1200) 21 | uv = hidden.max(dim=2)[0] < thred # [B, T] 22 | f0 = f0 * ~uv 23 | return f0.squeeze(0).cpu().numpy() 24 | 25 | 26 | def to_viterbi_f0(hidden, thred=0.03): 27 | # Create viterbi transition matrix 28 | if not hasattr(to_viterbi_f0, 'transition'): 29 | xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS)) 30 | transition = np.maximum(30 - abs(xx - yy), 0) 31 | transition = transition / transition.sum(axis=1, keepdims=True) 32 | to_viterbi_f0.transition = transition 33 | 34 | # Convert to probability 35 | prob = hidden.squeeze(0).cpu().numpy() 36 | prob = prob.T 37 | prob = prob / prob.sum(axis=0) 38 | 39 | # Perform viterbi decoding 40 | path = librosa.sequence.viterbi(prob, to_viterbi_f0.transition).astype(np.int64) 41 | center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device) 42 | 43 | return to_local_average_f0(hidden, center=center, thred=thred) 44 | -------------------------------------------------------------------------------- /modules/layer/block/residual.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class Residual(nn.Module): 6 | def __init__(self, input_dims: int, output_dims: int, dim: int = -1) -> None: 7 | """ 8 | Residual module for residual connections in neural networks. 9 | 10 | Args: 11 | input_dims (int): Number of input dimensions. 12 | output_dims (int): Number of output dimensions. 13 | dim (int, optional): Dimension along which to apply the residual connection. Defaults to -1. 14 | """ 15 | super().__init__() 16 | self.dim: int = dim 17 | self.output_dims: int = output_dims 18 | self.input_dims: int = input_dims 19 | self.projection: nn.Module = ( 20 | nn.Linear(input_dims, output_dims) 21 | if input_dims != output_dims 22 | else nn.Identity() 23 | ) 24 | 25 | def forward(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor: 26 | """ 27 | Forward pass of the residual module. 28 | 29 | Args: 30 | x (torch.Tensor): Input tensor. 31 | out (torch.Tensor): Output tensor. 32 | 33 | Returns: 34 | torch.Tensor: Result of the residual connection. 35 | """ 36 | x = torch.transpose(x, -1, self.dim) 37 | out = torch.transpose(out, -1, self.dim) 38 | if x.shape[-1] != self.input_dims: 39 | raise ValueError( 40 | f"Dimension mismatch: expected input dimension {self.input_dims}, but got {x.shape[-1]}." 41 | ) 42 | if out.shape[-1] != self.output_dims: 43 | raise ValueError( 44 | f"Dimension mismatch: expected output dimension {self.output_dims}, but got {out.shape[-1]}." 45 | ) 46 | return torch.transpose(out + self.projection(x), -1, self.dim) 47 | 48 | 49 | if __name__ == "__main__": 50 | model = Residual(2, 3) 51 | x1 = torch.randn(2, 2, 2) 52 | x2 = torch.randn(2, 2, 3) 53 | y = model(x1, x2) 54 | print(y.shape) 55 | -------------------------------------------------------------------------------- /modules/g2p/dictionary_g2p.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from modules.g2p.base_g2p import BaseG2P 4 | 5 | 6 | class DictionaryG2P(BaseG2P): 7 | def __init__(self, **kwargs): 8 | dict_path = kwargs["dictionary"] 9 | with open(dict_path, "r") as f: 10 | dictionary = f.read().strip().split("\n") 11 | self.dictionary = { 12 | item.split("\t")[0].strip(): item.split("\t")[1].strip().split(" ") 13 | for item in dictionary 14 | } 15 | 16 | def _g2p(self, input_text): 17 | word_seq_raw = input_text.strip().split(" ") 18 | word_seq = [] 19 | word_seq_idx = 0 20 | ph_seq = ["SP"] 21 | ph_idx_to_word_idx = [-1] 22 | for word in word_seq_raw: 23 | if word not in self.dictionary: 24 | warnings.warn(f"Word {word} is not in the dictionary. Ignored.") 25 | continue 26 | word_seq.append(word) 27 | phones = self.dictionary[word] 28 | for i, ph in enumerate(phones): 29 | if (i == 0 or i == len(phones) - 1) and ph == "SP": 30 | warnings.warn( 31 | f"The first or last phoneme of word {word} is SP, which is not allowed. " 32 | "Please check your dictionary." 33 | ) 34 | continue 35 | ph_seq.append(ph) 36 | ph_idx_to_word_idx.append(word_seq_idx) 37 | if ph_seq[-1] != "SP": 38 | ph_seq.append("SP") 39 | ph_idx_to_word_idx.append(-1) 40 | word_seq_idx += 1 41 | 42 | return ph_seq, word_seq, ph_idx_to_word_idx 43 | 44 | 45 | if __name__ == "__main__": 46 | pass 47 | grapheme_to_phoneme = DictionaryG2P( 48 | **{"dictionary": "/home/qq/Project/SOFA/dictionary/opencpop-extension.txt"} 49 | ) 50 | text = "wo SP shi yi ge xue sheng a" 51 | ph_seq, word_seq, ph_idx_to_word_idx = grapheme_to_phoneme(text) 52 | print(ph_seq) 53 | print(word_seq) 54 | print(ph_idx_to_word_idx) 55 | -------------------------------------------------------------------------------- /modules/AP_detector/base_detector.py: -------------------------------------------------------------------------------- 1 | class BaseAPDetector: 2 | def __init__(self, **kwargs): 3 | # args: list of str 4 | pass 5 | 6 | def process(self, predictions): 7 | # input: list of predictions, each prediction is a tuple of: 8 | # wav_path: pathlib.Path 9 | # wav_length: float 10 | # confidence: float 11 | # ph_seq: list of phonemes, SP is the silence phoneme. 12 | # ph_intervals: np.ndarray of shape (n_ph, 2), ph_intervals[i] = [start, end] 13 | # means the i-th phoneme starts at start and ends at end. 14 | # word_seq: list of words. 15 | # word_intervals: np.ndarray of shape (n_word, 2), word_intervals[i] = [start, end] 16 | 17 | # output: same as the input. 18 | 19 | res = [] 20 | for ( 21 | wav_path, 22 | wav_length, 23 | confidence, 24 | ph_seq, 25 | ph_intervals, 26 | word_seq, 27 | word_intervals, 28 | ) in predictions: 29 | prediction = self._process_one( 30 | wav_path, 31 | wav_length, 32 | confidence, 33 | ph_seq, 34 | ph_intervals, 35 | word_seq, 36 | word_intervals, 37 | ) 38 | res.append(prediction) 39 | 40 | return res 41 | 42 | def _process_one( 43 | self, 44 | wav_path, 45 | wav_length, 46 | confidence, 47 | ph_seq, 48 | ph_intervals, 49 | word_seq, 50 | word_intervals, 51 | ): 52 | # input: 53 | # wav_path: pathlib.Path 54 | # wav_length: float 55 | # confidence: float 56 | # ph_seq: list of phonemes, SP is the silence phoneme. 57 | # ph_intervals: np.ndarray of shape (n_ph, 2), ph_intervals[i] = [start, end] 58 | # means the i-th phoneme starts at start and ends at end. 59 | # word_seq: list of words. 60 | # word_intervals: np.ndarray of shape (n_word, 2), word_intervals[i] = [start, end] 61 | 62 | # output: same as the input. 63 | raise NotImplementedError 64 | -------------------------------------------------------------------------------- /modules/utils/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | 5 | def plot_for_valid( 6 | melspec, 7 | ph_seq, 8 | ph_intervals, 9 | frame_confidence, 10 | ph_frame_prob, 11 | ph_frame_id_gt, 12 | edge_prob, 13 | ): 14 | ph_seq = [i.split("/")[-1] for i in ph_seq] 15 | x = np.arange(melspec.shape[-1]) 16 | 17 | fig, (ax1, ax2) = plt.subplots(2) 18 | ax1.imshow(melspec[0], origin="lower", aspect="auto") 19 | 20 | for i, interval in enumerate(ph_intervals): 21 | if i == 0 or (i > 0 and ph_intervals[i - 1, 1] != interval[0]): 22 | if interval[0] > 0: 23 | ax1.axvline(interval[0], color="r", linewidth=1) 24 | if interval[1] < melspec.shape[-1]: 25 | ax1.axvline(interval[1], color="r", linewidth=1) 26 | if ph_seq[i] != "SP": 27 | if i % 2: 28 | ax1.text( 29 | (interval[0] + interval[1]) / 2 30 | - len(ph_seq[i]) * melspec.shape[-1] / 275, 31 | melspec.shape[-2] + 1, 32 | ph_seq[i], 33 | fontsize=11, 34 | color="black", 35 | ) 36 | else: 37 | ax1.text( 38 | (interval[0] + interval[1]) / 2 39 | - len(ph_seq[i]) * melspec.shape[-1] / 275, 40 | melspec.shape[-2] - 6, 41 | ph_seq[i], 42 | fontsize=11, 43 | color="white", 44 | ) 45 | 46 | ax1.plot( 47 | x, frame_confidence * melspec.shape[-2], color="black", linewidth=1, alpha=0.6 48 | ) 49 | ax1.fill_between(x, frame_confidence * melspec.shape[-2], color="black", alpha=0.3) 50 | 51 | ax2.imshow( 52 | ph_frame_prob.T, 53 | origin="lower", 54 | aspect="auto", 55 | interpolation="nearest", 56 | # vmin=0, 57 | # vmax=1, 58 | ) 59 | 60 | ax2.plot(x, ph_frame_id_gt, color="red", linewidth=1.5) 61 | # ax2.scatter(x, ph_frame_id_gt, s=5, marker='s', color="red") 62 | 63 | ax2.plot(x, edge_prob * ph_frame_prob.shape[-1], color="black", linewidth=1) 64 | ax2.fill_between(x, edge_prob * ph_frame_prob.shape[-1], color="black", alpha=0.3) 65 | 66 | fig.set_size_inches(13, 7) 67 | plt.subplots_adjust(hspace=0) 68 | plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05) 69 | 70 | return fig 71 | -------------------------------------------------------------------------------- /modules/rmvpe/spec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from librosa.filters import mel 5 | 6 | 7 | class MelSpectrogram(torch.nn.Module): 8 | def __init__( 9 | self, 10 | n_mel_channels, 11 | sampling_rate, 12 | win_length, 13 | hop_length, 14 | n_fft=None, 15 | mel_fmin=0, 16 | mel_fmax=None, 17 | clamp=1e-5 18 | ): 19 | super().__init__() 20 | n_fft = win_length if n_fft is None else n_fft 21 | self.hann_window = {} 22 | mel_basis = mel( 23 | sr=sampling_rate, 24 | n_fft=n_fft, 25 | n_mels=n_mel_channels, 26 | fmin=mel_fmin, 27 | fmax=mel_fmax, 28 | htk=True) 29 | mel_basis = torch.from_numpy(mel_basis).float() 30 | self.register_buffer("mel_basis", mel_basis) 31 | self.n_fft = win_length if n_fft is None else n_fft 32 | self.hop_length = hop_length 33 | self.win_length = win_length 34 | self.sampling_rate = sampling_rate 35 | self.n_mel_channels = n_mel_channels 36 | self.clamp = clamp 37 | 38 | def forward(self, audio, keyshift=0, speed=1, center=True): 39 | factor = 2 ** (keyshift / 12) 40 | n_fft_new = int(np.round(self.n_fft * factor)) 41 | win_length_new = int(np.round(self.win_length * factor)) 42 | hop_length_new = int(np.round(self.hop_length * speed)) 43 | 44 | keyshift_key = str(keyshift) + '_' + str(audio.device) 45 | if keyshift_key not in self.hann_window: 46 | self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) 47 | if center: 48 | pad_left = n_fft_new // 2 49 | pad_right = (n_fft_new + 1) // 2 50 | audio = F.pad(audio, (pad_left, pad_right)) 51 | 52 | fft = torch.stft( 53 | audio, 54 | n_fft=n_fft_new, 55 | hop_length=hop_length_new, 56 | win_length=win_length_new, 57 | window=self.hann_window[keyshift_key], 58 | center=False, 59 | return_complex=True 60 | ) 61 | magnitude = fft.abs() 62 | 63 | if keyshift != 0: 64 | size = self.n_fft // 2 + 1 65 | resize = magnitude.size(1) 66 | if resize < size: 67 | magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) 68 | magnitude = magnitude[:, :size, :] * self.win_length / win_length_new 69 | 70 | mel_output = torch.matmul(self.mel_basis, magnitude) 71 | log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) 72 | return log_mel_spec 73 | -------------------------------------------------------------------------------- /modules/rmvpe/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torchaudio.transforms import Resample 5 | 6 | # from utils.pitch_utils import interp_f0, resample_align_curve 7 | from .constants import * 8 | from .model import E2E0 9 | from .spec import MelSpectrogram 10 | from .utils import to_local_average_f0, to_viterbi_f0 11 | 12 | 13 | class RMVPE: 14 | def __init__(self, model_path, hop_length=160, device=None): 15 | self.resample_kernel = {} 16 | if device is None: 17 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 18 | else: 19 | self.device = device 20 | self.model = E2E0(4, 1, (2, 2)).eval().to(self.device) 21 | ckpt = torch.load(model_path, map_location=self.device) 22 | self.model.load_state_dict(ckpt["model"], strict=False) 23 | self.mel_extractor = MelSpectrogram( 24 | N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX 25 | ).to(self.device) 26 | 27 | @torch.no_grad() 28 | def mel2hidden(self, mel): 29 | n_frames = mel.shape[-1] 30 | mel = F.pad( 31 | mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="constant" 32 | ) 33 | hidden = self.model(mel) 34 | return hidden[:, :n_frames] 35 | 36 | def decode(self, hidden, thred=0.03, use_viterbi=False): 37 | if use_viterbi: 38 | f0 = to_viterbi_f0(hidden, thred=thred) 39 | else: 40 | f0 = to_local_average_f0(hidden, thred=thred) 41 | return f0 42 | 43 | def infer_from_audio(self, audio, sample_rate=16000, thred=0.03, use_viterbi=False): 44 | audio = torch.from_numpy(audio).float().unsqueeze(0).to(self.device) 45 | if sample_rate == 16000: 46 | audio_res = audio 47 | else: 48 | key_str = str(sample_rate) 49 | if key_str not in self.resample_kernel: 50 | self.resample_kernel[key_str] = Resample( 51 | sample_rate, 16000, lowpass_filter_width=128 52 | ) 53 | self.resample_kernel[key_str] = self.resample_kernel[key_str].to( 54 | self.device 55 | ) 56 | audio_res = self.resample_kernel[key_str](audio) 57 | mel = self.mel_extractor(audio_res, center=True) 58 | hidden = self.mel2hidden(mel) 59 | f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi) 60 | return f0 61 | 62 | def get_pitch(self, waveform, sample_rate, hop_size, length, interp_uv=False): 63 | f0 = self.infer_from_audio(waveform, sample_rate=sample_rate) 64 | uv = f0 == 0 65 | f0, uv = interp_f0(f0, uv) 66 | 67 | time_step = hop_size / sample_rate 68 | f0_res = resample_align_curve(f0, 0.01, time_step, length) 69 | uv_res = ( 70 | resample_align_curve(uv.astype(np.float32), 0.01, time_step, length) > 0.5 71 | ) 72 | if not interp_uv: 73 | f0_res[uv_res] = 0 74 | return f0_res, uv_res 75 | -------------------------------------------------------------------------------- /modules/g2p/base_g2p.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | 4 | 5 | class DataFrameDataset(torch.utils.data.Dataset): 6 | def __init__(self, dataframe): 7 | self.dataset = dataframe 8 | 9 | def __getitem__(self, index): 10 | return tuple(self.dataset.iloc[index]) 11 | 12 | def __len__(self): 13 | return len(self.dataset) 14 | 15 | 16 | class BaseG2P: 17 | def __init__(self, **kwargs): 18 | # args: list of str 19 | self.in_format = "lab" 20 | 21 | def _g2p(self, input_text): 22 | # input text, return phoneme sequence, word sequence, and phoneme index to word index mapping 23 | # ph_seq: list of phonemes, SP is the silence phoneme. 24 | # word_seq: list of words. 25 | # ph_idx_to_word_idx: ph_idx_to_word_idx[i] = j means the i-th phoneme belongs to the j-th word. 26 | # if ph_idx_to_word_idx[i] = -1, the i-th phoneme is a silence phoneme. 27 | # example: ph_seq = ['SP', 'ay', 'SP', 'ae', 'm', 'SP', 'ah', 'SP', 's', 't', 'uw', 'd', 'ah', 'n', 't', 'SP'] 28 | # word_seq = ['I', 'am', 'a', 'student'] 29 | # ph_idx_to_word_idx = [-1, 0, -1, 1, 1, -1, 2, -1, 3, 3, 3, 3, 3, 3, 3, -1] 30 | raise NotImplementedError 31 | 32 | def __call__(self, text): 33 | ph_seq, word_seq, ph_idx_to_word_idx = self._g2p(text) 34 | 35 | # The first and last phonemes should be `SP`, 36 | # and there should not be more than two consecutive `SP`s at any position. 37 | assert ph_seq[0] == "SP" and ph_seq[-1] == "SP" 38 | assert all( 39 | ph_seq[i] != "SP" or ph_seq[i + 1] != "SP" for i in range(len(ph_seq) - 1) 40 | ) 41 | return ph_seq, word_seq, ph_idx_to_word_idx 42 | 43 | def set_in_format(self, in_format): 44 | self.in_format = in_format 45 | 46 | def get_dataset(self, wav_paths): 47 | # dataset is a pandas dataframe with columns: wav_path, ph_seq, word_seq, ph_idx_to_word_idx 48 | dataset = [] 49 | for wav_path in wav_paths: 50 | try: 51 | if wav_path.with_suffix("." + self.in_format).exists(): 52 | with open( 53 | wav_path.with_suffix("." + self.in_format), 54 | "r", 55 | encoding="utf-8", 56 | ) as f: 57 | lab_text = f.read().strip() 58 | ph_seq, word_seq, ph_idx_to_word_idx = self(lab_text) 59 | dataset.append((wav_path, ph_seq, word_seq, ph_idx_to_word_idx)) 60 | except Exception as e: 61 | e.args = (f" Error when processing {wav_path}: {e} ",) 62 | raise e 63 | if len(dataset) <= 0: 64 | raise ValueError("No valid data found.") 65 | print(f"Loaded {len(dataset)} samples.") 66 | 67 | dataset = pd.DataFrame( 68 | dataset, columns=["wav_path", "ph_seq", "word_seq", "ph_idx_to_word_idx"] 69 | ) 70 | dataset = DataFrameDataset(dataset) 71 | return dataset 72 | -------------------------------------------------------------------------------- /modules/layer/block/resnet_block.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ResidualBasicBlock(nn.Module): 5 | def __init__(self, input_dims, output_dims, hidden_dims=None, n_groups=16): 6 | super(ResidualBasicBlock, self).__init__() 7 | 8 | self.input_dims = input_dims 9 | self.output_dims = output_dims 10 | self.hidden_dims = ( 11 | hidden_dims 12 | if hidden_dims is not None 13 | else max(n_groups * (output_dims // n_groups), n_groups) 14 | ) 15 | self.n_groups = n_groups 16 | 17 | self.block = nn.Sequential( 18 | nn.Conv1d( 19 | self.input_dims, 20 | self.hidden_dims, 21 | kernel_size=3, 22 | padding=1, 23 | bias=False, 24 | ), 25 | nn.GroupNorm(self.n_groups, self.hidden_dims), 26 | nn.Hardswish(), 27 | nn.Conv1d( 28 | self.hidden_dims, 29 | self.output_dims, 30 | kernel_size=3, 31 | padding=1, 32 | bias=False, 33 | ), 34 | ) 35 | 36 | self.shortcut = nn.Sequential( 37 | nn.Linear(self.input_dims, self.output_dims, bias=False) 38 | if self.input_dims != self.output_dims 39 | else nn.Identity() 40 | ) 41 | 42 | self.out = nn.Sequential( 43 | nn.LayerNorm(self.output_dims), 44 | nn.Hardswish(), 45 | ) 46 | 47 | def forward(self, x): 48 | x = self.block(x.transpose(1, 2)).transpose(1, 2) + self.shortcut(x) 49 | x = self.out(x) 50 | return x 51 | 52 | 53 | class ResidualBottleNeckBlock(nn.Module): 54 | def __init__(self, input_dims, output_dims, hidden_dims=None, n_groups=16): 55 | super(ResidualBottleNeckBlock, self).__init__() 56 | 57 | self.input_dims = input_dims 58 | self.output_dims = output_dims 59 | self.hidden_dims = ( 60 | hidden_dims 61 | if hidden_dims is not None 62 | else max(n_groups * ((output_dims // 4) // n_groups), n_groups) 63 | ) 64 | self.n_groups = n_groups 65 | 66 | self.input_proj = nn.Linear(self.input_dims, self.hidden_dims, bias=False) 67 | self.conv = nn.Sequential( 68 | nn.GroupNorm(self.n_groups, self.hidden_dims), 69 | nn.Hardswish(), 70 | nn.Conv1d( 71 | self.hidden_dims, 72 | self.hidden_dims, 73 | kernel_size=3, 74 | padding=1, 75 | bias=False, 76 | ), 77 | nn.GroupNorm(self.n_groups, self.hidden_dims), 78 | nn.Hardswish(), 79 | ) 80 | self.output_proj = nn.Linear(self.hidden_dims, self.output_dims, bias=False) 81 | 82 | self.shortcut = nn.Sequential( 83 | nn.Linear(self.input_dims, self.output_dims) 84 | if self.input_dims != self.output_dims 85 | else nn.Identity() 86 | ) 87 | 88 | self.out = nn.Sequential( 89 | nn.LayerNorm(self.output_dims), 90 | nn.Hardswish(), 91 | ) 92 | 93 | def forward(self, x): 94 | h = self.input_proj(x) 95 | h = self.conv(h.transpose(1, 2)).transpose(1, 2) 96 | h = self.output_proj(h) 97 | return self.out(h + self.shortcut(x)) 98 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import click 4 | import lightning as pl 5 | import torch 6 | 7 | import modules.AP_detector 8 | import modules.g2p 9 | from modules.utils.export_tool import Exporter 10 | from modules.utils.post_processing import post_processing 11 | from train import LitForcedAlignmentTask 12 | 13 | 14 | @click.command() 15 | @click.option( 16 | "--ckpt", 17 | "-c", 18 | default=None, 19 | required=True, 20 | type=str, 21 | help="path to the checkpoint", 22 | ) 23 | @click.option( 24 | "--folder", "-f", default="segments", type=str, help="path to the input folder" 25 | ) 26 | @click.option( 27 | "--mode", "-m", default="force", type=click.Choice(["force", "match"]) 28 | ) # TODO: add asr mode 29 | @click.option( 30 | "--g2p", "-g", default="Dictionary", type=str, help="name of the g2p class" 31 | ) 32 | @click.option( 33 | "--ap_detector", 34 | "-a", 35 | default="LoudnessSpectralcentroidAPDetector", # "NoneAPDetector", 36 | type=str, 37 | help="name of the AP detector class", 38 | ) 39 | @click.option( 40 | "--in_format", 41 | "-if", 42 | default="lab", 43 | required=False, 44 | type=str, 45 | help="File extension of input transcriptions. Default: lab", 46 | ) 47 | @click.option( 48 | "--out_formats", 49 | "-of", 50 | default="textgrid,htk,trans", 51 | required=False, 52 | type=str, 53 | help="Types of output file, separated by comma. Supported types:" 54 | "textgrid(praat)," 55 | " htk(lab,nnsvs,sinsy)," 56 | " transcriptions.csv(diffsinger,trans,transcription,transcriptions)", 57 | ) 58 | @click.option( 59 | "--save_confidence", 60 | "-sc", 61 | is_flag=True, 62 | default=False, 63 | show_default=True, 64 | help="save confidence.csv", 65 | ) 66 | @click.option( 67 | "--dictionary", 68 | "-d", 69 | default="dictionary/opencpop-extension.txt", 70 | type=str, 71 | help="(only used when --g2p=='Dictionary') path to the dictionary", 72 | ) 73 | def main( 74 | ckpt, 75 | folder, 76 | mode, 77 | g2p, 78 | ap_detector, 79 | in_format, 80 | out_formats, 81 | save_confidence, 82 | **kwargs, 83 | ): 84 | if not g2p.endswith("G2P"): 85 | g2p += "G2P" 86 | g2p_class = getattr(modules.g2p, g2p) 87 | grapheme_to_phoneme = g2p_class(**kwargs) 88 | out_formats = [i.strip().lower() for i in out_formats.split(",")] 89 | 90 | if not ap_detector.endswith("APDetector"): 91 | ap_detector += "APDetector" 92 | AP_detector_class = getattr(modules.AP_detector, ap_detector) 93 | get_AP = AP_detector_class(**kwargs) 94 | 95 | grapheme_to_phoneme.set_in_format(in_format) 96 | dataset = grapheme_to_phoneme.get_dataset(pathlib.Path(folder).rglob("*.wav")) 97 | 98 | torch.set_grad_enabled(False) 99 | model = LitForcedAlignmentTask.load_from_checkpoint(ckpt) 100 | model.set_inference_mode(mode) 101 | trainer = pl.Trainer(logger=False) 102 | predictions = trainer.predict(model, dataloaders=dataset, return_predictions=True) 103 | 104 | predictions = get_AP.process(predictions) 105 | predictions, log = post_processing(predictions) 106 | exporter = Exporter(predictions, log) 107 | 108 | if save_confidence: 109 | out_formats.append('confidence') 110 | 111 | exporter.export(out_formats) 112 | 113 | print("Output files are saved to the same folder as the input wav files.") 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /modules/g2p/readme_g2p_zh.md: -------------------------------------------------------------------------------- 1 | # g2p模块使用说明 2 | 3 | ## 1. 简介 4 | 5 | g2p模块是用于推理时,从`.lab`文件中读取文本,然后将文本转换为音素的模块。 6 | 7 | ## 2. 使用方法 8 | 9 | 在使用`infer.py`进行命令行推理时,添加`--g2p`参数,指定g2p模块名称(默认为Dictionary,可以省略名称的G2P后缀) 10 | 11 | 有些g2p模块需要额外的参数。例如,Dictionary G2P需要额外指定参数`--dictionary_path`,用于指定字典文件的路径。 12 | 13 | 不同模块需要的参数不同,具体参数请参考各个模块的说明。 14 | 15 | 例:如果你想使用Dictionary G2P,你可以这样运行`infer.py`: 16 | 17 | ```shell 18 | python infer.py --g2p Dictionary --dictionary_path /path/to/dictionary.txt 19 | ``` 20 | 21 | 或:如果你想使用Phoneme G2P,你可以这样运行`infer.py`: 22 | 23 | ```shell 24 | python infer.py --g2p PhonemeG2P 25 | ``` 26 | 27 | ## 3. 模块列表 28 | 29 | ### 3.1 Dictionary G2P 30 | 31 | #### 3.1.1 简介 32 | 33 | Dictionary G2P是一个基于字典的g2p模块。它会从`.lab`文件中读取词语序列,然后将词语序列转换为音素序列。 34 | 35 | 在词语和词语之间,会自动插入SP音素,在词语之内,不会插入SP音素。 36 | 37 | #### 3.1.2 输入格式 38 | 39 | 输入的`.lab`文件仅有一行,内容为词语序列,词语之间用空格分隔。 40 | 41 | 例: 42 | 43 | ```text 44 | I am a student 45 | ``` 46 | 47 | #### 3.1.3 参数 48 | 49 | | 参数名 | 类型 | 默认值 | 说明 | 50 | |-----------------|--------|-----|---------------------------------------------------------------------------------------------------------------------| 51 | | dictionary_path | string | 无 | 字典文件的路径。字典文件的格式为:每一行一个词条,每个词条包含一个单词和一个或多个音素,单词和音素之间用`\t`隔开,多个音素之间用空格隔开。有关示例,请参阅`dictionary/opencpop-extension.txt`。 | 52 | 53 | ### 3.2 Phoneme G2P 54 | 55 | #### 3.2.1 简介 56 | 57 | Phoneme G2P是一个基于音素的g2p模块。它会直接从`.lab`文件中读取音素序列,并在每个音素之间插入`SP`音素。 58 | 59 | #### 3.2.2 输入格式 60 | 61 | 输入的`.lab`文件仅有一行,内容为音素序列,音素之间用空格分隔,最好不要包含`SP`音素。 62 | 63 | 例: 64 | 65 | ```text 66 | ay ae m ah s t uw d ah n t 67 | ``` 68 | 69 | #### 3.2.3 参数 70 | 71 | 无 72 | 73 | ### 3.3 None G2P 74 | 75 | #### 3.3.1 简介 76 | 77 | None G2P是一个空的g2p模块。它会直接从`.lab`文件中读取音素序列,并且不会进行任何处理。 78 | 79 | 你可以把它视为不插入`SP`音素的Phoneme G2P。 80 | 81 | #### 3.3.2 输入格式 82 | 83 | 输入的`.lab`文件仅有一行,内容为音素序列,音素之间用空格分隔。 84 | 85 | 例: 86 | 87 | ```text 88 | SP ay SP ae m SP ah SP s t uw d ah n t SP 89 | ``` 90 | 91 | #### 3.3.3 参数 92 | 93 | 无 94 | 95 | ## 4. 自定义g2p模块 96 | 97 | 如果你想自定义g2p模块,你需要继承`base_g2p.py`中的`BaseG2P`类,并实现其中的`__init__`和`_g2p`方法。 98 | 99 | `__init__`方法用于初始化模块,`_g2p`方法用于将文本转换为音素序列、词语序列以及两个序列之间的对应关系。 100 | 101 | ### 4.1 `__init__`方法 102 | 103 | `__init__`方法的参数为`**kwargs`,由`infer.py`中的命令行参数`**g2p_args`传入。 104 | 105 | 如果你想使用额外的参数,你可以在`infer.py`中添加额外的命令行参数,然后在`__init__`方法中接收这些参数。 106 | 107 | 例:如果你想使用额外的参数`--my_param`,你可以在infer.py中添加这个参数: 108 | 109 | ```python 110 | ... 111 | 112 | 113 | @click.option('--my_param', type=str, default=None, help='My parameter for my g2p module') 114 | def main(ckpt, folder, g2p, match_mode, **kwargs): 115 | ... 116 | ``` 117 | 118 | 然后在`__init__`方法中接收这个参数: 119 | 120 | ```python 121 | class MyG2P(BaseG2P): 122 | def __init__(self, **kwargs): 123 | super().__init__(**kwargs) 124 | self.my_param = kwargs['my_param'] 125 | ... 126 | ``` 127 | 128 | **注意:g2p模块和AP_detector模块共用kwargs参数。** 129 | 130 | ### 4.2 `_g2p`方法 131 | 132 | `_g2p`方法的参数为`text`,是一个字符串,表示从.lab文件中读取的文本。 133 | 134 | `_g2p`方法的返回值为一个元组,包含三个元素: 135 | 136 | - `ph_seq`:音素列表,SP是静音音素。第一个音素和最后一个音素应当为SP,并且任何位置都不能有连续两个以上的SP。 137 | - `word_seq`:单词列表。 138 | - `ph_idx_to_word_idx`:`ph_idx_to_word_idx[i] = j` 意味着第i个音素属于第j个单词。如果`ph_idx_to_word_idx[i] = -1` 139 | ,则第i个音素是一个静音音素。 140 | 141 | 示例: 142 | 143 | ```python 144 | text = 'I am a student' 145 | ph_seq = ['SP', 'ay', 'SP', 'ae', 'm', 'SP', 'ah', 'SP', 's', 't', 'uw', 'd', 'ah', 'n', 't', 'SP'] 146 | word_seq = ['I', 'am', 'a', 'student'] 147 | ph_idx_to_word_idx = [-1, 0, -1, 1, 1, -1, 2, -1, 3, 3, 3, 3, 3, 3, 3, -1] 148 | ``` 149 | 150 | ### 4.3 使用 151 | 152 | 使用时,你可以这样运行`infer.py`: 153 | 154 | ```shell 155 | python infer.py --g2p My --my_param my_value 156 | ``` 157 | -------------------------------------------------------------------------------- /modules/utils/post_processing.py: -------------------------------------------------------------------------------- 1 | MIN_SP_LENGTH = 0.1 2 | SP_MERGE_LENGTH = 0.3 3 | 4 | 5 | def add_SP(word_seq, word_intervals, wav_length): 6 | word_seq_res = [] 7 | word_intervals_res = [] 8 | if len(word_seq) == 0: 9 | word_seq_res.append("SP") 10 | word_intervals_res.append([0, wav_length]) 11 | return word_seq_res, word_intervals_res 12 | 13 | word_seq_res.append("SP") 14 | word_intervals_res.append([0, word_intervals[0, 0]]) 15 | for word, (start, end) in zip(word_seq, word_intervals): 16 | if word_intervals_res[-1][1] < start: 17 | word_seq_res.append("SP") 18 | word_intervals_res.append([word_intervals_res[-1][1], start]) 19 | word_seq_res.append(word) 20 | word_intervals_res.append([start, end]) 21 | if word_intervals_res[-1][1] < wav_length: 22 | word_seq_res.append("SP") 23 | word_intervals_res.append([word_intervals_res[-1][1], wav_length]) 24 | if word_intervals[0, 0] <= 0: 25 | word_seq_res = word_seq_res[1:] 26 | word_intervals_res = word_intervals_res[1:] 27 | 28 | return word_seq_res, word_intervals_res 29 | 30 | 31 | def fill_small_gaps(word_seq, word_intervals, wav_length): 32 | if word_intervals[0, 0] > 0: 33 | if word_intervals[0, 0] < MIN_SP_LENGTH: 34 | word_intervals[0, 0] = 0 35 | 36 | for idx in range(len(word_seq) - 1): 37 | if word_intervals[idx, 1] < word_intervals[idx + 1, 0]: 38 | if word_intervals[idx + 1, 0] - word_intervals[idx, 1] < SP_MERGE_LENGTH: 39 | if word_seq[idx] == "AP": 40 | if word_seq[idx + 1] == "AP": 41 | # 情况1:gap的左右都是AP 42 | mean = (word_intervals[idx, 1] + word_intervals[idx + 1, 0]) / 2 43 | word_intervals[idx, 1] = mean 44 | word_intervals[idx + 1, 0] = mean 45 | else: 46 | # 情况2:只有左边是AP 47 | word_intervals[idx, 1] = word_intervals[idx + 1, 0] 48 | elif word_seq[idx + 1] == "AP": 49 | # 情况3:只有右边是AP 50 | word_intervals[idx + 1, 0] = word_intervals[idx, 1] 51 | else: 52 | # 情况4:gap的左右都不是AP 53 | if ( 54 | word_intervals[idx + 1, 0] - word_intervals[idx, 1] 55 | < MIN_SP_LENGTH 56 | ): 57 | mean = (word_intervals[idx, 1] + word_intervals[idx + 1, 0]) / 2 58 | word_intervals[idx, 1] = mean 59 | word_intervals[idx + 1, 0] = mean 60 | 61 | if word_intervals[-1, 1] < wav_length: 62 | if wav_length - word_intervals[-1, 1] < MIN_SP_LENGTH: 63 | word_intervals[-1, 1] = wav_length 64 | 65 | return word_seq, word_intervals 66 | 67 | 68 | def post_processing(predictions): 69 | print("Post-processing...") 70 | 71 | res = [] 72 | error_log = [] 73 | for ( 74 | wav_path, 75 | wav_length, 76 | confidence, 77 | ph_seq, 78 | ph_intervals, 79 | word_seq, 80 | word_intervals, 81 | ) in predictions: 82 | try: 83 | # fill small gaps 84 | word_seq, word_intervals = fill_small_gaps( 85 | word_seq, word_intervals, wav_length 86 | ) 87 | ph_seq, ph_intervals = fill_small_gaps(ph_seq, ph_intervals, wav_length) 88 | # add SP 89 | word_seq, word_intervals = add_SP(word_seq, word_intervals, wav_length) 90 | ph_seq, ph_intervals = add_SP(ph_seq, ph_intervals, wav_length) 91 | 92 | res.append( 93 | [ 94 | wav_path, 95 | wav_length, 96 | confidence, 97 | ph_seq, 98 | ph_intervals, 99 | word_seq, 100 | word_intervals, 101 | ] 102 | ) 103 | except Exception as e: 104 | error_log.append([wav_path, e]) 105 | return res, error_log 106 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | import warnings 4 | from typing import Dict 5 | 6 | import click 7 | import tqdm 8 | from textgrid import PointTier 9 | 10 | from modules.utils import label 11 | from modules.utils.metrics import ( 12 | BoundaryEditRatio, 13 | IntersectionOverUnion, 14 | Metric, 15 | VlabelerEditRatio, 16 | ) 17 | 18 | 19 | def remove_ignored_phonemes(ignored_phonemes_list: str, point_tier: PointTier): 20 | res_tier = PointTier(name=point_tier.name) 21 | if point_tier[0].mark not in ignored_phonemes_list: 22 | res_tier.addPoint(point_tier[0]) 23 | for i in range(len(point_tier) - 1): 24 | if ( 25 | point_tier[i].mark in ignored_phonemes_list 26 | and point_tier[i + 1].mark in ignored_phonemes_list 27 | ): 28 | continue 29 | 30 | res_tier.addPoint(point_tier[i + 1]) 31 | 32 | return res_tier 33 | 34 | 35 | @click.command( 36 | help="Calculate metrics between the FA predictions and the targets (ground truth)." 37 | ) 38 | @click.argument( 39 | "pred", 40 | type=click.Path(exists=True, file_okay=False, dir_okay=True, readable=True), 41 | metavar="PRED_DIR", 42 | ) 43 | @click.argument( 44 | "target", 45 | type=click.Path(exists=True, file_okay=False, dir_okay=True, readable=True), 46 | metavar="TARGET_DIR", 47 | ) 48 | @click.option( 49 | "--recursive", 50 | "-r", 51 | is_flag=True, 52 | help="Compare files in subdirectories recursively", 53 | ) 54 | @click.option( 55 | "--strict", "-s", is_flag=True, help="Raise errors on mismatching phone sequences" 56 | ) 57 | @click.option( 58 | "--ignore", 59 | type=str, 60 | default="", # AP,SP,,,,pau,cl 61 | help="Ignored phone marks, split by commas", 62 | show_default=True, 63 | ) 64 | def main(pred: str, target: str, recursive: bool, strict: bool, ignore: str): 65 | pred_dir = pathlib.Path(pred) 66 | target_dir = pathlib.Path(target) 67 | if recursive: 68 | iterable = pred_dir.rglob("*.TextGrid") 69 | else: 70 | iterable = pred_dir.glob("*.TextGrid") 71 | ignored = ignore.split(",") 72 | metrics: Dict[str, Metric] = { 73 | "BoundaryEditRatio": BoundaryEditRatio(), 74 | "VlabelerEditRatio10ms": VlabelerEditRatio(move_tolerance=0.01), 75 | "VlabelerEditRatio20ms": VlabelerEditRatio(move_tolerance=0.02), 76 | "VlabelerEditRatio50ms": VlabelerEditRatio(move_tolerance=0.05), 77 | "IntersectionOverUnion": IntersectionOverUnion(), 78 | } 79 | 80 | cnt = 0 81 | for pred_file in tqdm.tqdm(iterable): 82 | target_file = target_dir / pred_file.relative_to(pred_dir) 83 | if not target_file.exists(): 84 | warnings.warn( 85 | f'The prediction file "{pred_file}" has no matching target file, ' 86 | f'which should be "{target_file}".', 87 | category=UserWarning, 88 | ) 89 | warnings.filterwarnings("default") 90 | continue 91 | 92 | pred_tier = label.textgrid_from_file(pred_file)[-1] 93 | target_tier = label.textgrid_from_file(target_file)[-1] 94 | pred_tier = remove_ignored_phonemes(ignored, pred_tier) 95 | target_tier = remove_ignored_phonemes(ignored, target_tier) 96 | 97 | for metric in metrics.values(): 98 | try: 99 | metric.update(pred_tier, target_tier) 100 | except AssertionError as e: 101 | if not strict: 102 | warnings.warn( 103 | f"Failed to evaluate metric {metric.__class__.__name__} for file {pred_file}: {e}", 104 | category=UserWarning, 105 | ) 106 | warnings.filterwarnings("default") 107 | continue 108 | else: 109 | raise e 110 | 111 | cnt += 1 112 | 113 | if cnt == 0: 114 | raise RuntimeError( 115 | "Unable to compare any files in the given directories. " 116 | "Matching files should have same names and same relative paths, " 117 | "containing the same phone sequences except for spaces." 118 | ) 119 | result = {key: metric.compute() for key, metric in metrics.items()} 120 | print(json.dumps(result, indent=4, ensure_ascii=False)) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /modules/layer/backbone/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from modules.layer.block.resnet_block import ResidualBasicBlock 5 | from modules.layer.scaling.base import BaseDowmSampling, BaseUpSampling 6 | from modules.layer.scaling.stride_conv import DownSampling, UpSampling 7 | 8 | 9 | class UNetBackbone(nn.Module): 10 | def __init__( 11 | self, 12 | input_dims, 13 | output_dims, 14 | hidden_dims, 15 | block, 16 | down_sampling, 17 | up_sampling, 18 | down_sampling_factor=2, 19 | down_sampling_times=5, 20 | channels_scaleup_factor=2, 21 | **kwargs 22 | ): 23 | """_summary_ 24 | 25 | Args: 26 | input_dims (int): 27 | output_dims (int): 28 | hidden_dims (int): 29 | block (nn.Module): shape: (B, T, C) -> shape: (B, T, C) 30 | down_sampling (nn.Module): shape: (B, T, C) -> shape: (B, T/down_sampling_factor, C*2) 31 | up_sampling (nn.Module): shape: (B, T, C) -> shape: (B, T*down_sampling_factor, C/2) 32 | """ 33 | super(UNetBackbone, self).__init__() 34 | assert issubclass(block, nn.Module) 35 | assert issubclass(down_sampling, BaseDowmSampling) 36 | assert issubclass(up_sampling, BaseUpSampling) 37 | 38 | self.input_dims = input_dims 39 | self.output_dims = output_dims 40 | self.hidden_dims = hidden_dims 41 | 42 | self.divisible_factor = down_sampling_factor**down_sampling_times 43 | 44 | self.encoders = nn.ModuleList() 45 | self.encoders.append(block(input_dims, hidden_dims, **kwargs)) 46 | for i in range(down_sampling_times - 1): 47 | i += 1 48 | self.encoders.append( 49 | nn.Sequential( 50 | down_sampling( 51 | int(channels_scaleup_factor ** (i - 1)) * hidden_dims, 52 | int(channels_scaleup_factor**i) * hidden_dims, 53 | down_sampling_factor, 54 | ), 55 | block( 56 | int(channels_scaleup_factor**i) * hidden_dims, 57 | int(channels_scaleup_factor**i) * hidden_dims, 58 | **kwargs 59 | ), 60 | ) 61 | ) 62 | 63 | self.bottle_neck = nn.Sequential( 64 | down_sampling( 65 | int(channels_scaleup_factor ** (down_sampling_times - 1)) * hidden_dims, 66 | int(channels_scaleup_factor**down_sampling_times) * hidden_dims, 67 | down_sampling_factor, 68 | ), 69 | block( 70 | int(channels_scaleup_factor**down_sampling_times) * hidden_dims, 71 | int(channels_scaleup_factor**down_sampling_times) * hidden_dims, 72 | **kwargs 73 | ), 74 | up_sampling( 75 | int(channels_scaleup_factor**down_sampling_times) * hidden_dims, 76 | int(channels_scaleup_factor ** (down_sampling_times - 1)) * hidden_dims, 77 | down_sampling_factor, 78 | ), 79 | ) 80 | 81 | self.decoders = nn.ModuleList() 82 | for i in range(down_sampling_times - 1): 83 | i += 1 84 | self.decoders.append( 85 | nn.Sequential( 86 | block( 87 | int(channels_scaleup_factor ** (down_sampling_times - i)) 88 | * hidden_dims, 89 | int(channels_scaleup_factor ** (down_sampling_times - i)) 90 | * hidden_dims, 91 | **kwargs 92 | ), 93 | up_sampling( 94 | int(channels_scaleup_factor ** (down_sampling_times - i)) 95 | * hidden_dims, 96 | int(channels_scaleup_factor ** (down_sampling_times - i - 1)) 97 | * hidden_dims, 98 | down_sampling_factor, 99 | ), 100 | ) 101 | ) 102 | self.decoders.append(block(hidden_dims, output_dims, **kwargs)) 103 | 104 | def forward(self, x): 105 | T = x.shape[1] 106 | padding_len = T % self.divisible_factor 107 | if padding_len != 0: 108 | x = nn.functional.pad(x, (0, 0, 0, self.divisible_factor - padding_len)) 109 | 110 | h = [x] 111 | for encoder in self.encoders: 112 | h.append(encoder(h[-1])) 113 | 114 | h_ = [self.bottle_neck(h[-1])] 115 | for i, decoder in enumerate(self.decoders): 116 | h_.append(decoder(h_[-1] + h[-1 - i])) 117 | 118 | out = h_[-1] 119 | out = out[:, :T, :] 120 | 121 | return out 122 | 123 | 124 | if __name__ == "__main__": 125 | # pass 126 | model = UNetBackbone(1, 2, 64, ResidualBasicBlock, DownSampling, UpSampling) 127 | print(model) 128 | x = torch.randn(16, 320, 1) 129 | out = model(x) 130 | print(x.shape, out.shape) 131 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | import click 5 | import lightning as pl 6 | import torch 7 | import yaml 8 | from torch.utils.data import DataLoader 9 | 10 | from dataset import MixedDataset, WeightedBinningAudioBatchSampler, collate_fn 11 | from modules.task.forced_alignment import LitForcedAlignmentTask 12 | 13 | 14 | @click.command() 15 | @click.option( 16 | "--config_path", 17 | "-c", 18 | type=str, 19 | default="configs/train_config.yaml", 20 | show_default=True, 21 | help="training config path", 22 | ) 23 | @click.option( 24 | "--data_folder", 25 | "-d", 26 | type=str, 27 | default="data", 28 | show_default=True, 29 | help="data folder path", 30 | ) 31 | @click.option( 32 | "--pretrained_model_path", 33 | "-p", 34 | type=str, 35 | default=None, 36 | show_default=True, 37 | help="pretrained model path. if None, training from scratch", 38 | ) 39 | @click.option( 40 | "--resume", 41 | "-r", 42 | is_flag=True, 43 | default=False, 44 | show_default=True, 45 | help="resume training from checkpoint", 46 | ) 47 | def main(config_path: str, data_folder: str, pretrained_model_path, resume): 48 | data_folder = pathlib.Path(data_folder) 49 | os.environ[ 50 | "TORCH_CUDNN_V8_API_ENABLED" 51 | ] = "1" # Prevent unacceptable slowdowns when using 16 precision 52 | 53 | with open(config_path, "r") as f: 54 | config = yaml.safe_load(f) 55 | 56 | with open(data_folder / "binary" / "vocab.yaml") as f: 57 | vocab = yaml.safe_load(f) 58 | vocab_text = yaml.safe_dump(vocab) 59 | 60 | with open(data_folder / "binary" / "global_config.yaml") as f: 61 | config_global = yaml.safe_load(f) 62 | config.update(config_global) 63 | 64 | torch.set_float32_matmul_precision(config["float32_matmul_precision"]) 65 | pl.seed_everything(config["random_seed"], workers=True) 66 | 67 | # define dataset 68 | num_workers = config['dataloader_workers'] 69 | train_dataset = MixedDataset( 70 | config["data_augmentation_size"], data_folder / "binary", prefix="train" 71 | ) 72 | train_sampler = WeightedBinningAudioBatchSampler( 73 | train_dataset.get_label_types(), 74 | train_dataset.get_wav_lengths(), 75 | config["oversampling_weights"], 76 | config["batch_max_length"] / (2 if config["data_augmentation_size"] > 0 else 1), 77 | config["binning_length"], 78 | config["drop_last"], 79 | ) 80 | train_dataloader = DataLoader( 81 | dataset=train_dataset, 82 | batch_sampler=train_sampler, 83 | collate_fn=collate_fn, 84 | num_workers=num_workers, 85 | persistent_workers=num_workers > 0, 86 | pin_memory=True, 87 | prefetch_factor=(2 if num_workers > 0 else None), 88 | ) 89 | 90 | valid_dataset = MixedDataset(0, data_folder / "binary", prefix="valid") 91 | valid_dataloader = DataLoader( 92 | dataset=valid_dataset, 93 | batch_size=1, 94 | shuffle=False, 95 | collate_fn=collate_fn, 96 | num_workers=num_workers, 97 | persistent_workers=num_workers > 0, 98 | ) 99 | 100 | # model 101 | lightning_alignment_model = LitForcedAlignmentTask( 102 | vocab_text, 103 | config["model"], 104 | config["melspec_config"], 105 | config["optimizer_config"], 106 | config["loss_config"], 107 | config["data_augmentation_size"] > 0, 108 | ) 109 | 110 | # trainer 111 | trainer = pl.Trainer( 112 | accelerator=config["accelerator"], 113 | devices=config["devices"], 114 | precision=config["precision"], 115 | gradient_clip_val=config["gradient_clip_val"], 116 | gradient_clip_algorithm=config["gradient_clip_algorithm"], 117 | default_root_dir=str(pathlib.Path("ckpt") / config["model_name"]), 118 | val_check_interval=config["val_check_interval"], 119 | check_val_every_n_epoch=None, 120 | max_epochs=-1, 121 | max_steps=config["optimizer_config"]["total_steps"], 122 | ) 123 | 124 | ckpt_path = None 125 | if pretrained_model_path is not None: 126 | # use pretrained model TODO: load pretrained model 127 | pretrained = LitForcedAlignmentTask.load_from_checkpoint(pretrained_model_path) 128 | lightning_alignment_model.load_pretrained(pretrained) 129 | elif resume: 130 | # resume training state 131 | ckpt_path_list = (pathlib.Path("ckpt") / config["model_name"]).rglob("*.ckpt") 132 | ckpt_path_list = sorted( 133 | ckpt_path_list, key=lambda x: int(x.stem.split("step=")[-1]), reverse=True 134 | ) 135 | ckpt_path = str(ckpt_path_list[0]) if len(ckpt_path_list) > 0 else None 136 | 137 | # start training 138 | trainer.fit( 139 | model=lightning_alignment_model, 140 | train_dataloaders=train_dataloader, 141 | val_dataloaders=valid_dataloader, 142 | ckpt_path=ckpt_path, 143 | ) 144 | 145 | # Discard the optimizer and save 146 | trainer.save_checkpoint( 147 | str(pathlib.Path("ckpt") / config["model_name"]) + ".ckpt", weights_only=True 148 | ) 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /modules/utils/label.py: -------------------------------------------------------------------------------- 1 | # Conversion between various label formats (.csv, .lab, .textgrid, etc.) and a unified format. 2 | # TextGrid and PointTier are used as the unified format for efficient calculation. 3 | # Point.time indicates the start time of the phoneme, consistent with Vlabeler's behavior. 4 | 5 | from typing import List, Tuple, Union 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import textgrid as tg 10 | 11 | 12 | def durations_to_tier( 13 | marks: List, 14 | durarions: Union[List, np.ndarray], 15 | name="phones", 16 | start_time=0.0, 17 | ) -> tg.PointTier: 18 | assert len(marks) == len(durarions) 19 | 20 | durarions = np.insert(durarions, 0, start_time) 21 | times = np.cumsum(durarions) 22 | marks.append("") 23 | 24 | tier = tg.PointTier(name=name) 25 | for time, mark in zip(times, marks): 26 | tier.add(time, mark) 27 | 28 | return tier 29 | 30 | 31 | def interval_tier_to_point_tier(tier: tg.IntervalTier) -> tg.PointTier: 32 | point_tier = tg.PointTier(name=tier.name) 33 | point_tier.add(0.0, "") 34 | for interval in tier: 35 | if point_tier[-1].mark == "" and point_tier[-1].time == interval.minTime: 36 | point_tier[-1].mark = interval.mark 37 | else: 38 | point_tier.add(interval.minTime, interval.mark) 39 | point_tier.add(interval.maxTime, "") 40 | 41 | return point_tier 42 | 43 | 44 | def point_tier_to_interval_tier(tier: tg.PointTier) -> tg.IntervalTier: 45 | interval_tier = tg.IntervalTier(name=tier.name) 46 | for idx in range(len(tier) - 1): 47 | interval_tier.add(tier[idx].time, tier[idx + 1].time, tier[idx].mark) 48 | return interval_tier 49 | 50 | 51 | def tier_from_htk(lab_path: str, tier_name="phones") -> tg.PointTier: 52 | """Read a htk label file (nnsvs format) and return a PointTier object.""" 53 | tier = tg.IntervalTier(name=tier_name) 54 | 55 | with open(lab_path, "r", encoding="utf-8") as f: 56 | for line in f: 57 | start, end, mark = line.strip().split() 58 | tier.add(int(start) / 1e7, int(end) / 1e7, mark) 59 | 60 | return interval_tier_to_point_tier(tier) 61 | 62 | 63 | def textgrid_from_file(textgrid_path: str) -> tg.TextGrid: 64 | """Read a TextGrid file and return a TextGrid object.""" 65 | textgrid = tg.TextGrid() 66 | textgrid.read(textgrid_path, encoding="utf-8") 67 | for idx, tier in enumerate(textgrid): 68 | if isinstance(tier, tg.IntervalTier): 69 | textgrid.tiers[idx] = interval_tier_to_point_tier(tier) 70 | 71 | return textgrid 72 | 73 | 74 | def textgrids_from_csv(csv_path: str) -> List[Tuple[str, tg.TextGrid]]: 75 | """Read a CSV file and return a list of (filename, TextGrid) tuples.""" 76 | textgrids = [] 77 | 78 | df = pd.read_csv(csv_path) 79 | df = df.loc[:, ["name", "ph_seq", "ph_dur"]] 80 | 81 | for _, row in df.iterrows(): 82 | textgrid = tg.TextGrid() 83 | tier = durations_to_tier( 84 | row["ph_seq"].split(), list(map(float, row["ph_dur"].split())) 85 | ) 86 | textgrid.append(tier) 87 | 88 | textgrids.append((row["name"], textgrid)) 89 | 90 | return textgrids 91 | 92 | 93 | def save_tier_to_htk(tier: tg.PointTier, lab_path: str) -> None: 94 | """Save a PointTier object to a htk label file.""" 95 | with open(lab_path, "w", encoding="utf-8") as f: 96 | for i in range(len(tier) - 1): 97 | f.write( 98 | "{:.0f} {:.0f} {}\n".format( 99 | tier[i].time * 1e7, tier[i + 1].time * 1e7, tier[i].mark 100 | ) 101 | ) 102 | 103 | 104 | def save_textgrid(path: str, textgrid: tg.TextGrid) -> None: 105 | """Save a TextGrid object to a TextGrid file.""" 106 | for i in range(len(textgrid)): 107 | if textgrid[i].maxTime is None: 108 | textgrid[i].maxTime = textgrid[i][-1].time 109 | if isinstance(textgrid[i], tg.PointTier): 110 | textgrid.tiers[i] = point_tier_to_interval_tier(textgrid[i]) 111 | textgrid.write(path) 112 | 113 | 114 | def save_textgrids_to_csv( 115 | path: str, 116 | textgrids: List[Tuple[str, tg.TextGrid]], 117 | precision=6, 118 | ) -> None: 119 | """Save a list of (filename, TextGrid) tuples to a CSV file.""" 120 | rows = [] 121 | for name, textgrid in textgrids: 122 | tier = textgrid[-1] 123 | ph_seq = " ".join( 124 | ["" if point.mark == "" else point.mark for point in tier[:-1]] 125 | ) 126 | ph_dur = " ".join( 127 | [ 128 | "{:.{}n}".format(ed.time - st.time, precision) 129 | for st, ed in zip(tier[:-1], tier[1:]) 130 | ] 131 | ) 132 | rows.append([name, ph_seq, ph_dur]) 133 | 134 | df = pd.DataFrame(rows, columns=["name", "ph_seq", "ph_dur"]) 135 | df.to_csv(path, index=False, encoding="utf-8") 136 | 137 | 138 | if __name__ == "__main__": 139 | textgrid = textgrid_from_file("test/label/tg.TextGrid") 140 | save_textgrid("test/label/tg_out.TextGrid", textgrid) 141 | # # Save the TextGrid object to a htk label file 142 | # # save_htk(textgrid, "example_out.lab") 143 | # # Convert a TextGrid file to a TextGrid object 144 | # textgrid = from_textgrid("example.TextGrid") 145 | # # Save the TextGrid object to a TextGrid file 146 | # textgrid.write("example_out.TextGrid") 147 | -------------------------------------------------------------------------------- /modules/layer/block/conformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from .attention import MultiHeadSelfAttention 5 | from .func_module import FuncModule 6 | from .residual import Residual 7 | from modules.layer.activation.GLU import GLU 8 | 9 | 10 | class ConformerBlock(nn.Module): 11 | def __init__( 12 | self, 13 | input_dims=128, 14 | output_dims=128, 15 | hidden_dims=64, 16 | kernel_size=3, 17 | dropout=0.1, 18 | num_heads=8, 19 | max_seq_len=3200, 20 | mask="none", 21 | ): 22 | super(ConformerBlock, self).__init__() 23 | self.feed_forward_1 = nn.Sequential( 24 | nn.LayerNorm(input_dims), # nn.LayerNorm(input_dims), 25 | nn.Linear(input_dims, hidden_dims), 26 | nn.Hardswish(), 27 | nn.Dropout(dropout), 28 | nn.Linear(hidden_dims, hidden_dims), 29 | nn.Dropout(dropout), 30 | ) 31 | self.multi_head_self_attention = nn.Sequential( 32 | nn.LayerNorm(hidden_dims), 33 | MultiHeadSelfAttention(hidden_dims, num_heads=num_heads, mask=mask, max_seq_len=max_seq_len), 34 | ) 35 | self.convolution = nn.Sequential( 36 | nn.LayerNorm(hidden_dims), 37 | FuncModule(lambda x: rearrange(x, "b t c -> b c t")), 38 | nn.Conv1d( 39 | hidden_dims, 40 | hidden_dims, 41 | kernel_size, 42 | padding=kernel_size // 2, 43 | groups=hidden_dims, 44 | ), 45 | FuncModule(lambda x: rearrange(x, "b c t -> b t c")), 46 | GLU(hidden_dims, hidden_dims), 47 | nn.Linear(hidden_dims, hidden_dims), 48 | # nn.Conv1d(hidden_dims, hidden_dims, 1, 1, 0, 1, 1), 49 | nn.LayerNorm(hidden_dims), 50 | nn.Hardswish(), 51 | FuncModule(lambda x: rearrange(x, "b t c -> b c t")), 52 | nn.Conv1d( 53 | hidden_dims, 54 | hidden_dims, 55 | kernel_size, 56 | padding=kernel_size // 2, 57 | groups=hidden_dims, 58 | ), 59 | FuncModule(lambda x: rearrange(x, "b c t -> b t c")), 60 | nn.Dropout(dropout), 61 | ) 62 | self.feed_forward_2 = nn.Sequential( 63 | nn.LayerNorm(hidden_dims), 64 | nn.Linear(hidden_dims, hidden_dims), 65 | nn.Hardswish(), 66 | nn.Dropout(dropout), 67 | nn.Linear(hidden_dims, output_dims), 68 | nn.Dropout(dropout), 69 | ) 70 | self.norm = nn.LayerNorm(output_dims) 71 | 72 | self.residual_i_h = Residual(input_dims, hidden_dims) 73 | self.residual_h_h = Residual(hidden_dims, hidden_dims) 74 | self.residual_h_o = Residual(hidden_dims, output_dims) 75 | 76 | def forward(self, x): 77 | x = self.residual_i_h(x, (1 / 2) * self.feed_forward_1(x)) 78 | # Multi-head self-attention 79 | x = self.residual_h_h(x, self.multi_head_self_attention(x)) 80 | # Convolution 81 | x = self.residual_h_h(x, self.convolution(x)) 82 | # Feed Forward 83 | x = self.residual_h_o(x, (1 / 2) * self.feed_forward_2(x)) 84 | # Norm 85 | x = self.norm(x) 86 | return x 87 | 88 | 89 | class ForwardBackwardConformerBlock(nn.Module): 90 | def __init__( 91 | self, 92 | input_dims=128, 93 | output_dims=128, 94 | hidden_dims=128, 95 | kernel_size=3, 96 | dropout=0.1, 97 | num_heads=8, 98 | max_seq_len=3200, 99 | ): 100 | super(ForwardBackwardConformerBlock, self).__init__() 101 | self.forward_block = ConformerBlock(input_dims, 102 | hidden_dims, 103 | hidden_dims, 104 | kernel_size, 105 | dropout, 106 | num_heads, 107 | max_seq_len, 108 | mask='upper', 109 | ) 110 | self.backward_block = ConformerBlock(hidden_dims, 111 | output_dims, 112 | hidden_dims, 113 | kernel_size, 114 | dropout, 115 | num_heads, 116 | max_seq_len, 117 | mask='lower', 118 | ) 119 | 120 | def forward(self, x): 121 | x = self.forward_block(x) 122 | x = self.backward_block(x) 123 | return x 124 | 125 | 126 | if __name__ == "__main__": 127 | # test 128 | # bs, l, dims = 16, 320, 128 129 | # input_tensor = torch.randn(bs, l, dims) 130 | # model = ForwardBackwardConformerBlock(dims, 64) 131 | # y = model(input_tensor) 132 | # print(input_tensor.shape, y.shape) 133 | pass 134 | -------------------------------------------------------------------------------- /README_zh.MD: -------------------------------------------------------------------------------- 1 | # SOFA: Singing-Oriented Forced Aligner 2 | 3 | [English](README.MD) | 简体中文 4 | 5 | ![example](example.png) 6 | 7 | # 介绍 8 | 9 | SOFA(Singing-Oriented Forced Aligner)是一个专为歌声设计的强制对齐器,同时也兼容非歌声的对齐。 10 | 11 | 在歌声数据上,相比于MFA(Montreal Forced Aligner),SOFA具有以下优点: 12 | 13 | * 容易安装 14 | * 效果更好 15 | * 推理速度更快 16 | 17 | # 使用方法 18 | 19 | ## 环境配置 20 | 21 | 1. 使用`git clone`​​下载本仓库的代码 22 | 2. 安装conda 23 | 3. 创建conda环境,python版本的要求为`3.8` 24 | ```bash 25 | conda create -n SOFA python=3.8 -y 26 | conda activate SOFA 27 | ``` 28 | 4. 去[pytorch官网](https://pytorch.org/get-started/locally/)安装torch 29 | 5. (可选,用于提高wav读取速度)去[pytorch官网](https://pytorch.org/get-started/locally/)安装torchaudio 30 | 6. 安装其他python库 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## 推理 36 | 37 | 1. 下载模型文件。你可以在本仓库discussion的[模型分享板块](https://github.com/qiuqiao/SOFA/discussions/categories/pretrained-model-sharing)中找到训练好的模型,文件后缀为`.ckpt`。 38 | 2. 把字典文件放入`/dictionary`​​文件夹中。默认字典为`opencpop-extension.txt`​​ 39 | 3. 准备需要强制对齐的数据,放入一个文件夹中(默认放在`/segments`​​文件夹),格式如下 40 | ```text 41 | - segments 42 | - singer1 43 | - segment1.lab 44 | - segment1.wav 45 | - segment2.lab 46 | - segment2.wav 47 | - ... 48 | - singer2 49 | - segment1.lab 50 | - segment1.wav 51 | - ... 52 | ``` 53 | 保证`.wav`文件和对应的`.lab`在同一个文件夹即可。 54 | 55 | 其中,`.lab`文件是同名`.wav`文件的录音文本。录音文本的文件后缀名可以通过`--in_format`参数更改。 56 | 57 | 录音文本经过`g2p`模块转化为音素序列后,输入模型进行对齐。 58 | 59 | 例如,默认情况下使用`DictionaryG2P`模块和`opencpop-extension`词典时,加入录音文本的内容是:`gan shou ting zai wo fa duan de zhi jian`,`g2p`模块根据词典转化为音素序列`g an sh ou t ing z ai w o f a d uan d e zh ir j ian`。其他`g2p`模块的使用方法,参见[g2p模块使用说明](modules/g2p/readme_g2p_zh.md)。 60 | 61 | 4. 命令行推理 62 | 63 | 使用`python infer.py`进行推理。 64 | 65 | 需要指定的参数: 66 | - `--ckpt`:(必须指定)模型权重路径; 67 | - `--folder`:存放待对齐数据的文件夹​(默认为`segments`); 68 | - `--in_format`: 录音文本的文件后缀名(默认为`lab`); 69 | - `--out_formats`:推理出来的文件的标注格式,可指定多个,使用逗号分隔(默认为`TextGrid,htk,trans`) 70 | - `--save_confidence`:输出置信度。 71 | - `--dictionary`:字典文件​(默认为`dictionary/opencpop-extension.txt`​); 72 | 73 | ```bash 74 | python infer.py -c checkpoint_path -s segments_path -d dictionary_path -of output_format1,output_format2... 75 | ``` 76 | 5. 获取最终标注 77 | 78 | 最终的标注保存在文件夹中,文件夹的名称是你选择的标注格式,这个文件夹的位置和推理所用的wav文件处于同一个文件夹中。 79 | 80 | ### 高级功能 81 | 82 | - 使用自定义的g2p,而不是使用词典 83 | - 参见[g2p模块使用说明](modules/g2p/readme_g2p_zh.md) 84 | - matching模式,推理时指定`-m`即可开启,会在给定的音素序列中找到一个使得概率最大的连续序列片段,而非必须用上所有音素。 85 | 86 | ## onnx推理 87 | 88 | 1. 安装onnxruntime 89 | ```bash 90 | pip install onnxruntime-gpu 91 | ``` 92 | 2. 执行推理 93 | ```bash 94 | python onnx_infer.py \ 95 | --onnx /path/to/model.onnx \ 96 | --folder /input/audio/folder \ 97 | --g2p Dictionary \ 98 | --ap_detector LoudnessSpectralcentroidAPDetector \ 99 | --out_formats textgrid,htk 100 | ``` 101 | 102 | 参数说明: 103 | 104 | - `--onnx` ONNX模型文件路径 105 | - `--folder` 包含.wav音频和.lab标注的输入目录 106 | - `--g2p` 字素到音素转换器(默认:Dictionary) 107 | - `--ap_detector` 静音段检测算法(默认:LoudnessSpectralcentroidAPDetector) 108 | - `--out_formats` 输出格式列表(逗号分隔) 109 | 110 | ## 训练 111 | 112 | 1. 参照上文进行环境配置。建议安装torchaudio以获得更快的binarize速度; 113 | 2. 把训练数据按以下格式放入`data`文件夹: 114 | 115 | ``` 116 | - data 117 | - full_label 118 | - singer1 119 | - wavs 120 | - audio1.wav 121 | - audio2.wav 122 | - ... 123 | - transcriptions.csv 124 | - singer2 125 | - wavs 126 | - ... 127 | - transcriptions.csv 128 | - weak_label 129 | - singer3 130 | - wavs 131 | - ... 132 | - transcriptions.csv 133 | - singer4 134 | - wavs 135 | - ... 136 | - transcriptions.csv 137 | - no_label 138 | - audio1.wav 139 | - audio2.wav 140 | - ... 141 | ``` 142 | 关于`transcriptions.csv`的格式,参见:https://github.com/qiuqiao/SOFA/discussions/5 143 | 144 | 其中: 145 | 146 | `transcriptions.csv`只需要和`wavs`文件夹的相对位置正确即可; 147 | 148 | `weak_label`中的`transcriptions.csv`无需拥有`ph_dur`这一个`column`; 149 | 3. 按需修改`binarize_config.yaml`,然后执行`python binarize.py`; 150 | 4. 在release中下载你需要的预训练模型,并按需修改`train_config.yaml`,然后执行`python train.py -p path_to_your_pretrained_model`; 151 | 5. 训练可视化:`tensorboard --logdir=ckpt/`。 152 | 153 | ## 导出onnx模型 154 | 155 | 1. 安装onnxruntime 156 | ```bash 157 | pip install onnxruntime-gpu 158 | ``` 159 | 2. 导出onnx模型 160 | ```bash 161 | python export_onnx.py --ckpt_path /path/to/checkpoint.ckpt --onnx_path /output/model.onnx 162 | ``` 163 | 164 | ## 评估(适用于模型开发者) 165 | 166 | 可通过在预测(强制对齐)标注与目标(人工)标注之间计算特定的客观评价指标(尤其是在k折交叉验证中)来评估模型性能。 167 | 168 | 一些有用的指标包括: 169 | 170 | - 边界编辑距离:从预测的边界到目标边界的总移动距离。 171 | - 边界编辑比率:边界编辑距离除以目标的总时长。 172 | - 边界错误率:在给定的容差值下,位置错误的边界数占据目标边界总数的比例。 173 | 174 | 若要在特定的数据集上验证你的模型,请先运行推理以得到所有的预测标注。随后,你需要将预测标注与目标标注放置在不同的文件夹中,相对应的标注文件需要保持相同文件名、相同相对路径,并包含相同的音素序列(空白音素除外)。此脚本当前仅支持TextGrid格式。 175 | 176 | 运行以下命令: 177 | 178 | ```bash 179 | python evaluate.py -r -s 180 | ``` 181 | 182 | 其中 `PRED_DIR` 是包含所有预测标注的目录,`TARGET_DIR` 是包含所有目标标注的目录。 183 | 184 | 可选项: 185 | - `-r`, `--recursive`: 递归对比子文件夹中的文件 186 | - `-s`, `--strict`: 使用严格模式(当音素序列不相同时报错而非跳过) 187 | - `--ignore`: 忽略部分音素记号(默认值:`AP,SP,,,,pau,cl`) 188 | -------------------------------------------------------------------------------- /modules/g2p/readme_g2p.md: -------------------------------------------------------------------------------- 1 | # g2p Module Documentation 2 | 3 | ## 1. Introduction 4 | 5 | The g2p module is used for inference. It reads text from `.lab` files and then converts the text into a phoneme 6 | sequence. 7 | 8 | ## 2. Usage 9 | 10 | When using `infer.py` for command-line inference, add the `--g2p` argument to specify the g2p module name (default is 11 | Dictionary, and the G2P suffix in the name can be omitted). 12 | 13 | Some g2p modules require additional parameters. For instance, the Dictionary G2P requires the `--dictionary_path` 14 | parameter to specify the path to the dictionary file. 15 | 16 | Different modules require different parameters; for specifics, please refer to the documentation for each module. 17 | 18 | Example: if you want to use Dictionary G2P, you would run `infer.py` like this: 19 | 20 | ```shell 21 | python infer.py --g2p Dictionary --dictionary_path /path/to/dictionary.txt 22 | ``` 23 | 24 | Or: if you want to use Phoneme G2P, you would run `infer.py` like this: 25 | 26 | ```shell 27 | python infer.py --g2p PhonemeG2P 28 | ``` 29 | 30 | ## 3. Module List 31 | 32 | ### 3.1 Dictionary G2P 33 | 34 | #### 3.1.1 Introduction 35 | 36 | Dictionary G2P is a dictionary-based g2p module. It reads a sequence of words from a `.lab` file and then converts this 37 | sequence of words into a sequence of phonemes. 38 | 39 | Between words, the SP phoneme is automatically inserted, but it is not inserted within words. 40 | 41 | #### 3.1.2 Input Format 42 | 43 | The input `.lab` file only has one line, which contains a sequence of words separated by spaces. 44 | 45 | Example: 46 | 47 | ```text 48 | I am a student 49 | ``` 50 | 51 | #### 3.1.3 Parameters 52 | 53 | | Parameter Name | Type | Default | Description | 54 | |-----------------|--------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 55 | | dictionary_path | string | none | Path to the dictionary file. The format of the dictionary file: each line contains a word entry with a word and one or more phonemes. Words and phonemes are separated by `\t`, and phonemes by spaces. See `dictionary/opencpop-extension.txt` for an example. | 56 | 57 | ### 3.2 Phoneme G2P 58 | 59 | #### 3.2.1 Introduction 60 | 61 | Phoneme G2P is a phoneme-based g2p module. It reads a phoneme sequence directly from the `.lab` file and inserts a `SP` 62 | phoneme between each phoneme. 63 | 64 | #### 3.2.2 Input Format 65 | 66 | The input `.lab` file only has one line, which contains a phoneme sequence separated by spaces, preferably not including 67 | the 'SP' phoneme. 68 | 69 | Example: 70 | 71 | ```text 72 | ay ae m ah s t uw d ah n t 73 | ``` 74 | 75 | #### 3.2.3 Parameters 76 | 77 | None 78 | 79 | ### 3.3 None G2P 80 | 81 | #### 3.3.1 Introduction 82 | 83 | None G2P is a null g2p module. It reads a phoneme sequence directly from the `.lab` file and does not process it. 84 | 85 | It can be regarded as Phoneme G2P without inserting `SP` phonemes. 86 | 87 | #### 3.3.2 Input Format 88 | 89 | The input `.lab` file only has one line, which contains a phoneme sequence separated by spaces. 90 | 91 | Example: 92 | 93 | ```text 94 | SP ay SP ae m SP ah SP s t uw d ah n t SP 95 | ``` 96 | 97 | #### 3.3.3 Parameters 98 | 99 | None 100 | 101 | ## 4. Custom g2p Module 102 | 103 | If you wish to create a custom g2p module, you need to inherit from the `BaseG2P` class in `base_g2p.py` and implement 104 | the `__init__` and `_g2p` methods. 105 | 106 | The `__init__` method is for initializing the module, and the `_g2p` method is for converting text into a phoneme 107 | sequence, word sequence, and the mapping between the two sequences. 108 | 109 | ### 4.1 `__init__` Method 110 | 111 | The `__init__` method takes `**kwargs` as its parameters, which are passed from the command-line arguments `**g2p_args` 112 | in `infer.py`. 113 | 114 | If you want to use additional parameters, you can add them in `infer.py` and then receive these parameters in 115 | the `__init__` method. 116 | 117 | Example: If you want to use an additional parameter `--my_param`, you would add it in infer.py like this: 118 | 119 | ```python 120 | ... 121 | 122 | 123 | @click.option('--my_param', type=str, default=None, help='My parameter for my g2p module') 124 | def main(ckpt, folder, g2p, match_mode, **kwargs): 125 | ... 126 | ``` 127 | 128 | Then receive this parameter in the `__init__` method: 129 | 130 | ```python 131 | class MyG2P(BaseG2P): 132 | def __init__(self, **kwargs): 133 | super().__init__(**kwargs) 134 | self.my_param = kwargs['my_param'] 135 | ... 136 | ``` 137 | **Note: The g2p module and the AP_detector module share the kwargs parameters.** 138 | 139 | ### 4.2 `_g2p` Method 140 | 141 | The `_g2p` method takes `text` as a parameter, which is a string representing the text read from a .lab file. 142 | 143 | The return value of the `_g2p` method is a tuple containing three elements: 144 | 145 | - `ph_seq`: A list of phonemes, with SP as the silence phoneme. The first and last phonemes should be `SP`, and there should not be more than two consecutive `SP`s at any position. 146 | - `word_seq`: A list of words. 147 | - `ph_idx_to_word_idx`: `ph_idx_to_word_idx[i] = j` means that the ith phoneme belongs to the jth word. If 148 | `ph_idx_to_word_idx[i] = -1`, then the ith phoneme is a silence phoneme. 149 | 150 | Example: 151 | 152 | ```python 153 | text = 'I am a student' 154 | ph_seq = ['SP', 'ay', 'SP', 'ae', 'm', 'SP', 'ah', 'SP', 's', 't', 'uw', 'd', 'ah', 'n', 't', 'SP'] 155 | word_seq = ['I', 'am', 'a', 'student'] 156 | ph_idx_to_word_idx = [-1, 0, -1, 1, 1, -1, 2, -1, 3, 3, 3, 3, 3, 3, 3, -1] 157 | ``` 158 | 159 | ### 4.3 Usage 160 | 161 | When in use, you can run `infer.py` like this: 162 | 163 | ```shell 164 | python infer.py --g2p My --my_param my_value 165 | ``` 166 | -------------------------------------------------------------------------------- /modules/AP_detector/loudnesss_pectralcentroid_detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from modules.AP_detector.base_detector import BaseAPDetector 6 | from modules.utils.load_wav import load_wav 7 | 8 | 9 | class LoudnessSpectralcentroidAPDetector(BaseAPDetector): 10 | def __init__(self, **kwargs): 11 | self.spectral_centroid_threshold = 40 12 | self.spl_threshold = 20 13 | 14 | self.device = "cpu" if not torch.cuda.is_available() else "cuda" 15 | 16 | self.sample_rate = 44100 17 | self.hop_length = 512 18 | self.n_fft = 2048 19 | self.win_length = 1024 20 | self.hann_window = torch.hann_window(self.win_length).to(self.device) 21 | 22 | self.conv = nn.Conv1d( 23 | 1, 1, self.hop_length, self.hop_length, self.hop_length // 2, bias=False 24 | ).to(self.device) 25 | self.conv.requires_grad_(False) 26 | self.conv.weight.data.fill_(1.0 / self.hop_length) 27 | 28 | def _get_spl(self, wav): 29 | out = self.conv(wav.pow(2).unsqueeze(0).unsqueeze(0)) 30 | out = 20 * torch.log10(out.sqrt() / 2 * 10e5) 31 | return out.squeeze(0).squeeze(0) 32 | 33 | def _get_spectral_centroid(self, wav): 34 | wav = nn.functional.pad(wav, (self.n_fft // 2, (self.n_fft + 1) // 2)) 35 | fft = torch.stft( 36 | wav, 37 | n_fft=self.n_fft, 38 | hop_length=self.hop_length, 39 | win_length=self.win_length, 40 | window=self.hann_window, 41 | center=False, 42 | return_complex=True, 43 | ) 44 | magnitude = fft.abs().pow(2) 45 | magnitude = magnitude / magnitude.sum(dim=-2, keepdim=True) 46 | 47 | spectral_centroid = torch.sum( 48 | (1 + torch.arange(0, self.n_fft // 2 + 1)) 49 | .float() 50 | .unsqueeze(-1) 51 | .to(self.device) 52 | * magnitude, 53 | dim=0, 54 | ) 55 | 56 | return spectral_centroid 57 | 58 | def _get_diff_intervals(self, intervals_a, intervals_b): 59 | # get complement of interval_b 60 | if intervals_a.shape[0] <= 0: 61 | return np.array([]) 62 | if intervals_b.shape[0] <= 0: 63 | return intervals_a 64 | intervals_b = np.stack( 65 | [ 66 | np.concatenate([[0.0], intervals_b[:, 1]]), 67 | np.concatenate([intervals_b[:, 0], intervals_a[[-1], [-1]]]), 68 | ], 69 | axis=-1, 70 | ) 71 | intervals_b = intervals_b[(intervals_b[:, 0] < intervals_b[:, 1]), :] 72 | 73 | idx_a = 0 74 | idx_b = 0 75 | intersection_intervals = [] 76 | while idx_a < intervals_a.shape[0] and idx_b < intervals_b.shape[0]: 77 | start_a, end_a = intervals_a[idx_a] 78 | start_b, end_b = intervals_b[idx_b] 79 | if end_a <= start_b: 80 | idx_a += 1 81 | continue 82 | if end_b <= start_a: 83 | idx_b += 1 84 | continue 85 | intersection_intervals.append([max(start_a, start_b), min(end_a, end_b)]) 86 | if end_a < end_b: 87 | idx_a += 1 88 | else: 89 | idx_b += 1 90 | 91 | return np.array(intersection_intervals) 92 | 93 | def _process_one( 94 | self, 95 | wav_path, 96 | wav_length, 97 | confidence, 98 | ph_seq, 99 | ph_intervals, 100 | word_seq, 101 | word_intervals, 102 | ): 103 | # input: 104 | # wav_path: pathlib.Path 105 | # ph_seq: list of phonemes, SP is the silence phoneme. 106 | # ph_intervals: np.ndarray of shape (n_ph, 2), ph_intervals[i] = [start, end] 107 | # means the i-th phoneme starts at start and ends at end. 108 | # word_seq: list of words. 109 | # word_intervals: np.ndarray of shape (n_word, 2), word_intervals[i] = [start, end] 110 | 111 | # output: same as the input. 112 | wav = load_wav(wav_path, self.device, self.sample_rate) 113 | wav = 0.01 * (wav - wav.mean()) / wav.std() 114 | 115 | # ap intervals 116 | spl = self._get_spl(wav) 117 | spectral_centroid = self._get_spectral_centroid(wav) 118 | ap_frame = (spl > self.spl_threshold) & ( 119 | spectral_centroid > self.spectral_centroid_threshold 120 | ) 121 | ap_frame_diff = torch.diff( 122 | torch.cat( 123 | [ 124 | torch.tensor([0], device=self.device), 125 | ap_frame, 126 | torch.tensor([0], device=self.device), 127 | ] 128 | ), 129 | dim=0, 130 | ) 131 | ap_start_idx = torch.where(ap_frame_diff == 1)[0] 132 | ap_end_idx = torch.where(ap_frame_diff == -1)[0] 133 | ap_intervals = torch.stack([ap_start_idx, ap_end_idx], dim=-1) * ( 134 | self.hop_length / self.sample_rate 135 | ) 136 | ap_intervals = self._get_diff_intervals( 137 | ap_intervals.cpu().numpy(), word_intervals 138 | ) 139 | if ap_intervals.shape[0] <= 0: 140 | return ( 141 | wav_path, 142 | wav_length, 143 | confidence, 144 | ph_seq, 145 | ph_intervals, 146 | word_seq, 147 | word_intervals, 148 | ) 149 | ap_intervals = ap_intervals[(ap_intervals[:, 1] - ap_intervals[:, 0]) > 0.1, :] 150 | 151 | # merge 152 | ap_tuple_list = [ 153 | ("AP", ap_start, ap_end) 154 | for (ap_start, ap_end) in zip(ap_intervals[:, 0], ap_intervals[:, 1]) 155 | ] 156 | word_tuple_list = [ 157 | (word, word_start, word_end) 158 | for (word, (word_start, word_end)) in zip(word_seq, word_intervals) 159 | ] 160 | word_tuple_list.extend(ap_tuple_list) 161 | ph_tuple_list = [ 162 | (ph, ph_start, ph_end) 163 | for (ph, (ph_start, ph_end)) in zip(ph_seq, ph_intervals) 164 | ] 165 | ph_tuple_list.extend(ap_tuple_list) 166 | 167 | # sort 168 | word_tuple_list.sort(key=lambda x: x[1]) 169 | ph_tuple_list.sort(key=lambda x: x[1]) 170 | 171 | ph_seq = [ph for (ph, _, _) in ph_tuple_list] 172 | ph_intervals = np.array([(start, end) for (_, start, end) in ph_tuple_list]) 173 | 174 | word_seq = [word for (word, _, _) in word_tuple_list] 175 | word_intervals = np.array([(start, end) for (_, start, end) in word_tuple_list]) 176 | 177 | return ( 178 | wav_path, 179 | wav_length, 180 | confidence, 181 | ph_seq, 182 | ph_intervals, 183 | word_seq, 184 | word_intervals, 185 | ) 186 | -------------------------------------------------------------------------------- /modules/rmvpe/deepunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .constants import N_MELS 4 | 5 | 6 | class ConvBlockRes(nn.Module): 7 | def __init__(self, in_channels, out_channels, momentum=0.01): 8 | super(ConvBlockRes, self).__init__() 9 | self.conv = nn.Sequential( 10 | nn.Conv2d(in_channels=in_channels, 11 | out_channels=out_channels, 12 | kernel_size=(3, 3), 13 | stride=(1, 1), 14 | padding=(1, 1), 15 | bias=False), 16 | nn.BatchNorm2d(out_channels, momentum=momentum), 17 | nn.ReLU(), 18 | 19 | nn.Conv2d(in_channels=out_channels, 20 | out_channels=out_channels, 21 | kernel_size=(3, 3), 22 | stride=(1, 1), 23 | padding=(1, 1), 24 | bias=False), 25 | nn.BatchNorm2d(out_channels, momentum=momentum), 26 | nn.ReLU(), 27 | ) 28 | if in_channels != out_channels: 29 | self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) 30 | self.is_shortcut = True 31 | else: 32 | self.is_shortcut = False 33 | 34 | def forward(self, x): 35 | if self.is_shortcut: 36 | return self.conv(x) + self.shortcut(x) 37 | else: 38 | return self.conv(x) + x 39 | 40 | 41 | class ResEncoderBlock(nn.Module): 42 | def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): 43 | super(ResEncoderBlock, self).__init__() 44 | self.n_blocks = n_blocks 45 | self.conv = nn.ModuleList() 46 | self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) 47 | for i in range(n_blocks - 1): 48 | self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) 49 | self.kernel_size = kernel_size 50 | if self.kernel_size is not None: 51 | self.pool = nn.AvgPool2d(kernel_size=kernel_size) 52 | 53 | def forward(self, x): 54 | for i in range(self.n_blocks): 55 | x = self.conv[i](x) 56 | if self.kernel_size is not None: 57 | return x, self.pool(x) 58 | else: 59 | return x 60 | 61 | 62 | class ResDecoderBlock(nn.Module): 63 | def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): 64 | super(ResDecoderBlock, self).__init__() 65 | out_padding = (0, 1) if stride == (1, 2) else (1, 1) 66 | self.n_blocks = n_blocks 67 | self.conv1 = nn.Sequential( 68 | nn.ConvTranspose2d(in_channels=in_channels, 69 | out_channels=out_channels, 70 | kernel_size=(3, 3), 71 | stride=stride, 72 | padding=(1, 1), 73 | output_padding=out_padding, 74 | bias=False), 75 | nn.BatchNorm2d(out_channels, momentum=momentum), 76 | nn.ReLU(), 77 | ) 78 | self.conv2 = nn.ModuleList() 79 | self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) 80 | for i in range(n_blocks-1): 81 | self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) 82 | 83 | def forward(self, x, concat_tensor): 84 | x = self.conv1(x) 85 | x = torch.cat((x, concat_tensor), dim=1) 86 | for i in range(self.n_blocks): 87 | x = self.conv2[i](x) 88 | return x 89 | 90 | 91 | class Encoder(nn.Module): 92 | def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): 93 | super(Encoder, self).__init__() 94 | self.n_encoders = n_encoders 95 | self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) 96 | self.layers = nn.ModuleList() 97 | self.latent_channels = [] 98 | for i in range(self.n_encoders): 99 | self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)) 100 | self.latent_channels.append([out_channels, in_size]) 101 | in_channels = out_channels 102 | out_channels *= 2 103 | in_size //= 2 104 | self.out_size = in_size 105 | self.out_channel = out_channels 106 | 107 | def forward(self, x): 108 | concat_tensors = [] 109 | x = self.bn(x) 110 | for i in range(self.n_encoders): 111 | _, x = self.layers[i](x) 112 | concat_tensors.append(_) 113 | return x, concat_tensors 114 | 115 | 116 | class Intermediate(nn.Module): 117 | def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): 118 | super(Intermediate, self).__init__() 119 | self.n_inters = n_inters 120 | self.layers = nn.ModuleList() 121 | self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)) 122 | for i in range(self.n_inters-1): 123 | self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) 124 | 125 | def forward(self, x): 126 | for i in range(self.n_inters): 127 | x = self.layers[i](x) 128 | return x 129 | 130 | 131 | class Decoder(nn.Module): 132 | def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): 133 | super(Decoder, self).__init__() 134 | self.layers = nn.ModuleList() 135 | self.n_decoders = n_decoders 136 | for i in range(self.n_decoders): 137 | out_channels = in_channels // 2 138 | self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) 139 | in_channels = out_channels 140 | 141 | def forward(self, x, concat_tensors): 142 | for i in range(self.n_decoders): 143 | x = self.layers[i](x, concat_tensors[-1-i]) 144 | return x 145 | 146 | 147 | class TimbreFilter(nn.Module): 148 | def __init__(self, latent_rep_channels): 149 | super(TimbreFilter, self).__init__() 150 | self.layers = nn.ModuleList() 151 | for latent_rep in latent_rep_channels: 152 | self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0])) 153 | 154 | def forward(self, x_tensors): 155 | out_tensors = [] 156 | for i, layer in enumerate(self.layers): 157 | out_tensors.append(layer(x_tensors[i])) 158 | return out_tensors 159 | 160 | 161 | class DeepUnet0(nn.Module): 162 | def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): 163 | super(DeepUnet0, self).__init__() 164 | self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) 165 | self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) 166 | self.tf = TimbreFilter(self.encoder.latent_channels) 167 | self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) 168 | 169 | def forward(self, x): 170 | x, concat_tensors = self.encoder(x) 171 | x = self.intermediate(x) 172 | x = self.decoder(x, concat_tensors) 173 | return x 174 | -------------------------------------------------------------------------------- /modules/layer/block/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange, repeat 5 | 6 | 7 | class MultiHeadSelfAttention(nn.Module): 8 | def __init__( 9 | self, 10 | model_dim: int, 11 | num_heads: int, 12 | max_seq_len: int = 3200, 13 | dropout: float = 0.0, 14 | mask: str = "none", 15 | init_type: str = "kaiming_uniform", 16 | ): 17 | super().__init__() 18 | 19 | assert num_heads > 0, "num_heads must be positive" 20 | assert model_dim % num_heads == 0, "model_dim must be divisible by num_heads" 21 | assert ( 22 | model_dim // num_heads 23 | ) % 2 == 0, "model_dim // num_heads must be divisible by 2" 24 | assert max_seq_len > 0, "max_seq_len must be positive" 25 | assert 0.0 <= dropout < 1.0, "dropout must be in range [0.0, 1.0)" 26 | assert mask in [ 27 | "none", 28 | "upper", 29 | "lower", 30 | ], "mask must be one of [none, upper, lower]" 31 | assert init_type in [ 32 | "xavier_uniform", 33 | "xavier_normal", 34 | "kaiming_uniform", 35 | "kaiming_normal", 36 | ], "init_type must be one of [xavier_uniform, xavier_normal, kaiming_uniform, kaiming_normal]" 37 | 38 | self.max_seq_len = max_seq_len 39 | self.model_dim = model_dim 40 | self.d_k = model_dim // num_heads 41 | self.num_heads = num_heads 42 | self.mask = mask 43 | 44 | self.wq = nn.Linear(model_dim, model_dim) 45 | self.wk = nn.Linear(model_dim, model_dim) 46 | self.wv = nn.Linear(model_dim, model_dim) 47 | self.linear = nn.Linear(model_dim, model_dim) 48 | self.dropout = nn.Dropout(dropout) 49 | self.register_buffer( 50 | "theta_base", torch.Tensor([10000.0]) 51 | ) # 常量tensor需要用register_buffer注册,否则.to(device)不起作用 52 | cos, sin = self.precompute_rotation_matrix(self.max_seq_len, self.theta_base) 53 | self.register_buffer("rotation_matrix_cos", cos) 54 | self.register_buffer("rotation_matrix_sin", sin) 55 | 56 | self.init_type = init_type 57 | self.apply(self.init_weights) 58 | 59 | def init_weights(self, module): 60 | init_type = self.init_type 61 | if isinstance(module, nn.Linear): 62 | if init_type == "xavier_uniform": 63 | nn.init.xavier_uniform_(module.weight) 64 | elif init_type == "xavier_normal": 65 | nn.init.xavier_normal_(module.weight) 66 | elif init_type == "kaiming_uniform": 67 | nn.init.kaiming_uniform_(module.weight) 68 | elif init_type == "kaiming_normal": 69 | nn.init.kaiming_normal_(module.weight) 70 | nn.init.constant_(module.bias, 0) 71 | 72 | def precompute_rotation_matrix(self, seq_len: int, theta_base): 73 | dim = self.d_k 74 | power = torch.arange(dim // 2) * (-2) / dim 75 | theta_vector = torch.pow(theta_base, power) 76 | position_vector = torch.arange(seq_len) 77 | rotation_angle_matrix = torch.outer(position_vector, theta_vector) 78 | rotation_angle_matrix = repeat( 79 | rotation_angle_matrix, "l d -> l (d repeat)", repeat=2 80 | ) 81 | rotation_matrix_cos = torch.cos(rotation_angle_matrix).unsqueeze(0).unsqueeze(0) 82 | rotation_matrix_sin = torch.sin(rotation_angle_matrix).unsqueeze(0).unsqueeze(0) 83 | return rotation_matrix_cos, rotation_matrix_sin 84 | 85 | def apply_rotary_emb(self, xq: torch.Tensor, xk: torch.Tensor): 86 | # xq.shape = [batch_size, num_heads, seq_len, d_k] 87 | def get_sin_weight(q): 88 | q = rearrange(q, "b h t (d1 d2) -> b h t d2 d1", d2=2) 89 | q = q.clone() 90 | q[:, :, :, 1, :] = -1 * q[:, :, :, 1, :] 91 | q = q[:, :, :, [1, 0], :] 92 | q = rearrange(q, "b h t d2 d1 -> b h t (d1 d2)") 93 | return q 94 | 95 | # print(xq.shape, self.rotation_matrix_cos.shape) 96 | xq_ = get_sin_weight(xq) 97 | xk_ = get_sin_weight(xk) 98 | xq_out = ( 99 | xq 100 | * self.rotation_matrix_cos[ 101 | :, 102 | :, 103 | xq.shape[2], 104 | :, 105 | ] 106 | + xq_ 107 | * self.rotation_matrix_sin[ 108 | :, 109 | :, 110 | xq.shape[2], 111 | :, 112 | ] 113 | ) 114 | xk_out = ( 115 | xk 116 | * self.rotation_matrix_cos[ 117 | :, 118 | :, 119 | xk.shape[2], 120 | :, 121 | ] 122 | + xk_ 123 | * self.rotation_matrix_sin[ 124 | :, 125 | :, 126 | xk.shape[2], 127 | :, 128 | ] 129 | ) 130 | 131 | return xq_out, xk_out 132 | 133 | def _update_RoPE(self, seq_len): 134 | cos, sin = self.precompute_rotation_matrix(seq_len, self.theta_base) 135 | self.cos = cos 136 | self.sin = sin 137 | self.max_seq_len = seq_len 138 | 139 | def forward(self, x: torch.Tensor): # , lengths=None 140 | # input: Tensor[batch_size seq_len, hidden_dims], lengths: Tensor[batch_size] 141 | # output: Tensor[batch_size seq_len, hidden_dims] 142 | batch_size, seq_len, _ = x.shape 143 | 144 | if self.max_seq_len < seq_len: 145 | self._update_RoPE(seq_len) 146 | 147 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 148 | xq = xq.view(batch_size, self.num_heads, seq_len, self.d_k) 149 | xk = xk.view(batch_size, self.num_heads, seq_len, self.d_k) 150 | xv = xv.view(batch_size, self.num_heads, seq_len, self.d_k) 151 | 152 | xq, xk = self.apply_rotary_emb(xq, xk) 153 | 154 | scores = torch.matmul(xq, xk.transpose(-2, -1)) / np.sqrt( 155 | self.d_k 156 | ) # size: (batch_size, num_heads, seq_len, seq_len) 157 | if self.mask == "upper": 158 | mask = torch.triu(torch.ones_like(scores[0, 0]).float(), diagonal=1) 159 | scores.masked_fill_(mask == 1, -1e9) 160 | elif self.mask == "lower": 161 | mask = torch.tril(torch.ones_like(scores[0, 0]).float(), diagonal=-1) 162 | scores.masked_fill_(mask == 1, -1e9) 163 | # if lengths is not None: 164 | # mask = torch.arange(seq_len).to(x.device)[None, :] >= lengths[:, None] 165 | # scores.masked_fill_(mask[:, None, None, :] == 1, -1e9) 166 | # scores.masked_fill_(mask[:, None, :, None] == 1, -1e9) 167 | scores = torch.nn.functional.softmax(scores, dim=-1) 168 | scores = self.dropout(scores) 169 | 170 | output = torch.matmul(scores, xv) # size: (batch_size, num_heads, seq_len, d_k) 171 | output = rearrange(output, "b h t d -> b t (h d)") 172 | 173 | output = self.linear(output) # size: (batch_size, seq_len, model_dim) 174 | return output 175 | 176 | 177 | if __name__ == "__main__": 178 | model = MultiHeadSelfAttention(128, 8) 179 | tensor_x = torch.randn(4, 32, 128) 180 | y = model(tensor_x) 181 | print(y.shape) 182 | -------------------------------------------------------------------------------- /modules/utils/export_tool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import textgrid 4 | 5 | 6 | class Exporter: 7 | def __init__(self, predictions, log): 8 | self.predictions = predictions 9 | self.log = log 10 | 11 | def save_textgrids(self): 12 | print("Saving TextGrids...") 13 | 14 | for ( 15 | wav_path, 16 | wav_length, 17 | confidence, 18 | ph_seq, 19 | ph_intervals, 20 | word_seq, 21 | word_intervals, 22 | ) in self.predictions: 23 | tg = textgrid.TextGrid() 24 | word_tier = textgrid.IntervalTier(name="words") 25 | ph_tier = textgrid.IntervalTier(name="phones") 26 | 27 | for word, (start, end) in zip(word_seq, word_intervals): 28 | word_tier.add(start, end, word) 29 | 30 | for ph, (start, end) in zip(ph_seq, ph_intervals): 31 | ph_tier.add(minTime=float(start), maxTime=end, mark=ph) 32 | 33 | tg.append(word_tier) 34 | tg.append(ph_tier) 35 | 36 | label_path = ( 37 | wav_path.parent / "TextGrid" / wav_path.with_suffix(".TextGrid").name 38 | ) 39 | label_path.parent.mkdir(parents=True, exist_ok=True) 40 | tg.write(label_path) 41 | 42 | def save_htk(self): 43 | print("Saving htk labels...") 44 | 45 | for ( 46 | wav_path, 47 | wav_length, 48 | confidence, 49 | ph_seq, 50 | ph_intervals, 51 | word_seq, 52 | word_intervals, 53 | ) in self.predictions: 54 | label = "" 55 | for ph, (start, end) in zip(ph_seq, ph_intervals): 56 | start_time = int(float(start) * 10000000) 57 | end_time = int(float(end) * 10000000) 58 | label += f"{start_time} {end_time} {ph}\n" 59 | label_path = ( 60 | wav_path.parent / "htk" / "phones" / wav_path.with_suffix(".lab").name 61 | ) 62 | label_path.parent.mkdir(parents=True, exist_ok=True) 63 | with open(label_path, "w", encoding="utf-8") as f: 64 | f.write(label) 65 | f.close() 66 | 67 | label = "" 68 | for word, (start, end) in zip(word_seq, word_intervals): 69 | start_time = int(float(start) * 10000000) 70 | end_time = int(float(end) * 10000000) 71 | label += f"{start_time} {end_time} {word}\n" 72 | label_path = ( 73 | wav_path.parent / "htk" / "words" / wav_path.with_suffix(".lab").name 74 | ) 75 | label_path.parent.mkdir(parents=True, exist_ok=True) 76 | with open(label_path, "w", encoding="utf-8") as f: 77 | f.write(label) 78 | f.close() 79 | 80 | def save_transcriptions(self): 81 | print("Saving transcriptions.csv...") 82 | 83 | folder_to_data = {} 84 | 85 | for ( 86 | wav_path, 87 | wav_length, 88 | confidence, 89 | ph_seq, 90 | ph_intervals, 91 | word_seq, 92 | word_intervals, 93 | ) in self.predictions: 94 | folder = wav_path.parent 95 | if folder in folder_to_data: 96 | curr_data = folder_to_data[folder] 97 | else: 98 | curr_data = { 99 | "name": [], 100 | "word_seq": [], 101 | "word_dur": [], 102 | "ph_seq": [], 103 | "ph_dur": [], 104 | } 105 | 106 | name = wav_path.with_suffix("").name 107 | word_seq = " ".join(word_seq) 108 | ph_seq = " ".join(ph_seq) 109 | word_dur = [] 110 | ph_dur = [] 111 | 112 | last_word_end = 0 113 | for start, end in word_intervals: 114 | dur = np.round(end - last_word_end, 5) 115 | word_dur.append(dur) 116 | last_word_end += dur 117 | 118 | last_ph_end = 0 119 | for start, end in ph_intervals: 120 | dur = np.round(end - last_ph_end, 5) 121 | ph_dur.append(dur) 122 | last_ph_end += dur 123 | 124 | word_dur = " ".join([str(i) for i in word_dur]) 125 | ph_dur = " ".join([str(i) for i in ph_dur]) 126 | 127 | curr_data["name"].append(name) 128 | curr_data["word_seq"].append(word_seq) 129 | curr_data["word_dur"].append(word_dur) 130 | curr_data["ph_seq"].append(ph_seq) 131 | curr_data["ph_dur"].append(ph_dur) 132 | 133 | folder_to_data[folder] = curr_data 134 | 135 | for folder, data in folder_to_data.items(): 136 | df = pd.DataFrame(data) 137 | path = folder / "transcriptions" 138 | if not path.exists(): 139 | path.mkdir(parents=True, exist_ok=True) 140 | df.to_csv(path / "transcriptions.csv", index=False) 141 | 142 | def save_confidence_fn(self): 143 | print("saving confidence...") 144 | 145 | folder_to_data = {} 146 | 147 | for ( 148 | wav_path, 149 | wav_length, 150 | confidence, 151 | ph_seq, 152 | ph_intervals, 153 | word_seq, 154 | word_intervals, 155 | ) in self.predictions: 156 | folder = wav_path.parent 157 | if folder in folder_to_data: 158 | curr_data = folder_to_data[folder] 159 | else: 160 | curr_data = { 161 | "name": [], 162 | "confidence": [], 163 | } 164 | 165 | name = wav_path.with_suffix("").name 166 | curr_data["name"].append(name) 167 | curr_data["confidence"].append(confidence) 168 | 169 | folder_to_data[folder] = curr_data 170 | 171 | for folder, data in folder_to_data.items(): 172 | df = pd.DataFrame(data) 173 | path = folder / "confidence" 174 | if not path.exists(): 175 | path.mkdir(parents=True, exist_ok=True) 176 | df.to_csv(path / "confidence.csv", index=False) 177 | 178 | def export(self, out_formats): 179 | if "textgrid" in out_formats or "praat" in out_formats: 180 | self.save_textgrids() 181 | if ( 182 | "htk" in out_formats 183 | or "lab" in out_formats 184 | or "nnsvs" in out_formats 185 | or "sinsy" in out_formats 186 | ): 187 | self.save_htk() 188 | if ( 189 | "trans" in out_formats 190 | or "transcription" in out_formats 191 | or "transcriptions" in out_formats 192 | or "transcriptions.csv" in out_formats 193 | or "diffsinger" in out_formats 194 | ): 195 | self.save_transcriptions() 196 | 197 | if "confidence" in out_formats: 198 | self.save_confidence_fn() 199 | 200 | if self.log: 201 | print("error:") 202 | for line in self.log: 203 | print(line) 204 | -------------------------------------------------------------------------------- /modules/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from functools import lru_cache 2 | 3 | import textgrid as tg 4 | 5 | 6 | class Metric: 7 | """ 8 | A torchmetrics.Metric-like class with similar methods but lowered computing overhead. 9 | """ 10 | 11 | def update(self, pred, target): 12 | raise NotImplementedError() 13 | 14 | def compute(self): 15 | raise NotImplementedError() 16 | 17 | def reset(self): 18 | raise NotImplementedError() 19 | 20 | 21 | class VlabelerEditsCount(Metric): 22 | """ 23 | 在vlabeler中,将pred编辑为target所需要的最少次数 24 | The edit distance between pred and target in vlabeler. 25 | """ 26 | 27 | def __init__(self, move_tolerance=0.02): 28 | self.move_tolerance = move_tolerance 29 | self.counts = 0 30 | 31 | def update(self, pred: tg.PointTier, target: tg.PointTier): 32 | # 获得从pred编辑到target所需要的最少次数 33 | # 注意这是一个略微简化的模型,不一定和vlabeler完全一致。 34 | # 编辑操作包括: 35 | # 插入边界 36 | # 删除边界及其前一个音素(和vlabeler的操作对应) 37 | # 移动边界(如果边界距离大于move_tolerance s,就要移动) 38 | # 音素替换 39 | 40 | # vlabeler中,对TextGrid有要求,如果要满足要求的话, 41 | # PointTier中的第一个边界位置不需要编辑,最后一个音素必定为空 42 | assert len(pred) >= 2 and len(target) >= 2 43 | assert pred[0].time == target[0].time 44 | # assert target[-1].time == pred[-1].time 45 | assert pred[-1].mark == "" and target[-1].mark == "" 46 | 47 | @lru_cache(maxsize=None) 48 | def dfs(i, j): 49 | # 返回将pred[:i]更改为target[:j]所需的编辑次数 50 | 51 | # 边界条件 52 | if i == 0: 53 | # 一直插入边界直到j个边界,每次插入一个边界还要修改一个音素,所以是2j 54 | return j * 2 55 | if j == 0: 56 | # 删除边界的同时会删除前方的音素,删除i次 57 | return i 58 | 59 | # case1: 插入边界,pred[:i+1]只能覆盖到target[:j],所以要插入一个边界,和target[j+1]对应 60 | # 如果和上一个音素相同,那么就无需修改音素 61 | insert = dfs(i, j - 1) + 1 62 | if j == 1 or target[j - 1].mark != target[j - 2].mark: 63 | insert += 1 64 | # case2: 删除边界,pred[:i]已经能覆盖到target[:j+1],pred[i+1]完全无用,可以删了 65 | # 这里跟vlabeler的操作是一致的,vlabeler删除一个音素会同时删除前面的边界,这里是删除边界会同时删除后一个音素 66 | # 因为被删除了,所以无需修改音素 67 | delete = dfs(i - 1, j) + 1 68 | # case3:移动(也可以不移动)边界 69 | # 如果边界距离大于move_tolerance s,就要移动,否则不需要 70 | # 如果音素不一致就要修改,否则不需要 71 | move = dfs(i - 1, j - 1) 72 | if abs(pred[i - 1].time - target[j - 1].time) > self.move_tolerance: 73 | move += 1 74 | if pred[i - 1].mark != target[j - 1].mark: 75 | move += 1 76 | 77 | return min(insert, delete, move) 78 | 79 | self.counts += dfs(len(pred), len(target)) 80 | 81 | def compute(self): 82 | return self.counts 83 | 84 | def reset(self): 85 | self.counts = 0 86 | 87 | 88 | class VlabelerEditRatio(Metric): 89 | """ 90 | 编辑距离除以target的总长度 91 | Edit distance divided by total length of target. 92 | """ 93 | 94 | def __init__(self, move_tolerance=0.02): 95 | self.edit_distance = VlabelerEditsCount(move_tolerance) 96 | self.total = 0 97 | 98 | def update(self, pred: tg.PointTier, target: tg.PointTier): 99 | self.edit_distance.update(pred, target) 100 | # PointTier中的第一个边界位置不需要编辑,最后一个音素必定为空 101 | self.total += 2 * len(target) - 2 102 | 103 | def compute(self): 104 | if self.total == 0: 105 | return None 106 | return round(self.edit_distance.compute() / self.total, 6) 107 | 108 | def reset(self): 109 | self.edit_distance.reset() 110 | self.total = 0 111 | 112 | 113 | class IntersectionOverUnion(Metric): 114 | """ 115 | 所有音素的交并比 116 | Intersection over union of all phonemes. 117 | """ 118 | 119 | def __init__(self): 120 | self.intersection = {} 121 | self.sum = {} 122 | 123 | def update(self, pred: tg.PointTier, target: tg.PointTier): 124 | len_pred = len(pred) - 1 125 | len_target = len(target) - 1 126 | for i in range(len_pred): 127 | if pred[i].mark not in self.sum: 128 | self.sum[pred[i].mark] = pred[i + 1].time - pred[i].time 129 | self.intersection[pred[i].mark] = 0 130 | else: 131 | self.sum[pred[i].mark] += pred[i + 1].time - pred[i].time 132 | for j in range(len_target): 133 | if target[j].mark not in self.sum: 134 | self.sum[target[j].mark] = target[j + 1].time - target[j].time 135 | self.intersection[target[j].mark] = 0 136 | else: 137 | self.sum[target[j].mark] += target[j + 1].time - target[j].time 138 | 139 | i = 0 140 | j = 0 141 | while i < len_pred and j < len_target: 142 | if pred[i].mark == target[j].mark: 143 | intersection = min(pred[i + 1].time, target[j + 1].time) - max( 144 | pred[i].time, target[j].time 145 | ) 146 | self.intersection[pred[i].mark] += ( 147 | intersection if intersection > 0 else 0 148 | ) 149 | 150 | if pred[i + 1].time < target[j + 1].time: 151 | i += 1 152 | elif pred[i + 1].time > target[j + 1].time: 153 | j += 1 154 | else: 155 | i += 1 156 | j += 1 157 | 158 | def compute(self, phonemes=None): 159 | if phonemes is None: 160 | return { 161 | k: round(v / (self.sum[k] - v), 6) for k, v in self.intersection.items() 162 | } 163 | 164 | if isinstance(phonemes, str): 165 | if phonemes in self.intersection: 166 | return round( 167 | self.intersection[phonemes] 168 | / (self.sum[phonemes] - self.intersection[phonemes]), 169 | 6, 170 | ) 171 | else: 172 | return None 173 | else: 174 | return { 175 | ph: ( 176 | round( 177 | self.intersection[ph] / (self.sum[ph] - self.intersection[ph]), 178 | 6, 179 | ) 180 | if ph in self.intersection 181 | else None 182 | ) 183 | for ph in phonemes 184 | } 185 | 186 | def reset(self): 187 | self.intersection = {} 188 | self.sum = {} 189 | 190 | 191 | class BoundaryEditDistance(Metric): 192 | """ 193 | The total moving distance from the predicted boundaries to the target boundaries. 194 | """ 195 | 196 | def __init__(self): 197 | self.distance = 0.0 198 | 199 | def update(self, pred: tg.PointTier, target: tg.PointTier): 200 | # 确保音素完全一致 201 | assert len(pred) == len(target) 202 | for i in range(len(pred)): 203 | assert pred[i].mark == target[i].mark 204 | 205 | # 计算边界距离 206 | for pred_point, target_point in zip(pred, target): 207 | self.distance += abs(pred_point.time - target_point.time) 208 | 209 | def compute(self): 210 | return round(self.distance, 6) 211 | 212 | def reset(self): 213 | self.distance = 0.0 214 | 215 | 216 | class BoundaryEditRatio(Metric): 217 | """ 218 | The boundary edit distance divided by the total duration of target intervals. 219 | """ 220 | 221 | def __init__(self): 222 | self.distance_metric = BoundaryEditDistance() 223 | self.duration = 0.0 224 | 225 | def update(self, pred: tg.PointTier, target: tg.PointTier): 226 | self.distance_metric.update(pred, target) 227 | self.duration += target[-1].time - target[0].time 228 | 229 | def compute(self): 230 | if self.duration == 0.0: 231 | return None 232 | return round(self.distance_metric.compute() / self.duration, 6) 233 | -------------------------------------------------------------------------------- /export_onnx.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import pathlib 3 | 4 | import click 5 | import onnx 6 | import onnxsim 7 | 8 | from typing import Any 9 | 10 | import lightning as pl 11 | import yaml 12 | from einops import repeat 13 | from modules.layer.backbone.unet import UNetBackbone 14 | from modules.layer.block.resnet_block import ResidualBasicBlock 15 | from modules.layer.scaling.stride_conv import DownSampling, UpSampling 16 | 17 | import torch 18 | import torch.nn as nn 19 | from librosa.filters import mel 20 | 21 | 22 | class MelSpectrogram_ONNX(nn.Module): 23 | def __init__( 24 | self, 25 | n_mel_channels, 26 | sampling_rate, 27 | win_length, 28 | hop_length, 29 | n_fft=None, 30 | mel_fmin=0, 31 | mel_fmax=None, 32 | clamp=1e-5 33 | ): 34 | super().__init__() 35 | n_fft = win_length if n_fft is None else n_fft 36 | mel_basis = mel( 37 | sr=sampling_rate, 38 | n_fft=n_fft, 39 | n_mels=n_mel_channels, 40 | fmin=mel_fmin, 41 | fmax=mel_fmax, 42 | htk=True) 43 | mel_basis = torch.from_numpy(mel_basis).float() 44 | self.register_buffer("mel_basis", mel_basis) 45 | self.n_fft = win_length if n_fft is None else n_fft 46 | self.hop_length = hop_length 47 | self.win_length = win_length 48 | self.sampling_rate = sampling_rate 49 | self.n_mel_channels = n_mel_channels 50 | self.clamp = clamp 51 | 52 | def forward(self, audio, center=True): 53 | fft = torch.stft( 54 | audio, 55 | n_fft=self.n_fft, 56 | hop_length=self.hop_length, 57 | win_length=self.win_length, 58 | window=torch.hann_window(self.win_length, device=audio.device), 59 | center=center, 60 | return_complex=False 61 | ) 62 | magnitude = torch.sqrt(torch.sum(fft ** 2, dim=-1)) 63 | mel_output = torch.matmul(self.mel_basis, magnitude) 64 | log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) 65 | return log_mel_spec 66 | 67 | 68 | class LitForcedAlignmentOnnx(pl.LightningModule): 69 | def __init__( 70 | self, 71 | vocab_text, 72 | model_config, 73 | melspec_config 74 | ): 75 | super().__init__() 76 | self.save_hyperparameters() 77 | 78 | self.vocab = yaml.safe_load(vocab_text) 79 | self.melspec_config = melspec_config 80 | 81 | self.backbone = UNetBackbone( 82 | melspec_config["n_mels"], 83 | model_config["hidden_dims"], 84 | model_config["hidden_dims"], 85 | ResidualBasicBlock, 86 | DownSampling, 87 | UpSampling, 88 | down_sampling_factor=model_config["down_sampling_factor"], # 3 89 | down_sampling_times=model_config["down_sampling_times"], # 7 90 | channels_scaleup_factor=model_config["channels_scaleup_factor"], # 1.5 91 | ) 92 | self.head = nn.Linear( 93 | model_config["hidden_dims"], self.vocab[""] + 2 94 | ) 95 | self.mel_extractor = MelSpectrogram_ONNX( 96 | melspec_config["n_mels"], melspec_config["sample_rate"], melspec_config["win_length"], 97 | melspec_config["hop_length"], melspec_config["n_fft"], melspec_config["fmin"], melspec_config["fmax"] 98 | ) 99 | 100 | def forward(self, waveform, num_frames, ph_seq_id) -> Any: 101 | melspec = self.mel_extractor(waveform).detach() 102 | melspec = (melspec - melspec.mean()) / melspec.std() 103 | melspec = repeat( 104 | melspec, "B C T -> B C (T N)", N=self.melspec_config["scale_factor"] 105 | ) 106 | 107 | h = self.backbone(melspec.transpose(1, 2)) 108 | logits = self.head(h) 109 | ph_frame_logits = logits[:, :, 2:] 110 | ph_edge_logits = logits[:, :, 0] 111 | ctc_logits = torch.cat([logits[:, :, [1]], logits[:, :, 3:]], dim=-1) 112 | 113 | ph_mask = torch.zeros(self.vocab[""]) 114 | ph_mask[ph_seq_id] = 1 115 | ph_mask[0] = 1 116 | 117 | ph_frame_logits = ph_frame_logits[:, :num_frames, :] 118 | ph_edge_logits = ph_edge_logits[:, :num_frames] 119 | ctc_logits = ctc_logits[:, :num_frames, :] 120 | 121 | ph_mask = ph_mask.to(ph_frame_logits.device).unsqueeze(0).unsqueeze(0).logical_not() * 1e9 122 | ph_frame_pred = torch.nn.functional.softmax(ph_frame_logits.float() - ph_mask.float(), dim=-1).squeeze(0) 123 | ph_prob_log = torch.log_softmax(ph_frame_logits.float() - ph_mask.float(), dim=-1).squeeze(0) 124 | ph_edge_pred = ((torch.nn.functional.sigmoid(ph_edge_logits.float()) - 0.1) / 0.8).clamp(0.0, 1.0) 125 | ph_edge_pred = ph_edge_pred.squeeze(0) 126 | ctc_logits = ctc_logits.float().squeeze(0) # (ctc_logits.squeeze(0) - ph_mask) 127 | 128 | T, vocab_size = ph_frame_pred.shape 129 | 130 | # decode 131 | diff_ph_edge_pred = ph_edge_pred[1:] - ph_edge_pred[:-1] 132 | edge_diff = torch.cat((diff_ph_edge_pred, torch.tensor([0.0], device=ph_edge_pred.device)), dim=0) 133 | edge_prob = (ph_edge_pred + torch.cat( 134 | (torch.tensor([0.0], device=ph_edge_pred.device), ph_edge_pred[:-1]))).clamp(0, 1) 135 | return edge_diff, edge_prob, ph_prob_log, ctc_logits, T 136 | 137 | 138 | @torch.no_grad() 139 | @click.command(help='') 140 | @click.option('--ckpt_path', required=True, metavar='DIR', help='Path to the checkpoint') 141 | @click.option('--onnx_path', required=True, metavar='DIR', help='Path to the onnx') 142 | def export(ckpt_path, onnx_path): 143 | assert ckpt_path is not None, "Checkpoint directory (ckpt_dir) cannot be None" 144 | 145 | assert os.path.exists(ckpt_path), f"Checkpoint path does not exist: {ckpt_path}" 146 | 147 | os.makedirs(pathlib.Path(onnx_path).parent, exist_ok=True) 148 | 149 | output_config = pathlib.Path(onnx_path).with_name('config.yaml') 150 | assert not os.path.exists(onnx_path), f"Error: The file '{onnx_path}' already exists." 151 | assert not output_config.exists(), f"Error: The file '{output_config}' already exists." 152 | 153 | model = LitForcedAlignmentOnnx.load_from_checkpoint(ckpt_path, strict=False) 154 | 155 | waveform = torch.randn((1, 44100), dtype=torch.float32) 156 | ph_seq_id = torch.zeros((1, 37), dtype=torch.int64) 157 | num_frames = torch.tensor(500, dtype=torch.int64) 158 | 159 | if torch.cuda.is_available(): 160 | model.cuda() 161 | waveform = waveform.cuda() 162 | 163 | with torch.no_grad(): 164 | torch.onnx.export( 165 | model, 166 | (waveform, num_frames, ph_seq_id), 167 | onnx_path, 168 | input_names=['waveform', 'num_frames', 'ph_seq_id'], 169 | output_names=['edge_diff', 'edge_prob', 'ph_prob_log', 'ctc_logits', 'T'], 170 | dynamic_axes={ 171 | 'waveform': {1: 'n_samples'}, 172 | 'ph_seq_id': {1: 'n_samples'}, 173 | 'edge_diff': {1: 'n_samples'}, 174 | 'edge_prob': {1: 'n_samples'}, 175 | 'ph_prob_log': {1: 'n_samples'}, 176 | 'ctc_logits': {1: 'n_samples'} 177 | }, 178 | opset_version=17 179 | ) 180 | onnx_model, check = onnxsim.simplify(onnx_path, include_subgraph=True) 181 | assert check, 'Simplified ONNX model could not be validated' 182 | onnx.save(onnx_model, onnx_path) 183 | print(f'Model saved to: {onnx_path}') 184 | 185 | out_config = {'melspec_config': model.hparams.melspec_config, 186 | 'model_config': model.hparams.model_config, 187 | 'vocab': yaml.safe_load(model.hparams.vocab_text) 188 | } 189 | with open(output_config, 'w') as file: 190 | yaml.dump(out_config, file, default_flow_style=False, allow_unicode=True) 191 | 192 | 193 | if __name__ == '__main__': 194 | export() 195 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # SOFA: Singing-Oriented Forced Aligner 2 | 3 | English | [简体中文](README_zh.MD) 4 | 5 | ![example](example.png) 6 | 7 | # Introduction 8 | 9 | SOFA (Singing-Oriented Forced Aligner) is a forced alignment tool designed specifically for singing voice. 10 | 11 | On singing data, SOFA has the following advantages over MFA (Montreal Forced Aligner): 12 | 13 | * Easier installation 14 | * Better performance 15 | * Faster inference speed 16 | 17 | # How to Use 18 | 19 | ## Environment Setup 20 | 21 | 1. Use `git clone` to download the code from this repository 22 | 2. Install conda 23 | 3. Create a conda environment, requiring Python version `3.8` 24 | ```bash 25 | conda create -n SOFA python=3.8 -y 26 | conda activate SOFA 27 | ``` 28 | 4. Go to the [pytorch official website](https://pytorch.org/get-started/locally/) to install torch 29 | 5. (Optional, to improve wav file reading speed) Go to the [pytorch official website](https://pytorch.org/get-started/locally/) to install torchaudio 30 | 6. Install other Python libraries 31 | ```bash 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ## Inference 36 | 37 | 1. Download the model file. You can find the trained models in the [pretrained model sharing category of the discussion section](https://github.com/qiuqiao/SOFA/discussions/categories/pretrained-model-sharing), with the file extension `.ckpt`. 38 | 2. Place the dictionary file in the `/dictionary` folder. The default dictionary is `opencpop-extension.txt` 39 | 3. Prepare the data for forced alignment and place it in a folder (by default in the `/segments` folder), with the following format 40 | ```text 41 | - segments 42 | - singer1 43 | - segment1.lab 44 | - segment1.wav 45 | - segment2.lab 46 | - segment2.wav 47 | - ... 48 | - singer2 49 | - segment1.lab 50 | - segment1.wav 51 | - ... 52 | ``` 53 | Ensure that the `.wav` files and their corresponding `.lab` files are in the same folder. 54 | 55 | The `.lab` file is the transcription for the `.wav` file with the same name. The file extension for the transcription can be changed using the `--in_format` parameter. 56 | 57 | After the transcription is converted into a phoneme sequence by the `g2p` module, it is fed into the model for alignment. 58 | 59 | For example, when using the `DictionaryG2P` module and the `opencpop-extension` dictionary by default, if the content of the transcription is: `gan shou ting zai wo fa duan de zhi jian`, the `g2p` module will convert it based on the dictionary into the phoneme sequence `g an sh ou t ing z ai w o f a duan d e zh ir j ian`. For how to use other `g2p` modules, see [g2p module usage instructions](modules/g2p/readme_g2p.md). 60 | 61 | 4. Command-line inference 62 | 63 | Use `python infer.py` to perform inference. 64 | 65 | Parameters that need to be specified: 66 | - `--ckpt`: (must be specified) The path to the model weights; 67 | - `--folder`: The folder where the data to be aligned is stored (default is `segments`); 68 | - `--in_format`: The file extension of the transcription (default is `lab`); 69 | - `--out_formats`: The annotation format of the inferred files, multiple formats can be specified, separated by commas (default is `TextGrid,htk,trans`). 70 | - `--save_confidence`: Output confidence scores. 71 | - `--dictionary`: The dictionary file (default is `dictionary/opencpop-extension.txt`); 72 | 73 | ```bash 74 | python infer.py --ckpt checkpoint_path --folder segments_path --dictionary dictionary_path --out_formats output_format1,output_format2... 75 | ``` 76 | 5. Retrieve the Final Annotation 77 | 78 | The final annotation is saved in a folder, the name of which is the annotation format you have chosen. This folder is located in the same directory as the wav files used for inference. 79 | 80 | ### Advanced Features 81 | 82 | - Using a custom g2p instead of a dictionary 83 | - See [g2p module instructions](modules/g2p/readme_g2p.md) 84 | - In the matching mode, you can activate it by specifying `-m` during inference. It finds the most probable contiguous sequence segment within the given phoneme sequence, rather than having to use all the phonemes. 85 | 86 | ## ONNX Inference 87 | 88 | 1. Install onnxruntime 89 | ```bash 90 | pip install onnxruntime-gpu 91 | ``` 92 | 2. Execute inference 93 | ```bash 94 | python onnx_infer.py \ 95 | --onnx /path/to/model.onnx \ 96 | --folder /input/audio/folder \ 97 | --g2p Dictionary \ 98 | --ap_detector LoudnessSpectralcentroidAPDetector \ 99 | --out_formats textgrid,htk 100 | ``` 101 | 102 | Parameter descriptions: 103 | 104 | - `--onnx` Path to the ONNX model file 105 | - `--folder` Input directory containing .wav audio and .lab annotations 106 | - `--g2p` Grapheme-to-phoneme converter (default: Dictionary) 107 | - `--ap_detector` Silence segment detection algorithm (default: LoudnessSpectralcentroidAPDetector) 108 | - `--out_formats` List of output formats (comma-separated) 109 | 110 | ## Training 111 | 112 | 1. Follow the steps above for setting up the environment. It is recommended to install torchaudio for faster binarization speed; 113 | 2. Place the training data in the `data` folder in the following format: 114 | 115 | ``` 116 | - data 117 | - full_label 118 | - singer1 119 | - wavs 120 | - audio1.wav 121 | - audio2.wav 122 | - ... 123 | - transcriptions.csv 124 | - singer2 125 | - wavs 126 | - ... 127 | - transcriptions.csv 128 | - weak_label 129 | - singer3 130 | - wavs 131 | - ... 132 | - transcriptions.csv 133 | - singer4 134 | - wavs 135 | - ... 136 | - transcriptions.csv 137 | - no_label 138 | - audio1.wav 139 | - audio2.wav 140 | - ... 141 | ``` 142 | Regarding the format of `transcriptions.csv`, see: https://github.com/qiuqiao/SOFA/discussions/5 143 | 144 | Where: 145 | 146 | `transcriptions.csv` only needs to have the correct relative path to the `wavs` folder; 147 | 148 | The `transcriptions.csv` in `weak_label` does not need to have a `ph_dur` column; 149 | 3. Modify `binarize_config.yaml` as needed, then execute `python binarize.py`; 150 | 4. Download the pre-trained model you need from releases, modify `train_config.yaml` as needed, then execute `python train.py -p path_to_your_pretrained_model`; 151 | 5. For training visualization: `tensorboard --logdir=ckpt/`. 152 | 153 | ## Export ONNX Model 154 | 155 | 1. Install onnxruntime 156 | ```bash 157 | pip install onnxruntime-gpu 158 | ``` 159 | 2. Export ONNX model 160 | ```bash 161 | python export_onnx.py --ckpt_path /path/to/checkpoint.ckpt --onnx_path /output/model.onnx 162 | ``` 163 | 164 | ## Evaluation (for model developers) 165 | 166 | To measure the performance of a model, it is useful to calculate some objective evaluation metrics between the predictions (force-aligned labels) and the targets (manual labels), especially in a k-fold cross-validation. 167 | 168 | Some useful metrics are: 169 | 170 | - Boundary Edit Distance: the total moving distance from the predicted boundaries to the target boundaries. 171 | - Boundary Edit Ratio: the boundary edit distance divided by the total duration of target intervals. 172 | - Boundary Error Rate: the proportion of misplaced boundaries to all target boundaries under a given tolerance of distance. 173 | 174 | To evaluate your model on a specific dataset, please first run the inference to get all predictions. You should put your predictions and targets in different folders, with same filenames and relative paths, containing the same phone sequences except for spaces. The script only supports TextGrid format currently. 175 | 176 | Run the following command: 177 | 178 | ```bash 179 | python evaluate.py -r -s 180 | ``` 181 | 182 | where `PRED_DIR` is a directory containing all predictions and `TARGET_DIR` is a directory containing all targets. 183 | 184 | Options: 185 | - `-r`, `--recursive`: compare the files in subdirectories recursively 186 | - `-s`, `--strict`: use strict mode (raise errors instead of skipping if the phones are not identical) 187 | - `--ignore`: ignore some phone marks (default: `AP,SP,,,,pau,cl`) 188 | 189 | The script will calculate: 190 | 191 | - The boundary edit ratio 192 | - The boundary error rate, under 10ms, 20ms and 50ms tolerance 193 | -------------------------------------------------------------------------------- /modules/utils/metrics_test.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import textgrid as tg 4 | 5 | from modules.utils.metrics import ( 6 | BoundaryEditDistance, 7 | BoundaryEditRatio, 8 | IntersectionOverUnion, 9 | VlabelerEditRatio, 10 | VlabelerEditsCount, 11 | ) 12 | 13 | 14 | def point_tier_from_list(list: List[Tuple[float, str]], name="") -> tg.PointTier: 15 | tier = tg.PointTier(name) 16 | for time, mark in list: 17 | tier.add(time, mark) 18 | return tier 19 | 20 | 21 | def get_vlabeler_edit_ratio(pred_tier, target_tier): 22 | dist = VlabelerEditsCount(move_tolerance=20) 23 | dist.update(pred_tier, target_tier) 24 | 25 | ratio = VlabelerEditRatio(move_tolerance=20) 26 | ratio.update(pred_tier, target_tier) 27 | return ratio.compute(), dist.compute() 28 | 29 | 30 | class TestVlabelerEditRatio: 31 | # 测试用例 1:完全相同 32 | def test_same(self): 33 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 34 | target_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 35 | edit_ratio, edit_num = get_vlabeler_edit_ratio(pred_tier, target_tier) 36 | assert edit_ratio == 0.0 37 | assert edit_num == 0 38 | 39 | # 测试用例 2:插入边界 40 | def test_insert(self): 41 | pred_tier = point_tier_from_list([(0, "a"), (100, "")]) 42 | target_tier = point_tier_from_list([(0, "a"), (50, "a"), (100, "")]) 43 | edit_ratio, edit_num = get_vlabeler_edit_ratio(pred_tier, target_tier) 44 | assert edit_ratio == round(1 / 3, 6) 45 | assert edit_num == 1 46 | 47 | pred_tier = point_tier_from_list([(0, "a"), (100, "")]) 48 | target_tier = point_tier_from_list([(0, "a"), (50, "b"), (100, "")]) 49 | edit_ratio, edit_num = get_vlabeler_edit_ratio(pred_tier, target_tier) 50 | assert edit_ratio == round(2 / 3, 6) 51 | assert edit_num == 2 52 | 53 | # 测试用例 3:删除边界 54 | def test_delete(self): 55 | pred_tier = point_tier_from_list([(0, "a"), (50, "b"), (100, "")]) 56 | target_tier = point_tier_from_list([(0, "a"), (100, "")]) 57 | edit_ratio, edit_num = get_vlabeler_edit_ratio(pred_tier, target_tier) 58 | assert edit_ratio == 1.0 59 | assert edit_num == 1, f"{edit_num}!= 1" 60 | 61 | # 测试用例 4:移动边界 62 | def test_move(self): 63 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 64 | target_tier = point_tier_from_list([(0, "a"), (121, "b"), (200, "")]) 65 | edit_ratio, edit_num = get_vlabeler_edit_ratio(pred_tier, target_tier) 66 | assert edit_ratio == round(1 / 3, 6) 67 | assert edit_num == 1 68 | 69 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 70 | target_tier = point_tier_from_list([(0, "a"), (120, "b"), (200, "")]) 71 | edit_ratio, edit_num = get_vlabeler_edit_ratio(pred_tier, target_tier) 72 | assert edit_ratio == 0.0 73 | assert edit_num == 0 74 | 75 | # 测试用例 5:音素替换 76 | def test_replace(self): 77 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 78 | target_tier = point_tier_from_list([(0, "b"), (100, "c"), (200, "")]) 79 | edit_ratio, edit_num = get_vlabeler_edit_ratio(pred_tier, target_tier) 80 | assert edit_ratio == round(2 / 3, 6) 81 | assert edit_num == 2 82 | 83 | 84 | class TestIntersectionOverUnion: 85 | def test_same(self): 86 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 87 | target_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 88 | answer = { 89 | "a": 1.0, 90 | "b": 1.0, 91 | } 92 | 93 | iou = IntersectionOverUnion() 94 | iou.update(pred_tier, target_tier) 95 | pred_answer = iou.compute() 96 | 97 | assert pred_answer == answer, f"{pred_answer}!= {answer}" 98 | 99 | def test_zero(self): 100 | pred_tier = point_tier_from_list([(0, "a"), (50, "b"), (250, "")]) 101 | target_tier = point_tier_from_list([(0, "c"), (150, "d"), (200, "")]) 102 | answer = { 103 | "a": 0.0, 104 | "b": 0.0, 105 | "c": 0.0, 106 | "d": 0.0, 107 | } 108 | 109 | iou = IntersectionOverUnion() 110 | iou.update(pred_tier, target_tier) 111 | pred_answer = iou.compute() 112 | 113 | assert pred_answer == answer, f"{pred_answer}!= {answer}" 114 | 115 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 116 | target_tier = point_tier_from_list([(0, "b"), (100, "a"), (200, "")]) 117 | answer = { 118 | "a": 0.0, 119 | "b": 0.0, 120 | } 121 | 122 | iou.reset() 123 | iou.update(pred_tier, target_tier) 124 | pred_answer = iou.compute() 125 | 126 | assert pred_answer == answer, f"{pred_answer}!= {answer}" 127 | 128 | def test_half(self): 129 | pred_tier = point_tier_from_list([(0, "a"), (50, "b"), (250, "")]) 130 | target_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 131 | answer = { 132 | "a": 0.5, 133 | "b": 0.5, 134 | } 135 | 136 | iou = IntersectionOverUnion() 137 | iou.update(pred_tier, target_tier) 138 | pred_answer = iou.compute() 139 | 140 | assert pred_answer == answer, f"{pred_answer}!= {answer}" 141 | 142 | 143 | class TestBoundaryEditDistance: 144 | def test_same(self): 145 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 146 | target_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 147 | answer = 0.0 148 | 149 | metric = BoundaryEditDistance() 150 | metric.update(pred_tier, target_tier) 151 | pred_answer = metric.compute() 152 | 153 | assert pred_answer == answer, f"{pred_answer}!= {answer}" 154 | 155 | def test_2(self): 156 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 157 | target_tier = point_tier_from_list([(0, "a"), (150, "b"), (200, "")]) 158 | answer = 50.0 159 | 160 | metric = BoundaryEditDistance() 161 | metric.update(pred_tier, target_tier) 162 | pred_answer = metric.compute() 163 | 164 | assert pred_answer == answer, f"{pred_answer}!= {answer}" 165 | 166 | def test_3(self): 167 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 168 | target_tier = point_tier_from_list([(50, "a"), (100, "b"), (150, "")]) 169 | answer = 100.0 170 | 171 | metric = BoundaryEditDistance() 172 | metric.update(pred_tier, target_tier) 173 | pred_answer = metric.compute() 174 | 175 | assert pred_answer == answer, f"{pred_answer}!= {answer}" 176 | 177 | def test_assert(self): 178 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 179 | target_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "c")]) 180 | 181 | metric = BoundaryEditDistance() 182 | try: 183 | metric.update(pred_tier, target_tier) 184 | except AssertionError: 185 | return 186 | assert False, "AssertionError not raised" 187 | 188 | 189 | class TestBoundaryEditRatio: 190 | def test_same(self): 191 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 192 | target_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 193 | answer = 0.0 194 | 195 | metric = BoundaryEditRatio() 196 | metric.update(pred_tier, target_tier) 197 | pred_answer = metric.compute() 198 | 199 | assert pred_answer == answer, f"{pred_answer}!= {answer}" 200 | 201 | def test_2(self): 202 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 203 | target_tier = point_tier_from_list([(0, "a"), (150, "b"), (200, "")]) 204 | answer = 50.0 / 200.0 205 | 206 | metric = BoundaryEditRatio() 207 | metric.update(pred_tier, target_tier) 208 | pred_answer = metric.compute() 209 | 210 | assert pred_answer == answer, f"{pred_answer}!= {answer}" 211 | 212 | def test_3(self): 213 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 214 | target_tier = point_tier_from_list([(50, "a"), (100, "b"), (150, "")]) 215 | answer = 100.0 / (150.0 - 50.0) 216 | 217 | metric = BoundaryEditRatio() 218 | metric.update(pred_tier, target_tier) 219 | pred_answer = metric.compute() 220 | 221 | assert pred_answer == answer, f"{pred_answer}!= {answer}" 222 | 223 | def test_assert(self): 224 | pred_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "")]) 225 | target_tier = point_tier_from_list([(0, "a"), (100, "b"), (200, "c")]) 226 | 227 | metric = BoundaryEditRatio() 228 | try: 229 | metric.update(pred_tier, target_tier) 230 | except AssertionError: 231 | return 232 | assert False, "AssertionError not raised" 233 | -------------------------------------------------------------------------------- /dictionary/opencpop-extension.txt: -------------------------------------------------------------------------------- 1 | a a 2 | ai ai 3 | an an 4 | ang ang 5 | ao ao 6 | ba b a 7 | bai b ai 8 | ban b an 9 | bang b ang 10 | bao b ao 11 | be b e 12 | bei b ei 13 | ben b en 14 | beng b eng 15 | ber b er 16 | bi b i 17 | bia b ia 18 | bian b ian 19 | biang b iang 20 | biao b iao 21 | bie b ie 22 | bin b in 23 | bing b ing 24 | biong b iong 25 | biu b iu 26 | bo b o 27 | bong b ong 28 | bou b ou 29 | bu b u 30 | bua b ua 31 | buai b uai 32 | buan b uan 33 | buang b uang 34 | bui b ui 35 | bun b un 36 | bv b v 37 | bve b ve 38 | ca c a 39 | cai c ai 40 | can c an 41 | cang c ang 42 | cao c ao 43 | ce c e 44 | cei c ei 45 | cen c en 46 | ceng c eng 47 | cer c er 48 | cha ch a 49 | chai ch ai 50 | chan ch an 51 | chang ch ang 52 | chao ch ao 53 | che ch e 54 | chei ch ei 55 | chen ch en 56 | cheng ch eng 57 | cher ch er 58 | chi ch ir 59 | chong ch ong 60 | chou ch ou 61 | chu ch u 62 | chua ch ua 63 | chuai ch uai 64 | chuan ch uan 65 | chuang ch uang 66 | chui ch ui 67 | chun ch un 68 | chuo ch uo 69 | chv ch v 70 | chyi ch i 71 | ci c i0 72 | cong c ong 73 | cou c ou 74 | cu c u 75 | cua c ua 76 | cuai c uai 77 | cuan c uan 78 | cuang c uang 79 | cui c ui 80 | cun c un 81 | cuo c uo 82 | cv c v 83 | cyi c i 84 | da d a 85 | dai d ai 86 | dan d an 87 | dang d ang 88 | dao d ao 89 | de d e 90 | dei d ei 91 | den d en 92 | deng d eng 93 | der d er 94 | di d i 95 | dia d ia 96 | dian d ian 97 | diang d iang 98 | diao d iao 99 | die d ie 100 | din d in 101 | ding d ing 102 | diong d iong 103 | diu d iu 104 | dong d ong 105 | dou d ou 106 | du d u 107 | dua d ua 108 | duai d uai 109 | duan d uan 110 | duang d uang 111 | dui d ui 112 | dun d un 113 | duo d uo 114 | dv d v 115 | dve d ve 116 | e e 117 | ei ei 118 | en en 119 | eng eng 120 | er er 121 | fa f a 122 | fai f ai 123 | fan f an 124 | fang f ang 125 | fao f ao 126 | fe f e 127 | fei f ei 128 | fen f en 129 | feng f eng 130 | fer f er 131 | fi f i 132 | fia f ia 133 | fian f ian 134 | fiang f iang 135 | fiao f iao 136 | fie f ie 137 | fin f in 138 | fing f ing 139 | fiong f iong 140 | fiu f iu 141 | fo f o 142 | fong f ong 143 | fou f ou 144 | fu f u 145 | fua f ua 146 | fuai f uai 147 | fuan f uan 148 | fuang f uang 149 | fui f ui 150 | fun f un 151 | fv f v 152 | fve f ve 153 | ga g a 154 | gai g ai 155 | gan g an 156 | gang g ang 157 | gao g ao 158 | ge g e 159 | gei g ei 160 | gen g en 161 | geng g eng 162 | ger g er 163 | gi g i 164 | gia g ia 165 | gian g ian 166 | giang g iang 167 | giao g iao 168 | gie g ie 169 | gin g in 170 | ging g ing 171 | giong g iong 172 | giu g iu 173 | gong g ong 174 | gou g ou 175 | gu g u 176 | gua g ua 177 | guai g uai 178 | guan g uan 179 | guang g uang 180 | gui g ui 181 | gun g un 182 | guo g uo 183 | gv g v 184 | gve g ve 185 | ha h a 186 | hai h ai 187 | han h an 188 | hang h ang 189 | hao h ao 190 | he h e 191 | hei h ei 192 | hen h en 193 | heng h eng 194 | her h er 195 | hi h i 196 | hia h ia 197 | hian h ian 198 | hiang h iang 199 | hiao h iao 200 | hie h ie 201 | hin h in 202 | hing h ing 203 | hiong h iong 204 | hiu h iu 205 | hong h ong 206 | hou h ou 207 | hu h u 208 | hua h ua 209 | huai h uai 210 | huan h uan 211 | huang h uang 212 | hui h ui 213 | hun h un 214 | huo h uo 215 | hv h v 216 | hve h ve 217 | ji j i 218 | jia j ia 219 | jian j ian 220 | jiang j iang 221 | jiao j iao 222 | jie j ie 223 | jin j in 224 | jing j ing 225 | jiong j iong 226 | jiu j iu 227 | ju j v 228 | juan j van 229 | jue j ve 230 | jun j vn 231 | ka k a 232 | kai k ai 233 | kan k an 234 | kang k ang 235 | kao k ao 236 | ke k e 237 | kei k ei 238 | ken k en 239 | keng k eng 240 | ker k er 241 | ki k i 242 | kia k ia 243 | kian k ian 244 | kiang k iang 245 | kiao k iao 246 | kie k ie 247 | kin k in 248 | king k ing 249 | kiong k iong 250 | kiu k iu 251 | kong k ong 252 | kou k ou 253 | ku k u 254 | kua k ua 255 | kuai k uai 256 | kuan k uan 257 | kuang k uang 258 | kui k ui 259 | kun k un 260 | kuo k uo 261 | kv k v 262 | kve k ve 263 | la l a 264 | lai l ai 265 | lan l an 266 | lang l ang 267 | lao l ao 268 | le l e 269 | lei l ei 270 | len l en 271 | leng l eng 272 | ler l er 273 | li l i 274 | lia l ia 275 | lian l ian 276 | liang l iang 277 | liao l iao 278 | lie l ie 279 | lin l in 280 | ling l ing 281 | liong l iong 282 | liu l iu 283 | lo l o 284 | long l ong 285 | lou l ou 286 | lu l u 287 | lua l ua 288 | luai l uai 289 | luan l uan 290 | luang l uang 291 | lui l ui 292 | lun l un 293 | luo l uo 294 | lv l v 295 | lve l ve 296 | ma m a 297 | mai m ai 298 | man m an 299 | mang m ang 300 | mao m ao 301 | me m e 302 | mei m ei 303 | men m en 304 | meng m eng 305 | mer m er 306 | mi m i 307 | mia m ia 308 | mian m ian 309 | miang m iang 310 | miao m iao 311 | mie m ie 312 | min m in 313 | ming m ing 314 | miong m iong 315 | miu m iu 316 | mo m o 317 | mong m ong 318 | mou m ou 319 | mu m u 320 | mua m ua 321 | muai m uai 322 | muan m uan 323 | muang m uang 324 | mui m ui 325 | mun m un 326 | mv m v 327 | mve m ve 328 | na n a 329 | nai n ai 330 | nan n an 331 | nang n ang 332 | nao n ao 333 | ne n e 334 | nei n ei 335 | nen n en 336 | neng n eng 337 | ner n er 338 | ni n i 339 | nia n ia 340 | nian n ian 341 | niang n iang 342 | niao n iao 343 | nie n ie 344 | nin n in 345 | ning n ing 346 | niong n iong 347 | niu n iu 348 | nong n ong 349 | nou n ou 350 | nu n u 351 | nua n ua 352 | nuai n uai 353 | nuan n uan 354 | nuang n uang 355 | nui n ui 356 | nun n un 357 | nuo n uo 358 | nv n v 359 | nve n ve 360 | o o 361 | ong ong 362 | ou ou 363 | pa p a 364 | pai p ai 365 | pan p an 366 | pang p ang 367 | pao p ao 368 | pe p e 369 | pei p ei 370 | pen p en 371 | peng p eng 372 | per p er 373 | pi p i 374 | pia p ia 375 | pian p ian 376 | piang p iang 377 | piao p iao 378 | pie p ie 379 | pin p in 380 | ping p ing 381 | piong p iong 382 | piu p iu 383 | po p o 384 | pong p ong 385 | pou p ou 386 | pu p u 387 | pua p ua 388 | puai p uai 389 | puan p uan 390 | puang p uang 391 | pui p ui 392 | pun p un 393 | pv p v 394 | pve p ve 395 | qi q i 396 | qia q ia 397 | qian q ian 398 | qiang q iang 399 | qiao q iao 400 | qie q ie 401 | qin q in 402 | qing q ing 403 | qiong q iong 404 | qiu q iu 405 | qu q v 406 | quan q van 407 | que q ve 408 | qun q vn 409 | ra r a 410 | rai r ai 411 | ran r an 412 | rang r ang 413 | rao r ao 414 | re r e 415 | rei r ei 416 | ren r en 417 | reng r eng 418 | rer r er 419 | ri r ir 420 | rong r ong 421 | rou r ou 422 | ru r u 423 | rua r ua 424 | ruai r uai 425 | ruan r uan 426 | ruang r uang 427 | rui r ui 428 | run r un 429 | ruo r uo 430 | rv r v 431 | ryi r i 432 | sa s a 433 | sai s ai 434 | san s an 435 | sang s ang 436 | sao s ao 437 | se s e 438 | sei s ei 439 | sen s en 440 | seng s eng 441 | ser s er 442 | sha sh a 443 | shai sh ai 444 | shan sh an 445 | shang sh ang 446 | shao sh ao 447 | she sh e 448 | shei sh ei 449 | shen sh en 450 | sheng sh eng 451 | sher sh er 452 | shi sh ir 453 | shong sh ong 454 | shou sh ou 455 | shu sh u 456 | shua sh ua 457 | shuai sh uai 458 | shuan sh uan 459 | shuang sh uang 460 | shui sh ui 461 | shun sh un 462 | shuo sh uo 463 | shv sh v 464 | shyi sh i 465 | si s i0 466 | song s ong 467 | sou s ou 468 | su s u 469 | sua s ua 470 | suai s uai 471 | suan s uan 472 | suang s uang 473 | sui s ui 474 | sun s un 475 | suo s uo 476 | sv s v 477 | syi s i 478 | ta t a 479 | tai t ai 480 | tan t an 481 | tang t ang 482 | tao t ao 483 | te t e 484 | tei t ei 485 | ten t en 486 | teng t eng 487 | ter t er 488 | ti t i 489 | tia t ia 490 | tian t ian 491 | tiang t iang 492 | tiao t iao 493 | tie t ie 494 | tin t in 495 | ting t ing 496 | tiong t iong 497 | tong t ong 498 | tou t ou 499 | tu t u 500 | tua t ua 501 | tuai t uai 502 | tuan t uan 503 | tuang t uang 504 | tui t ui 505 | tun t un 506 | tuo t uo 507 | tv t v 508 | tve t ve 509 | wa w a 510 | wai w ai 511 | wan w an 512 | wang w ang 513 | wao w ao 514 | we w e 515 | wei w ei 516 | wen w en 517 | weng w eng 518 | wer w er 519 | wi w i 520 | wo w o 521 | wong w ong 522 | wou w ou 523 | wu w u 524 | xi x i 525 | xia x ia 526 | xian x ian 527 | xiang x iang 528 | xiao x iao 529 | xie x ie 530 | xin x in 531 | xing x ing 532 | xiong x iong 533 | xiu x iu 534 | xu x v 535 | xuan x van 536 | xue x ve 537 | xun x vn 538 | ya y a 539 | yai y ai 540 | yan y En 541 | yang y ang 542 | yao y ao 543 | ye y E 544 | yei y ei 545 | yi y i 546 | yin y in 547 | ying y ing 548 | yo y o 549 | yong y ong 550 | you y ou 551 | yu y v 552 | yuan y van 553 | yue y ve 554 | yun y vn 555 | ywu y u 556 | za z a 557 | zai z ai 558 | zan z an 559 | zang z ang 560 | zao z ao 561 | ze z e 562 | zei z ei 563 | zen z en 564 | zeng z eng 565 | zer z er 566 | zha zh a 567 | zhai zh ai 568 | zhan zh an 569 | zhang zh ang 570 | zhao zh ao 571 | zhe zh e 572 | zhei zh ei 573 | zhen zh en 574 | zheng zh eng 575 | zher zh er 576 | zhi zh ir 577 | zhong zh ong 578 | zhou zh ou 579 | zhu zh u 580 | zhua zh ua 581 | zhuai zh uai 582 | zhuan zh uan 583 | zhuang zh uang 584 | zhui zh ui 585 | zhun zh un 586 | zhuo zh uo 587 | zhv zh v 588 | zhyi zh i 589 | zi z i0 590 | zong z ong 591 | zou z ou 592 | zu z u 593 | zua z ua 594 | zuai z uai 595 | zuan z uan 596 | zuang z uang 597 | zui z ui 598 | zun z un 599 | zuo z uo 600 | zv z v 601 | zyi z i 602 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | import h5py 4 | import numpy as np 5 | import pandas as pd 6 | import torch 7 | from einops import rearrange 8 | 9 | 10 | class MixedDataset(torch.utils.data.Dataset): 11 | def __init__( 12 | self, 13 | augmentation_size, 14 | binary_data_folder="data/binary", 15 | prefix="train", 16 | ): 17 | # do not open hdf5 here 18 | self.h5py_file = None 19 | self.label_types = None 20 | self.wav_lengths = None 21 | if augmentation_size > 0: 22 | self.augmentation_indexes = np.arange(augmentation_size + 1) 23 | else: 24 | self.augmentation_indexes = None 25 | 26 | self.binary_data_folder = binary_data_folder 27 | self.prefix = prefix 28 | 29 | def get_label_types(self): 30 | uninitialized = self.label_types is None 31 | if uninitialized: 32 | self._open_h5py_file() 33 | ret = self.label_types 34 | if uninitialized: 35 | self._close_h5py_file() 36 | return ret 37 | 38 | def get_wav_lengths(self): 39 | uninitialized = self.wav_lengths is None 40 | if uninitialized: 41 | self._open_h5py_file() 42 | ret = self.wav_lengths 43 | if uninitialized: 44 | self._close_h5py_file() 45 | return ret 46 | 47 | def _open_h5py_file(self): 48 | self.h5py_file = h5py.File( 49 | str(pathlib.Path(self.binary_data_folder) / (self.prefix + ".h5py")), "r" 50 | ) 51 | self.label_types = np.array(self.h5py_file["meta_data"]["label_types"]) 52 | self.wav_lengths = np.array(self.h5py_file["meta_data"]["wav_lengths"]) 53 | 54 | def _close_h5py_file(self): 55 | self.h5py_file.close() 56 | self.h5py_file = None 57 | 58 | def __len__(self): 59 | uninitialized = self.h5py_file is None 60 | if uninitialized: 61 | self._open_h5py_file() 62 | ret = len(self.h5py_file["items"]) 63 | if uninitialized: 64 | self._close_h5py_file() 65 | return ret 66 | 67 | def __getitem__(self, index): 68 | if self.h5py_file is None: 69 | self._open_h5py_file() 70 | 71 | item = self.h5py_file["items"][str(index)] 72 | 73 | # input_feature 74 | if self.augmentation_indexes is None: 75 | input_feature = np.array(item["input_feature"]) 76 | else: 77 | indexes = np.random.choice(self.augmentation_indexes, 2) 78 | input_feature = np.array(item["input_feature"])[indexes, :, :] 79 | 80 | # label_type 81 | label_type = np.array(item["label_type"]) 82 | 83 | # ph_seq 84 | ph_seq = np.array(item["ph_seq"]) 85 | 86 | # ph_edge 87 | ph_edge = np.array(item["ph_edge"]) 88 | 89 | # ph_frame 90 | ph_frame = np.array(item["ph_frame"]) 91 | 92 | # ph_mask 93 | ph_mask = np.array(item["ph_mask"]) 94 | 95 | input_feature = np.repeat( 96 | input_feature, len(ph_frame) // input_feature.shape[-1], axis=-1 97 | ) 98 | 99 | return input_feature, ph_seq, ph_edge, ph_frame, ph_mask, label_type 100 | 101 | 102 | class WeightedBinningAudioBatchSampler(torch.utils.data.Sampler): 103 | def __init__( 104 | self, 105 | type_ids, 106 | wav_lengths, 107 | oversampling_weights=None, 108 | max_length=100, 109 | binning_length=1000, 110 | drop_last=False, 111 | ): 112 | if oversampling_weights is None: 113 | oversampling_weights = [1] * (max(type_ids) + 1) 114 | oversampling_weights = np.array(oversampling_weights).astype(np.float32) 115 | 116 | assert min(oversampling_weights) > 0 117 | assert len(oversampling_weights) >= max(type_ids) + 1 118 | assert min(type_ids) >= 0 119 | assert len(type_ids) == len(wav_lengths) 120 | assert max_length > 0 121 | assert binning_length > 0 122 | 123 | count = np.bincount(type_ids) 124 | count = np.pad(count, (0, len(oversampling_weights) - len(count))) 125 | self.oversampling_weights = oversampling_weights / min( 126 | oversampling_weights[count > 0] 127 | ) 128 | self.max_length = max_length 129 | self.drop_last = drop_last 130 | 131 | # sort by wav_lengths 132 | meta_data = ( 133 | pd.DataFrame( 134 | { 135 | "dataset_index": range(len(type_ids)), 136 | "type_id": type_ids, 137 | "wav_length": wav_lengths, 138 | } 139 | ) 140 | .sort_values(by=["wav_length"], ascending=False) 141 | .reset_index(drop=True) 142 | ) 143 | 144 | # binning and compute oversampling num 145 | self.bins = [] 146 | 147 | curr_bin_start_index = 0 148 | curr_bin_max_item_length = meta_data.loc[0, "wav_length"] 149 | for i in range(len(meta_data)): 150 | if curr_bin_max_item_length * (i - curr_bin_start_index) > binning_length: 151 | bin_data = { 152 | "batch_size": self.max_length // curr_bin_max_item_length, 153 | "num_batches": 0, 154 | "type": [], 155 | } 156 | 157 | item_num = 0 158 | for type_id, weight in enumerate(self.oversampling_weights): 159 | idx_list = ( 160 | meta_data.loc[curr_bin_start_index : i - 1] 161 | .loc[meta_data["type_id"] == type_id] 162 | .to_dict(orient="list")["dataset_index"] 163 | ) 164 | 165 | oversample_num = np.round(len(idx_list) * (weight - 1)) 166 | bin_data["type"].append( 167 | { 168 | "idx_list": idx_list, 169 | "oversample_num": oversample_num, 170 | } 171 | ) 172 | item_num += len(idx_list) + oversample_num 173 | 174 | if bin_data["batch_size"] <= 0: 175 | raise ValueError( 176 | "batch_size <= 0, maybe batch_max_length in training config is too small " 177 | "or max_length in binarizing config is too long." 178 | ) 179 | num_batches = item_num / bin_data["batch_size"] 180 | if self.drop_last: 181 | bin_data["num_batches"] = int(num_batches) 182 | else: 183 | bin_data["num_batches"] = int(np.ceil(num_batches)) 184 | self.bins.append(bin_data) 185 | 186 | curr_bin_start_index = i 187 | curr_bin_max_item_length = meta_data.loc[i, "wav_length"] 188 | 189 | self.len = None 190 | 191 | def __len__(self): 192 | if self.len is None: 193 | self.len = 0 194 | for bin_data in self.bins: 195 | self.len += bin_data["num_batches"] 196 | return self.len 197 | 198 | def __iter__(self): 199 | np.random.shuffle(self.bins) 200 | 201 | for bin_data in self.bins: 202 | batch_size = bin_data["batch_size"] 203 | num_batches = bin_data["num_batches"] 204 | 205 | idx_list = [] 206 | for type_id, weight in enumerate(self.oversampling_weights): 207 | idx_list_of_type = bin_data["type"][type_id]["idx_list"] 208 | oversample_num = bin_data["type"][type_id]["oversample_num"] 209 | 210 | if len(idx_list_of_type) > 0: 211 | idx_list.extend(idx_list_of_type) 212 | oversample_idx_list = np.random.choice( 213 | idx_list_of_type, int(oversample_num) 214 | ) 215 | idx_list.extend(oversample_idx_list) 216 | 217 | idx_list = np.random.permutation(idx_list) 218 | 219 | if self.drop_last: 220 | num_batches = int(num_batches) 221 | idx_list = idx_list[: num_batches * batch_size] 222 | else: 223 | num_batches = int(np.ceil(num_batches)) 224 | random_idx = np.random.choice( 225 | idx_list, int(num_batches * batch_size - len(idx_list)) 226 | ) 227 | idx_list = np.concatenate([idx_list, random_idx]) 228 | 229 | np.random.shuffle(idx_list) 230 | 231 | for i in range(num_batches): 232 | yield idx_list[int(i * batch_size) : int((i + 1) * batch_size)] 233 | 234 | 235 | def collate_fn(batch): 236 | """_summary_ 237 | 238 | Args: 239 | batch (tuple): input_feature, ph_seq, ph_edge, ph_frame, ph_mask, label_type from MixedDataset 240 | 241 | Returns: 242 | input_feature: (B C T) 243 | input_feature_lengths: (B) 244 | ph_seq: (B S) 245 | ph_seq_lengths: (B) 246 | ph_edge: (B T) 247 | ph_frame: (B T) 248 | ph_mask: (B vocab_size) 249 | label_type: (B) 250 | """ 251 | input_feature_lengths = torch.tensor([i[0].shape[-1] for i in batch]) 252 | max_len = max(input_feature_lengths) 253 | ph_seq_lengths = torch.tensor([len(item[1]) for item in batch]) 254 | max_ph_seq_len = max(ph_seq_lengths) 255 | if batch[0][0].shape[0] > 1: 256 | augmentation_enabled = True 257 | else: 258 | augmentation_enabled = False 259 | 260 | # padding 261 | for i, item in enumerate(batch): 262 | item = list(item) 263 | for param in [0, 2, 3]: 264 | item[param] = torch.nn.functional.pad( 265 | torch.tensor(item[param]), 266 | (0, max_len - item[param].shape[-1]), 267 | "constant", 268 | 0, 269 | ) 270 | item[1] = torch.nn.functional.pad( 271 | torch.tensor(item[1]), 272 | (0, max_ph_seq_len - item[1].shape[-1]), 273 | "constant", 274 | 0, 275 | ) 276 | item[4] = torch.from_numpy(item[4]) 277 | batch[i] = tuple(item) 278 | 279 | input_feature = torch.stack([item[0] for item in batch], dim=1) 280 | input_feature = rearrange(input_feature, "n b c t -> (n b) c t") 281 | ph_seq = torch.stack([item[1] for item in batch]) 282 | ph_edge = torch.stack([item[2] for item in batch]) 283 | ph_frame = torch.stack([item[3] for item in batch]) 284 | ph_mask = torch.stack([item[4] for item in batch]) 285 | 286 | label_type = torch.tensor(np.array([item[5] for item in batch])) 287 | 288 | if augmentation_enabled: 289 | input_feature_lengths = torch.concat( 290 | [input_feature_lengths, input_feature_lengths], dim=0 291 | ) 292 | ph_seq = torch.concat([ph_seq, ph_seq], dim=0) 293 | ph_seq_lengths = torch.concat([ph_seq_lengths, ph_seq_lengths], dim=0) 294 | ph_edge = torch.concat([ph_edge, ph_edge], dim=0) 295 | ph_frame = torch.concat([ph_frame, ph_frame], dim=0) 296 | ph_mask = torch.concat([ph_mask, ph_mask], dim=0) 297 | label_type = torch.concat([label_type, label_type], dim=0) 298 | 299 | return ( 300 | input_feature, 301 | input_feature_lengths, 302 | ph_seq, 303 | ph_seq_lengths, 304 | ph_edge, 305 | ph_frame, 306 | ph_mask, 307 | label_type, 308 | ) 309 | 310 | 311 | if __name__ == "__main__": 312 | dataset = MixedDataset(2) 313 | print(dataset[0]) 314 | # sampler = WeightedBinningAudioBatchSampler(dataset.get_label_types(), dataset.get_wav_lengths(), [1, 0.3, 0.4]) 315 | # for i in tqdm(sampler): 316 | # print(len(i)) 317 | -------------------------------------------------------------------------------- /onnx_infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | import click 5 | import numpy as np 6 | import onnxruntime as ort 7 | import torchaudio 8 | import yaml 9 | from tqdm import tqdm 10 | 11 | import modules.AP_detector 12 | import modules.g2p 13 | import numba 14 | 15 | from modules.utils.export_tool import Exporter 16 | from modules.utils.post_processing import post_processing 17 | 18 | 19 | def load_config_from_yaml(file_path): 20 | with open(file_path, 'r') as file: 21 | config = yaml.safe_load(file) 22 | return config 23 | 24 | 25 | def run_inference(session, waveform, num_frames, ph_seq_id): 26 | output_names = [output.name for output in session.get_outputs()] 27 | 28 | input_data = { 29 | 'waveform': waveform, 30 | 'num_frames': np.array(num_frames, dtype=np.int64), 31 | 'ph_seq_id': ph_seq_id 32 | } 33 | 34 | # 运行推理 35 | try: 36 | results = session.run(output_names, input_data) 37 | except Exception as e: 38 | print(f"推理过程中发生错误: {e}") 39 | raise 40 | 41 | # 将结果转换为字典形式 42 | output_dict = {name: result for name, result in zip(output_names, results)} 43 | 44 | return output_dict 45 | 46 | 47 | def create_session(onnx_model_path): 48 | providers = ['CUDAExecutionProvider', 'DmlExecutionProvider', 'CPUExecutionProvider' 49 | ] 50 | 51 | session_options = ort.SessionOptions() 52 | session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL 53 | session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL 54 | 55 | try: 56 | session = ort.InferenceSession(onnx_model_path, sess_options=session_options, providers=providers) 57 | except Exception as e: 58 | print(f"An error occurred while creating ONNX Runtime session: {e}") 59 | raise 60 | 61 | return session 62 | 63 | 64 | @numba.jit 65 | def forward_pass(T, S, prob_log, not_edge_prob_log, edge_prob_log, curr_ph_max_prob_log, dp, backtrack_s, ph_seq_id, 66 | prob3_pad_len): 67 | for t in range(1, T): 68 | # [t-1,s] -> [t,s] 69 | prob1 = dp[t - 1, :] + prob_log[t, :] + not_edge_prob_log[t] 70 | 71 | prob2 = np.empty(S, dtype=np.float32) 72 | prob2[0] = -np.inf 73 | for i in range(1, S): 74 | prob2[i] = ( 75 | dp[t - 1, i - 1] 76 | + prob_log[t, i - 1] 77 | + edge_prob_log[t] 78 | + curr_ph_max_prob_log[i - 1] * (T / S) 79 | ) 80 | 81 | # [t-1,s-2] -> [t,s] 82 | prob3 = np.empty(S, dtype=np.float32) 83 | for i in range(prob3_pad_len): 84 | prob3[i] = -np.inf 85 | for i in range(prob3_pad_len, S): 86 | if i - prob3_pad_len + 1 < S - 1 and ph_seq_id[i - prob3_pad_len + 1] != 0: 87 | prob3[i] = -np.inf 88 | else: 89 | prob3[i] = ( 90 | dp[t - 1, i - prob3_pad_len] 91 | + prob_log[t, i - prob3_pad_len] 92 | + edge_prob_log[t] 93 | + curr_ph_max_prob_log[i - prob3_pad_len] * (T / S) 94 | ) 95 | 96 | stacked_probs = np.empty((3, S), dtype=np.float32) 97 | for i in range(S): 98 | stacked_probs[0, i] = prob1[i] 99 | stacked_probs[1, i] = prob2[i] 100 | stacked_probs[2, i] = prob3[i] 101 | 102 | for i in range(S): 103 | max_idx = 0 104 | max_val = stacked_probs[0, i] 105 | for j in range(1, 3): 106 | if stacked_probs[j, i] > max_val: 107 | max_val = stacked_probs[j, i] 108 | max_idx = j 109 | dp[t, i] = max_val 110 | backtrack_s[t, i] = max_idx 111 | 112 | for i in range(S): 113 | if backtrack_s[t, i] == 0: 114 | curr_ph_max_prob_log[i] = max(curr_ph_max_prob_log[i], prob_log[t, i]) 115 | elif backtrack_s[t, i] > 0: 116 | curr_ph_max_prob_log[i] = prob_log[t, i] 117 | 118 | for i in range(S): 119 | if ph_seq_id[i] == 0: 120 | curr_ph_max_prob_log[i] = 0 121 | 122 | return dp, backtrack_s, curr_ph_max_prob_log 123 | 124 | 125 | def decode(ph_seq_id, ph_prob_log, edge_prob): 126 | # ph_seq_id: (S) 127 | # ph_prob_log: (T, vocab_size) 128 | # edge_prob: (T,2) 129 | T = ph_prob_log.shape[0] 130 | S = len(ph_seq_id) 131 | # not_SP_num = (ph_seq_id > 0).sum() 132 | prob_log = ph_prob_log[:, ph_seq_id] 133 | 134 | edge_prob_log = np.log(edge_prob + 1e-6).astype("float32") 135 | not_edge_prob_log = np.log(1 - edge_prob + 1e-6).astype("float32") 136 | 137 | # init 138 | curr_ph_max_prob_log = np.full(S, -np.inf) 139 | dp = np.full((T, S), -np.inf, dtype="float32") # (T, S) 140 | backtrack_s = np.full_like(dp, -1, dtype="int32") 141 | 142 | dp[0, 0] = prob_log[0, 0] 143 | curr_ph_max_prob_log[0] = prob_log[0, 0] 144 | if ph_seq_id[0] == 0 and prob_log.shape[-1] > 1: 145 | dp[0, 1] = prob_log[0, 1] 146 | curr_ph_max_prob_log[1] = prob_log[0, 1] 147 | 148 | # forward 149 | prob3_pad_len = 2 if S >= 2 else 1 150 | dp, backtrack_s, curr_ph_max_prob_log = forward_pass( 151 | T, S, prob_log, not_edge_prob_log, edge_prob_log, curr_ph_max_prob_log, dp, backtrack_s, ph_seq_id, 152 | prob3_pad_len 153 | ) 154 | 155 | # backward 156 | ph_idx_seq = [] 157 | ph_time_int = [] 158 | frame_confidence = [] 159 | 160 | # 如果mode==forced,只能从最后一个音素或者SP结束 161 | if S >= 2 and dp[-1, -2] > dp[-1, -1] and ph_seq_id[-1] == 0: 162 | s = S - 2 163 | else: 164 | s = S - 1 165 | 166 | for t in np.arange(T - 1, -1, -1): 167 | assert backtrack_s[t, s] >= 0 or t == 0 168 | frame_confidence.append(dp[t, s]) 169 | if backtrack_s[t, s] != 0: 170 | ph_idx_seq.append(s) 171 | ph_time_int.append(t) 172 | s -= backtrack_s[t, s] 173 | ph_idx_seq.reverse() 174 | ph_time_int.reverse() 175 | frame_confidence.reverse() 176 | frame_confidence = np.exp( 177 | np.diff( 178 | np.pad(frame_confidence, (1, 0), "constant", constant_values=0.0), 1 179 | ) 180 | ) 181 | 182 | return ( 183 | np.array(ph_idx_seq), 184 | np.array(ph_time_int), 185 | np.array(frame_confidence), 186 | ) 187 | 188 | 189 | @click.command() 190 | @click.option( 191 | "--onnx", 192 | "-c", 193 | default=None, 194 | required=True, 195 | type=str, 196 | help="path to the onnx", 197 | ) 198 | @click.option( 199 | "--folder", "-f", default="segments", type=str, help="path to the input folder" 200 | ) 201 | @click.option( 202 | "--g2p", "-g", default="Dictionary", type=str, help="name of the g2p class" 203 | ) 204 | @click.option( 205 | "--ap_detector", 206 | "-a", 207 | default="LoudnessSpectralcentroidAPDetector", # "NoneAPDetector", 208 | type=str, 209 | help="name of the AP detector class", 210 | ) 211 | @click.option( 212 | "--in_format", 213 | "-if", 214 | default="lab", 215 | required=False, 216 | type=str, 217 | help="File extension of input transcriptions. Default: lab", 218 | ) 219 | @click.option( 220 | "--out_formats", 221 | "-of", 222 | default="textgrid,htk,trans", 223 | required=False, 224 | type=str, 225 | help="Types of output file, separated by comma. Supported types:" 226 | "textgrid(praat)," 227 | " htk(lab,nnsvs,sinsy)," 228 | " transcriptions.csv(diffsinger,trans,transcription,transcriptions)", 229 | ) 230 | @click.option( 231 | "--save_confidence", 232 | "-sc", 233 | is_flag=True, 234 | default=False, 235 | show_default=True, 236 | help="save confidence.csv", 237 | ) 238 | @click.option( 239 | "--dictionary", 240 | "-d", 241 | default="dictionary/opencpop-extension.txt", 242 | type=str, 243 | help="(only used when --g2p=='Dictionary') path to the dictionary", 244 | ) 245 | def infer(onnx, 246 | folder, 247 | g2p, 248 | ap_detector, 249 | in_format, 250 | out_formats, 251 | save_confidence, 252 | **kwargs, ): 253 | config_file = pathlib.Path(onnx).with_name('config.yaml') 254 | assert os.path.exists(onnx), f"Onnx file does not exist: {onnx}" 255 | assert config_file.exists(), f"Config file does not exist: {config_file}" 256 | 257 | config = load_config_from_yaml(config_file) 258 | melspec_config = config['melspec_config'] 259 | session = create_session(onnx) 260 | 261 | if not g2p.endswith("G2P"): 262 | g2p += "G2P" 263 | g2p_class = getattr(modules.g2p, g2p) 264 | grapheme_to_phoneme = g2p_class(**kwargs) 265 | out_formats = [i.strip().lower() for i in out_formats.split(",")] 266 | 267 | if not ap_detector.endswith("APDetector"): 268 | ap_detector += "APDetector" 269 | AP_detector_class = getattr(modules.AP_detector, ap_detector) 270 | get_AP = AP_detector_class(**kwargs) 271 | 272 | grapheme_to_phoneme.set_in_format(in_format) 273 | dataset = grapheme_to_phoneme.get_dataset(pathlib.Path(folder).rglob("*.wav")) 274 | predictions = [] 275 | 276 | for i in tqdm(range(len(dataset)), desc="Processing", unit="sample"): 277 | wav_path, ph_seq, word_seq, ph_idx_to_word_idx = dataset[i] 278 | 279 | waveform, sr = torchaudio.load(wav_path) 280 | waveform = waveform[0][None, :][0] 281 | if sr != melspec_config['sample_rate']: 282 | waveform = torchaudio.transforms.Resample(sr, melspec_config['sample_rate'])(waveform) 283 | 284 | wav_length = waveform.shape[0] / melspec_config["sample_rate"] 285 | ph_seq_id = np.array([config['vocab'][ph] for ph in ph_seq], dtype=np.int64) 286 | num_frames = int( 287 | (wav_length * melspec_config["scale_factor"] * melspec_config["sample_rate"] + 0.5) / melspec_config[ 288 | "hop_length"] 289 | ) 290 | results = run_inference(session, [waveform.numpy()], num_frames, [ph_seq_id]) 291 | 292 | edge_diff = results['edge_diff'] 293 | edge_prob = results['edge_prob'] 294 | ph_prob_log = results['ph_prob_log'] 295 | # ctc_logits = results['ctc_logits'] 296 | T = results['T'] 297 | 298 | ph_idx_seq, ph_time_int_pred, frame_confidence = decode(ph_seq_id, ph_prob_log, edge_prob, ) 299 | total_confidence = np.exp(np.mean(np.log(frame_confidence + 1e-6)) / 3) 300 | 301 | # postprocess 302 | frame_length = melspec_config["hop_length"] / ( 303 | melspec_config["sample_rate"] * melspec_config["scale_factor"] 304 | ) 305 | ph_time_fractional = (edge_diff[ph_time_int_pred] / 2).clip(-0.5, 0.5) 306 | ph_time_pred = frame_length * ( 307 | np.concatenate( 308 | [ 309 | ph_time_int_pred.astype("float32") + ph_time_fractional, 310 | [T], 311 | ] 312 | ) 313 | ) 314 | ph_intervals = np.stack([ph_time_pred[:-1], ph_time_pred[1:]], axis=1) 315 | 316 | ph_seq_pred = [] 317 | ph_intervals_pred = [] 318 | word_seq_pred = [] 319 | word_intervals_pred = [] 320 | 321 | word_idx_last = -1 322 | for j, ph_idx in enumerate(ph_idx_seq): 323 | # ph_idx只能用于两种情况:ph_seq和ph_idx_to_word_idx 324 | if ph_seq[ph_idx] == "SP": 325 | continue 326 | ph_seq_pred.append(ph_seq[ph_idx]) 327 | ph_intervals_pred.append(ph_intervals[j, :]) 328 | 329 | word_idx = ph_idx_to_word_idx[ph_idx] 330 | if word_idx == word_idx_last: 331 | word_intervals_pred[-1][1] = ph_intervals[j, 1] 332 | else: 333 | word_seq_pred.append(word_seq[word_idx]) 334 | word_intervals_pred.append([ph_intervals[j, 0], ph_intervals[j, 1]]) 335 | word_idx_last = word_idx 336 | ph_seq_pred = np.array(ph_seq_pred) 337 | ph_intervals_pred = np.array(ph_intervals_pred).clip(min=0, max=None) 338 | word_seq_pred = np.array(word_seq_pred) 339 | word_intervals_pred = np.array(word_intervals_pred).clip(min=0, max=None) 340 | 341 | predictions.append((wav_path, 342 | wav_length, 343 | total_confidence, 344 | ph_seq_pred, 345 | ph_intervals_pred, 346 | word_seq_pred, 347 | word_intervals_pred)) 348 | 349 | predictions = get_AP.process(predictions) 350 | predictions, log = post_processing(predictions) 351 | exporter = Exporter(predictions, log) 352 | 353 | if save_confidence: 354 | out_formats.append('confidence') 355 | 356 | exporter.export(out_formats) 357 | 358 | 359 | if __name__ == '__main__': 360 | infer() 361 | -------------------------------------------------------------------------------- /modules/loss/GHMLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def update_ema(ema, alpha, num_bins, hist): 6 | hist = hist / (torch.sum(hist) + 1e-10) * num_bins 7 | ema = ema * alpha + (1 - alpha) * hist 8 | ema = ema / (torch.sum(ema) + 1e-10) * num_bins 9 | return ema 10 | 11 | 12 | class CTCGHMLoss(torch.nn.Module): 13 | def __init__(self, num_bins=10, alpha=0.999): 14 | super().__init__() 15 | self.ctc_loss_fn = nn.CTCLoss(reduction="none") 16 | self.ctc_loss_fn_cpu = nn.CTCLoss(reduction="none").cpu() 17 | self.num_bins = num_bins 18 | self.register_buffer("ema", torch.ones(num_bins)) 19 | self.alpha = alpha 20 | 21 | def forward(self, log_probs, targets, input_lengths, target_lengths, valid=False): 22 | if len(log_probs) <= 0: 23 | return torch.tensor(0.0).to(log_probs.device) 24 | try: 25 | raw_loss = self.ctc_loss_fn( 26 | log_probs, targets, input_lengths, target_lengths 27 | ) 28 | except RuntimeError: 29 | raw_loss = self.ctc_loss_fn_cpu( 30 | log_probs.cpu(), 31 | targets.cpu(), 32 | input_lengths.cpu(), 33 | target_lengths.cpu(), 34 | ).to(log_probs.device) 35 | loss_for_ema = ( 36 | (-raw_loss / input_lengths).exp().clamp(1e-6, 1 - 1e-6) 37 | ).detach() # 值域为[0, 1] 38 | loss_weighted = ( 39 | raw_loss 40 | / ( 41 | self.ema[ 42 | torch.floor(loss_for_ema * self.num_bins) 43 | .detach() 44 | .long() 45 | .clamp(0, self.num_bins - 1) 46 | ].detach() 47 | + 1e-10 48 | ).detach() 49 | ) 50 | loss_final = loss_weighted.mean() 51 | 52 | if not valid: 53 | hist = torch.histc(loss_for_ema, bins=self.num_bins, min=0, max=1) 54 | self.ema = update_ema(self.ema, self.alpha, self.num_bins, hist) 55 | 56 | return loss_final 57 | 58 | 59 | class BCEGHMLoss(torch.nn.Module): 60 | def __init__(self, num_bins=10, alpha=1 - 1e-6, label_smoothing=0.0): 61 | super().__init__() 62 | self.loss_fn = nn.BCELoss(reduction="none") 63 | self.num_bins = num_bins 64 | self.register_buffer("GD_stat_ema", torch.ones(num_bins)) 65 | self.alpha = alpha 66 | self.label_smoothing = label_smoothing 67 | 68 | def forward(self, pred_porb, target_porb, mask=None, valid=False): 69 | if len(pred_porb) <= 0: 70 | return torch.tensor(0.0).to(pred_porb.device) 71 | if mask is None: 72 | mask = torch.ones_like(pred_porb).to(pred_porb.device) 73 | assert ( 74 | pred_porb.shape == target_porb.shape 75 | and pred_porb.shape[:2] == mask.shape[:2] 76 | ) 77 | if len(mask.shape) < len(pred_porb.shape): 78 | mask = mask.unsqueeze(-1) 79 | mask = mask.repeat(1, pred_porb.shape[-1]) 80 | assert pred_porb.max() <= 1 and pred_porb.min() >= 0 81 | assert target_porb.max() <= 1 and target_porb.min() >= 0 82 | 83 | target_porb = target_porb.clamp(self.label_smoothing, 1 - self.label_smoothing) 84 | 85 | raw_loss = self.loss_fn(pred_porb, target_porb) 86 | 87 | gradient_magnitudes = (pred_porb - target_porb).abs() 88 | gradient_magnitudes_index = ( 89 | torch.floor(gradient_magnitudes * self.num_bins).long().clamp(0, 9) 90 | ) 91 | weights = 1 / self.GD_stat_ema[gradient_magnitudes_index] + 1e-3 92 | loss_weighted = raw_loss * weights 93 | mask_weights = mask.float() 94 | loss_weighted = loss_weighted * mask_weights 95 | loss_final = torch.sum(loss_weighted) / torch.sum(mask_weights) 96 | 97 | if not valid: 98 | # update ema 99 | # "Elements lower than min and higher than max and NaN elements are ignored." 100 | gradient_magnitudes_index = gradient_magnitudes_index.flatten() 101 | mask_weights = mask_weights.flatten() 102 | gradient_magnitudes_index_hist = torch.bincount( 103 | input=gradient_magnitudes_index, 104 | weights=mask_weights, 105 | minlength=self.num_bins, 106 | ) 107 | self.GD_stat_ema = update_ema( 108 | self.GD_stat_ema, 109 | self.alpha, 110 | self.num_bins, 111 | gradient_magnitudes_index_hist, 112 | ) 113 | 114 | return loss_final 115 | 116 | 117 | class MultiLabelGHMLoss(torch.nn.Module): 118 | def __init__(self, num_classes, num_bins=10, alpha=(1 - 1e-6), label_smoothing=0.0): 119 | super().__init__() 120 | self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") 121 | self.num_bins = num_bins 122 | # 难易样本不均衡 123 | self.register_buffer("GD_stat_ema", torch.ones(num_bins)) 124 | self.num_classes = num_classes 125 | # 类别不均衡与正负样本不均衡,分为正、负、中性三类 126 | self.register_buffer("label_stat_ema_each_class", torch.ones([num_classes * 3])) 127 | self.alpha = alpha 128 | self.label_smoothing = label_smoothing 129 | 130 | def forward(self, pred_logits, target_porb, mask=None, valid=False): 131 | """_summary_ 132 | 133 | Args: 134 | pred_porb (torch.Tensor): predicted probability, shape: (** C) 135 | target_porb (torch.Tensor): target probability, shape: same as pred_porb. 136 | mask (torch.Tensor, optional): mask tensor, ignore loss when mask==0. 137 | shape: same as pred_porb. Defaults to None. 138 | valid (bool, optional): enable ema update. Defaults to False. 139 | 140 | Returns: 141 | loss_final (torch.Tensor): loss value, shape: () 142 | """ 143 | if len(pred_logits) <= 0: 144 | return torch.tensor(0.0).to(pred_logits.device) 145 | if mask is None: 146 | mask = torch.ones_like(pred_logits).to(pred_logits.device) 147 | assert ( 148 | pred_logits.shape == target_porb.shape 149 | and pred_logits.shape[:2] == mask.shape[:2] 150 | ) 151 | if len(mask.shape) < len(pred_logits.shape): 152 | mask = mask.unsqueeze(-1) 153 | assert pred_logits.shape[-1] == self.num_classes 154 | assert target_porb.max() <= 1 and target_porb.min() >= 0 155 | 156 | pred_logits = pred_logits.reshape(-1, self.num_classes) 157 | target_porb = target_porb.reshape(-1, self.num_classes) 158 | mask = mask.reshape(target_porb.shape[0], -1) 159 | if mask.shape[-1] == 1 and target_porb.shape[-1] > 1: 160 | mask = mask.repeat(1, target_porb.shape[-1]) 161 | target_porb = target_porb.clamp(self.label_smoothing, 1 - self.label_smoothing) 162 | 163 | raw_loss = self.loss_fn(pred_logits, target_porb) 164 | 165 | pred_porb = torch.nn.functional.sigmoid(pred_logits) 166 | gradient_magnitudes_index = ( 167 | torch.floor((pred_porb - target_porb).abs() * self.num_bins) 168 | .long() 169 | .clamp(0, self.num_bins - 1) 170 | ) 171 | GD_weights = 1 / self.GD_stat_ema[gradient_magnitudes_index] + 1e-3 172 | target_porb_index = torch.floor(target_porb * 3).long().clamp( 173 | 0, 2 174 | ) + 3 * torch.arange(self.num_classes).to(target_porb.device).unsqueeze(0) 175 | classes_weights = 1 / self.label_stat_ema_each_class[target_porb_index] + 1e-3 176 | weights = torch.sqrt(GD_weights * classes_weights) 177 | loss_weighted = raw_loss * weights 178 | loss_weighted = loss_weighted * mask 179 | loss_final = torch.sum(loss_weighted) / torch.sum(mask) 180 | 181 | if not valid: 182 | # update ema 183 | # TODO:要带着mask统计的话,mask的shape就要和input一致 184 | mask = mask.flatten() 185 | gradient_magnitudes_index = gradient_magnitudes_index.flatten() 186 | gradient_magnitudes_index_hist = torch.bincount( 187 | input=gradient_magnitudes_index, 188 | weights=mask, 189 | minlength=self.num_bins, 190 | ) 191 | self.GD_stat_ema = update_ema( 192 | self.GD_stat_ema, 193 | self.alpha, 194 | self.num_bins, 195 | gradient_magnitudes_index_hist, 196 | ) 197 | 198 | target_porb_index = target_porb_index.flatten() 199 | target_porb_index_hist = torch.bincount( 200 | input=target_porb_index, 201 | weights=mask, 202 | minlength=self.num_classes * 3, 203 | ) 204 | self.label_stat_ema_each_class = update_ema( 205 | self.label_stat_ema_each_class, 206 | self.alpha, 207 | self.num_classes * 3, 208 | target_porb_index_hist, 209 | ) 210 | 211 | return loss_final 212 | 213 | 214 | if __name__ == "__main__": 215 | loss_fn = MultiLabelGHMLoss(10, alpha=0.9) 216 | input = torch.nn.functional.sigmoid(torch.randn(3, 3, 10) * 10) 217 | target = (torch.nn.functional.sigmoid(torch.randn(3, 3, 10)) > 0.5).float() 218 | print(loss_fn(input, target)) 219 | 220 | 221 | class GHMLoss(torch.nn.Module): 222 | def __init__(self, num_classes, num_bins=10, alpha=1 - 1e-6, label_smoothing=0.0): 223 | super().__init__() 224 | self.num_classes = num_classes 225 | self.register_buffer("class_ema", torch.ones(num_classes)) 226 | self.num_bins = num_bins 227 | self.register_buffer("GD_ema", torch.ones(num_bins)) 228 | self.alpha = alpha 229 | self.loss_fn = nn.CrossEntropyLoss(reduction="none") 230 | self.label_smoothing = label_smoothing 231 | 232 | def forward(self, pred_logits, target_label, mask=None, valid=False): 233 | if len(pred_logits) <= 0: 234 | return torch.tensor(0.0).to(pred_logits.device) 235 | 236 | # pred: [B, T, C] 237 | assert len(pred_logits.shape) == 3 and pred_logits.shape[-1] == self.num_classes 238 | 239 | # target: [B, T] 240 | assert len(target_label.shape) == 2 241 | assert target_label.shape[0] == pred_logits.shape[0] 242 | assert target_label.shape[1] == pred_logits.shape[1] 243 | target_label = target_label.long() 244 | 245 | # mask: [B, T] or [B, T, C] 246 | if mask is None: 247 | mask = torch.ones_like(pred_logits).to(pred_logits.device) 248 | if len(mask.shape) == 2: 249 | mask = mask.unsqueeze(-1) 250 | assert mask.shape[0] == target_label.shape[0] 251 | assert mask.shape[1] == target_label.shape[1] 252 | assert mask.shape[-1] == 1 or mask.shape[-1] == self.num_classes 253 | time_mask = mask.any(dim=-1) # [B, T] 254 | 255 | pred_logits = pred_logits - 1e9 * mask.logical_not().float() 256 | target_prob = ( 257 | nn.functional.one_hot(target_label, num_classes=self.num_classes) 258 | .float() 259 | .clamp(self.label_smoothing, 1 - self.label_smoothing) 260 | ) 261 | target_prob = target_prob * mask.float() 262 | raw_loss = self.loss_fn( 263 | pred_logits.transpose(1, 2), target_prob.transpose(1, 2) 264 | ) # [B, T] 265 | pred_probs = torch.softmax(pred_logits, dim=-1).detach() # [B, T, C] 266 | 267 | # calculate weighted loss 268 | GD = (pred_probs - target_prob).abs() 269 | GD = torch.gather(GD, -1, target_label.unsqueeze(-1)).squeeze(-1) # [B, T] 270 | GD_index = torch.floor(GD * self.num_bins).long().clamp(0, self.num_bins - 1) 271 | # GD = GD - 1e9 * time_mask.logical_not().float() 272 | weights = torch.sqrt( 273 | self.class_ema[target_label].detach() * self.GD_ema[GD_index].detach() 274 | ) # [B, T] 275 | loss_weighted = (raw_loss / weights) * time_mask.float() # [B, T] 276 | loss_final = torch.sum(loss_weighted) / torch.sum(time_mask.float()) 277 | 278 | if not valid: 279 | # update ema 280 | # "Elements lower than min and higher than max and NaN elements are ignored." 281 | target_label = ( 282 | target_label + (self.num_classes + 10) * time_mask.logical_not().long() 283 | ) 284 | GD_index = GD_index + (self.num_bins + 10) * time_mask.logical_not().long() 285 | class_hist = torch.bincount( 286 | input=target_label.flatten(), 287 | weights=time_mask.flatten(), 288 | minlength=self.num_classes, 289 | ) 290 | class_hist = class_hist[: self.num_classes] 291 | GD_hist = torch.bincount( 292 | input=GD_index.flatten(), 293 | weights=time_mask.flatten(), 294 | minlength=self.num_bins, 295 | ) 296 | GD_hist = GD_hist[: self.num_bins] 297 | self.GD_ema = update_ema(self.GD_ema, self.alpha, self.num_bins, GD_hist) 298 | self.class_ema = update_ema( 299 | self.class_ema, self.alpha, self.num_classes, class_hist 300 | ) 301 | 302 | return loss_final 303 | -------------------------------------------------------------------------------- /binarize.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import warnings 3 | 4 | import click 5 | import h5py 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | import yaml 10 | from tqdm import tqdm 11 | 12 | from modules.utils.get_melspec import MelSpecExtractor 13 | from modules.utils.load_wav import load_wav 14 | 15 | 16 | class ForcedAlignmentBinarizer: 17 | def __init__( 18 | self, 19 | data_folder, 20 | valid_set_size, 21 | valid_set_preferred_folders, 22 | data_augmentation, 23 | ignored_phonemes, 24 | melspec_config, 25 | max_length, 26 | ): 27 | self.data_folder = pathlib.Path(data_folder) 28 | self.valid_set_size = valid_set_size 29 | self.valid_set_preferred_folders = valid_set_preferred_folders 30 | self.data_augmentation = data_augmentation 31 | self.data_augmentation["key_shift_choices"] = np.array( 32 | self.data_augmentation["key_shift_choices"] 33 | ) 34 | self.ignored_phonemes = ignored_phonemes 35 | self.melspec_config = melspec_config 36 | self.scale_factor = melspec_config["scale_factor"] 37 | self.max_length = max_length 38 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 39 | 40 | self.sample_rate = self.melspec_config["sample_rate"] 41 | self.frame_length = self.melspec_config["hop_length"] / self.sample_rate 42 | 43 | self.get_melspec = MelSpecExtractor(**melspec_config, device=self.device) 44 | 45 | @staticmethod 46 | def get_vocab(data_folder_path, ignored_phonemes): 47 | print("Generating vocab...") 48 | phonemes = [] 49 | trans_path_list = data_folder_path.rglob("transcriptions.csv") 50 | 51 | for trans_path in trans_path_list: 52 | if trans_path.name == "transcriptions.csv": 53 | df = pd.read_csv(trans_path) 54 | ph = list(set(" ".join(df["ph_seq"]).split(" "))) 55 | phonemes.extend(ph) 56 | 57 | phonemes = set(phonemes) 58 | for p in ignored_phonemes: 59 | if p in phonemes: 60 | phonemes.remove(p) 61 | phonemes = sorted(phonemes) 62 | phonemes = ["SP", *phonemes] 63 | 64 | vocab = dict(zip(phonemes, range(len(phonemes)))) 65 | vocab.update(dict(zip(range(len(phonemes)), phonemes))) 66 | vocab.update({i: 0 for i in ignored_phonemes}) 67 | vocab.update({"": len(phonemes)}) 68 | 69 | print(f"vocab_size is {len(phonemes)}") 70 | 71 | return vocab 72 | 73 | def process(self): 74 | vocab = self.get_vocab(self.data_folder, self.ignored_phonemes) 75 | with open(self.data_folder / "binary" / "vocab.yaml", "w") as file: 76 | yaml.dump(vocab, file) 77 | 78 | # load metadata of each item 79 | meta_data_df = self.get_meta_data(self.data_folder, vocab) 80 | 81 | # split train and valid set 82 | valid_set_size = int(self.valid_set_size) 83 | meta_data_valid = ( 84 | meta_data_df[meta_data_df["label_type"] != "no_label"] 85 | .sample(frac=1) 86 | .sort_values(by="preferred", ascending=False) 87 | .iloc[:valid_set_size, :] 88 | ) 89 | meta_data_train = meta_data_df.drop(meta_data_valid.index).reset_index( 90 | drop=True 91 | ) 92 | meta_data_valid = meta_data_valid.reset_index(drop=True) 93 | 94 | # binarize valid set 95 | self.binarize( 96 | "valid", 97 | meta_data_valid, 98 | vocab, 99 | self.data_folder / "binary", 100 | False, 101 | ) 102 | 103 | # binarize train set 104 | self.binarize( 105 | "train", 106 | meta_data_train, 107 | vocab, 108 | self.data_folder / "binary", 109 | self.data_augmentation["size"] > 0, 110 | ) 111 | 112 | def binarize( 113 | self, 114 | prefix: str, 115 | meta_data: pd.DataFrame, 116 | vocab: dict, 117 | binary_data_folder: str, 118 | enable_data_augmentation: bool, 119 | ): 120 | print(f"Binarizing {prefix} set...") 121 | 122 | h5py_file_path = pathlib.Path(binary_data_folder) / (prefix + ".h5py") 123 | h5py_file = h5py.File(h5py_file_path, "w") 124 | h5py_meta_data = h5py_file.create_group("meta_data") 125 | items_meta_data = {"label_types": [], "wav_lengths": []} 126 | h5py_items = h5py_file.create_group("items") 127 | 128 | label_type_to_id = {"no_label": 0, "weak_label": 1, "full_label": 2} 129 | 130 | idx = 0 131 | total_time = 0.0 132 | for _, item in tqdm(meta_data.iterrows(), total=meta_data.shape[0]): 133 | try: 134 | # input_feature: [data_augmentation.size+1,input_dim,T] 135 | waveform = load_wav(item.wav_path, self.device, self.sample_rate) 136 | input_feature = self.get_melspec(waveform) 137 | 138 | wav_length = len(waveform) / self.sample_rate 139 | T = input_feature.shape[-1] * self.scale_factor 140 | if wav_length > self.max_length: 141 | print( 142 | f"Item {item.wav_path} has a length of {wav_length}s, which is too long, skip it." 143 | ) 144 | continue 145 | else: 146 | h5py_item_data = h5py_items.create_group(str(idx)) 147 | items_meta_data["wav_lengths"].append(wav_length) 148 | idx += 1 149 | total_time += wav_length 150 | 151 | if enable_data_augmentation: 152 | input_features = [input_feature] 153 | key_shifts = np.random.choice( 154 | self.data_augmentation["key_shift_choices"], 155 | self.data_augmentation["size"], 156 | replace=False, 157 | ) 158 | for key_shift in key_shifts: 159 | input_features.append( 160 | self.get_melspec(waveform, key_shift=key_shift) 161 | ) 162 | 163 | input_feature = torch.stack(input_features, dim=0) 164 | else: 165 | input_feature = input_feature.unsqueeze(0) 166 | 167 | input_feature = ( 168 | input_feature - input_feature.mean(dim=[1, 2], keepdim=True) 169 | ) / input_feature.std(dim=[1, 2], keepdim=True) 170 | 171 | h5py_item_data["input_feature"] = ( 172 | input_feature.cpu().numpy().astype("float32") 173 | ) 174 | 175 | # label_type: [] 176 | label_type_id = label_type_to_id[item.label_type] 177 | if label_type_id == 2: 178 | if len(item.ph_dur) != len(item.ph_seq): 179 | label_type_id = 1 180 | if len(item.ph_seq) == 0: 181 | label_type_id = 0 182 | h5py_item_data["label_type"] = label_type_id 183 | items_meta_data["label_types"].append(label_type_id) 184 | 185 | if label_type_id == 0: 186 | # ph_seq: [S] 187 | ph_seq = np.array([]).astype("int32") 188 | 189 | # ph_edge: [scale_factor * T] 190 | ph_edge = np.zeros([T], dtype="float32") 191 | 192 | # ph_frame: [scale_factor * T] 193 | ph_frame = np.zeros(T, dtype="int32") 194 | 195 | # ph_mask: [vocab_size] 196 | ph_mask = np.ones(vocab[""], dtype="int32") 197 | elif label_type_id == 1: 198 | # ph_seq: [S] 199 | ph_seq = np.array(item.ph_seq).astype("int32") 200 | ph_seq = ph_seq[ph_seq != 0] 201 | 202 | # ph_edge: [scale_factor * T] 203 | ph_edge = np.zeros([T], dtype="float32") 204 | 205 | # ph_frame: [scale_factor * T] 206 | ph_frame = np.zeros(T, dtype="int32") 207 | 208 | # ph_mask: [vocab_size] 209 | ph_mask = np.zeros(vocab[""], dtype="int32") 210 | ph_mask[ph_seq] = 1 211 | ph_mask[0] = 1 212 | elif label_type_id == 2: 213 | # ph_seq: [S] 214 | ph_seq = np.array(item.ph_seq).astype("int32") 215 | not_sp_idx = ph_seq != 0 216 | ph_seq = ph_seq[not_sp_idx] 217 | 218 | # ph_edge: [scale_factor * T] 219 | ph_dur = np.array(item.ph_dur).astype("float32") 220 | ph_time = np.array(np.concatenate(([0], ph_dur))).cumsum() / ( 221 | self.frame_length / self.scale_factor 222 | ) 223 | ph_interval = np.stack((ph_time[:-1], ph_time[1:])) 224 | 225 | ph_interval = ph_interval[:, not_sp_idx] 226 | ph_seq = ph_seq 227 | ph_time = np.unique(ph_interval.flatten()) 228 | if ph_time[-1] >= T: 229 | ph_time = ph_time[:-1] 230 | 231 | ph_edge = np.zeros([T], dtype="float32") 232 | if len(ph_seq) > 0: 233 | if ph_time[-1] + 0.5 > T: 234 | ph_time = ph_time[:-1] 235 | if ph_time[0] - 0.5 < 0: 236 | ph_time = ph_time[1:] 237 | ph_time_int = np.round(ph_time).astype("int32") 238 | ph_time_fractional = ph_time - ph_time_int 239 | 240 | ph_edge[ph_time_int] = 0.5 + ph_time_fractional 241 | ph_edge[ph_time_int - 1] = 0.5 - ph_time_fractional 242 | ph_edge = ph_edge * 0.8 + 0.1 243 | 244 | # ph_frame: [scale_factor * T] 245 | ph_frame = np.zeros(T, dtype="int32") 246 | if len(ph_seq) > 0: 247 | for ph_id, st, ed in zip( 248 | ph_seq, ph_interval[0], ph_interval[1] 249 | ): 250 | if st < 0: 251 | st = 0 252 | if ed > T: 253 | ed = T 254 | ph_frame[int(np.round(st)) : int(np.round(ed))] = ph_id 255 | 256 | # ph_mask: [vocab_size] 257 | ph_mask = np.zeros(vocab[""], dtype="int32") 258 | if len(ph_seq) > 0: 259 | ph_mask[ph_seq] = 1 260 | ph_mask[0] = 1 261 | else: 262 | raise ValueError("Unknown label type.") 263 | 264 | h5py_item_data["ph_seq"] = ph_seq.astype("int32") 265 | h5py_item_data["ph_edge"] = ph_edge.astype("float32") 266 | h5py_item_data["ph_frame"] = ph_frame.astype("int32") 267 | h5py_item_data["ph_mask"] = ph_mask.astype("int32") 268 | 269 | # print( 270 | # h5py_item_data["input_feature"].shape, 271 | # np.array(h5py_item_data["label_type"]), 272 | # h5py_item_data["ph_seq"].shape, 273 | # h5py_item_data["ph_edge"].shape, 274 | # h5py_item_data["ph_frame"].shape, 275 | # h5py_item_data["ph_mask"].shape, 276 | # ) 277 | # print( 278 | # h5py_item_data["input_feature"].shape[-1] * 4, 279 | # h5py_item_data["ph_edge"].shape[0], 280 | # h5py_item_data["ph_frame"].shape[0], 281 | # ) 282 | # assert ( 283 | # h5py_item_data["input_feature"].shape[-1] * 4 284 | # == h5py_item_data["ph_edge"].shape[0] 285 | # ) 286 | # assert ( 287 | # h5py_item_data["input_feature"].shape[-1] * 4 288 | # == h5py_item_data["ph_frame"].shape[0] 289 | # ) 290 | except Exception as e: 291 | e.args += (item.wav_path,) 292 | print(e) 293 | continue 294 | for k, v in items_meta_data.items(): 295 | h5py_meta_data[k] = np.array(v) 296 | h5py_file.close() 297 | full_label_ratio = items_meta_data["label_types"].count(2) / len( 298 | items_meta_data["label_types"] 299 | ) 300 | weak_label_ratio = items_meta_data["label_types"].count(1) / len( 301 | items_meta_data["label_types"] 302 | ) 303 | no_label_ratio = items_meta_data["label_types"].count(0) / len( 304 | items_meta_data["label_types"] 305 | ) 306 | print( 307 | "Data compression ratio: \n" 308 | f" full label data: {100 * full_label_ratio:.2f} %,\n" 309 | f" weak label data: {100 * weak_label_ratio:.2f} %,\n" 310 | f" no label data: {100 * no_label_ratio:.2f} %." 311 | ) 312 | print( 313 | f"Successfully binarized {prefix} set, " 314 | f"total time {total_time:.2f}s, saved to {h5py_file_path}" 315 | ) 316 | 317 | def get_meta_data(self, data_folder, vocab): 318 | path = data_folder 319 | trans_path_list = [ 320 | i 321 | for i in path.rglob("transcriptions.csv") 322 | if i.name == "transcriptions.csv" 323 | ] 324 | if len(trans_path_list) <= 0: 325 | warnings.warn(f"No transcriptions.csv found in {data_folder}.") 326 | 327 | print("Loading metadata...") 328 | meta_data_df = pd.DataFrame() 329 | for trans_path in tqdm(trans_path_list): 330 | df = pd.read_csv(trans_path, dtype=str) 331 | df["wav_path"] = df["name"].apply( 332 | lambda name: str(trans_path.parent / "wavs" / (str(name) + ".wav")), 333 | ) 334 | df["preferred"] = df["wav_path"].apply( 335 | lambda path_: ( 336 | True 337 | if any( 338 | [ 339 | i in pathlib.Path(path_).parts 340 | for i in self.valid_set_preferred_folders 341 | ] 342 | ) 343 | else False 344 | ), 345 | ) 346 | df["label_type"] = df["wav_path"].apply( 347 | lambda path_: ( 348 | "full_label" 349 | if "full_label" in path_ 350 | else "weak_label" if "weak_label" in path_ else "no_label" 351 | ), 352 | ) 353 | if len(meta_data_df) >= 1: 354 | meta_data_df = pd.concat([meta_data_df, df]) 355 | else: 356 | meta_data_df = df 357 | 358 | no_label_df = pd.DataFrame( 359 | {"wav_path": [i for i in (path / "no_label").rglob("*.wav")]} 360 | ) 361 | meta_data_df = pd.concat([meta_data_df, no_label_df]) 362 | meta_data_df["label_type"].fillna("no_label", inplace=True) 363 | 364 | meta_data_df.reset_index(drop=True, inplace=True) 365 | 366 | meta_data_df["ph_seq"] = meta_data_df["ph_seq"].apply( 367 | lambda x: ([vocab[i] for i in x.split(" ")] if isinstance(x, str) else []) 368 | ) 369 | if "ph_dur" in meta_data_df.columns: 370 | meta_data_df["ph_dur"] = meta_data_df["ph_dur"].apply( 371 | lambda x: ( 372 | [float(i) for i in x.split(" ")] if isinstance(x, str) else [] 373 | ) 374 | ) 375 | meta_data_df = meta_data_df.sort_values(by="label_type").reset_index(drop=True) 376 | 377 | return meta_data_df 378 | 379 | 380 | @click.command() 381 | @click.option( 382 | "--config_path", 383 | "-c", 384 | type=str, 385 | default="configs/binarize_config.yaml", 386 | show_default=True, 387 | help="binarize config path", 388 | ) 389 | def binarize(config_path: str): 390 | with open(config_path, "r") as f: 391 | config = yaml.safe_load(f) 392 | 393 | global_config = { 394 | "max_length": config["max_length"], 395 | "melspec_config": config["melspec_config"], 396 | "data_augmentation_size": config["data_augmentation"]["size"], 397 | } 398 | with open(pathlib.Path("data/binary/") / "global_config.yaml", "w") as file: 399 | yaml.dump(global_config, file) 400 | 401 | ForcedAlignmentBinarizer(**config).process() 402 | 403 | 404 | if __name__ == "__main__": 405 | binarize() 406 | --------------------------------------------------------------------------------