├── LICENSE ├── README.md ├── asserts ├── libriphrase.png ├── overview.png └── wenetphrase.png ├── conformer └── conformer │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── convolution.py │ ├── embedding.py │ ├── encoder.py │ ├── feed_forward.py │ ├── model.py │ ├── model_def.py │ └── modules.py ├── data-processing └── wenetspeech │ ├── README.md │ ├── cn_tn.py │ ├── norm_txt.py │ ├── read.py │ └── wenetclip.py ├── dataloaders ├── SPC_N0_ALL.py ├── SPC_N0_TARGET.py ├── SPC_N1_ALL.py ├── SPC_N1_TARGET.py ├── __pycache__ │ ├── SPC_N0_ALL.cpython-310.pyc │ ├── SPC_N0_TARGET.cpython-310.pyc │ ├── SPC_N1_ALL.cpython-310.pyc │ ├── SPC_N1_TARGET.cpython-310.pyc │ ├── libriphrase_test.cpython-310.pyc │ ├── libriphrase_test_18.cpython-310.pyc │ ├── libriphrase_train.cpython-310.pyc │ ├── libriphrase_trainY.cpython-310.pyc │ └── libriphrase_train_18.cpython-310.pyc ├── g2p │ ├── LICENSE.txt │ ├── g2p_en │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── expand.cpython-310.pyc │ │ │ ├── expand.cpython-39.pyc │ │ │ ├── g2p.cpython-310.pyc │ │ │ └── g2p.cpython-39.pyc │ │ ├── checkpoint20.npz │ │ ├── expand.py │ │ ├── g2p.py │ │ └── homographs.en │ └── lightning_logs │ │ ├── version_0 │ │ ├── events.out.tfevents.1703650022.great-server24.1715931.0 │ │ └── hparams.yaml │ │ ├── version_1 │ │ ├── events.out.tfevents.1703650063.great-server24.1718627.0 │ │ └── hparams.yaml │ │ ├── version_2 │ │ ├── events.out.tfevents.1703650148.great-server24.1718627.1 │ │ └── hparams.yaml │ │ ├── version_3 │ │ ├── events.out.tfevents.1703651214.great-server24.1769388.0 │ │ └── hparams.yaml │ │ ├── version_4 │ │ ├── events.out.tfevents.1703651335.great-server24.1769388.1 │ │ └── hparams.yaml │ │ ├── version_5 │ │ ├── events.out.tfevents.1703651958.great-server24.1794497.0 │ │ └── hparams.yaml │ │ └── version_6 │ │ ├── events.out.tfevents.1703652071.great-server24.1794497.1 │ │ └── hparams.yaml ├── libriphrase_test.py ├── libriphrase_test_18.py ├── libriphrase_train.py ├── libriphrase_trainY.py ├── libriphrase_trainY2.py ├── libriphrase_train_18.py ├── wenetphrase_test.py └── wenetphrase_train.py ├── g2p ├── LICENSE.txt └── g2p_en │ ├── __init__.py │ ├── checkpoint20.npz │ ├── expand.py │ ├── g2p.py │ └── homographs.en ├── libriphrase_hardneg.json.zip ├── mdtc.py ├── mm-kws ├── LICENSE ├── README.md ├── conformer │ └── conformer │ │ ├── __init__.py │ │ ├── activation.py │ │ ├── attention.py │ │ ├── convolution.py │ │ ├── embedding.py │ │ ├── encoder.py │ │ ├── feed_forward.py │ │ ├── model.py │ │ ├── model_def.py │ │ └── modules.py ├── dataloaders │ ├── SPC_N0_ALL.py │ ├── SPC_N0_TARGET.py │ ├── SPC_N1_ALL.py │ ├── SPC_N1_TARGET.py │ ├── __pycache__ │ │ ├── SPC_N0_ALL.cpython-310.pyc │ │ ├── SPC_N0_TARGET.cpython-310.pyc │ │ ├── SPC_N1_ALL.cpython-310.pyc │ │ ├── SPC_N1_TARGET.cpython-310.pyc │ │ ├── libriphrase_test.cpython-310.pyc │ │ ├── libriphrase_test_18.cpython-310.pyc │ │ ├── libriphrase_train.cpython-310.pyc │ │ ├── libriphrase_trainY.cpython-310.pyc │ │ └── libriphrase_train_18.cpython-310.pyc │ ├── g2p │ │ ├── LICENSE.txt │ │ └── g2p_en │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── expand.cpython-310.pyc │ │ │ ├── expand.cpython-39.pyc │ │ │ ├── g2p.cpython-310.pyc │ │ │ └── g2p.cpython-39.pyc │ │ │ ├── checkpoint20.npz │ │ │ ├── expand.py │ │ │ ├── g2p.py │ │ │ └── homographs.en │ ├── libriphrase_test.py │ ├── libriphrase_test_18.py │ ├── libriphrase_train.py │ ├── libriphrase_trainY.py │ ├── libriphrase_trainY2.py │ ├── libriphrase_train_18.py │ ├── wenetphrase_test.py │ └── wenetphrase_train.py ├── g2p │ ├── LICENSE.txt │ └── g2p_en │ │ ├── __init__.py │ │ ├── checkpoint20.npz │ │ ├── expand.py │ │ ├── g2p.py │ │ └── homographs.en ├── libriphrase_hardneg.json.zip ├── mdtc.py ├── models.py ├── models_tiny.py ├── test.py ├── train.py └── wenetphrase_hardneg.json.zip ├── models.py ├── models_tiny.py ├── test.py ├── train.py └── wenetphrase_hardneg.json.zip /asserts/libriphrase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/asserts/libriphrase.png -------------------------------------------------------------------------------- /asserts/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/asserts/overview.png -------------------------------------------------------------------------------- /asserts/wenetphrase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/asserts/wenetphrase.png -------------------------------------------------------------------------------- /conformer/conformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .model import Conformer 16 | -------------------------------------------------------------------------------- /conformer/conformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | from torch import Tensor 17 | 18 | 19 | class Swish(nn.Module): 20 | """ 21 | Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied 22 | to a variety of challenging domains such as Image classification and Machine translation. 23 | """ 24 | def __init__(self): 25 | super(Swish, self).__init__() 26 | 27 | def forward(self, inputs: Tensor) -> Tensor: 28 | return inputs * inputs.sigmoid() 29 | 30 | 31 | class GLU(nn.Module): 32 | """ 33 | The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing 34 | in the paper “Language Modeling with Gated Convolutional Networks” 35 | """ 36 | def __init__(self, dim: int) -> None: 37 | super(GLU, self).__init__() 38 | self.dim = dim 39 | 40 | def forward(self, inputs: Tensor) -> Tensor: 41 | outputs, gate = inputs.chunk(2, dim=self.dim) 42 | return outputs * gate.sigmoid() 43 | -------------------------------------------------------------------------------- /conformer/conformer/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from torch import Tensor 20 | from typing import Optional 21 | 22 | from .embedding import PositionalEncoding 23 | from .modules import Linear 24 | 25 | 26 | class RelativeMultiHeadAttention(nn.Module): 27 | """ 28 | Multi-head attention with relative positional encoding. 29 | This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" 30 | 31 | Args: 32 | d_model (int): The dimension of model 33 | num_heads (int): The number of attention heads. 34 | dropout_p (float): probability of dropout 35 | 36 | Inputs: query, key, value, pos_embedding, mask 37 | - **query** (batch, time, dim): Tensor containing query vector 38 | - **key** (batch, time, dim): Tensor containing key vector 39 | - **value** (batch, time, dim): Tensor containing value vector 40 | - **pos_embedding** (batch, time, dim): Positional embedding tensor 41 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked 42 | 43 | Returns: 44 | - **outputs**: Tensor produces by relative multi head attention module. 45 | """ 46 | def __init__( 47 | self, 48 | d_model: int = 512, 49 | num_heads: int = 16, 50 | dropout_p: float = 0.1, 51 | ): 52 | super(RelativeMultiHeadAttention, self).__init__() 53 | assert d_model % num_heads == 0, "d_model % num_heads should be zero." 54 | self.d_model = d_model 55 | self.d_head = int(d_model / num_heads) 56 | self.num_heads = num_heads 57 | self.sqrt_dim = math.sqrt(d_model) 58 | 59 | self.query_proj = Linear(d_model, d_model) 60 | self.key_proj = Linear(d_model, d_model) 61 | self.value_proj = Linear(d_model, d_model) 62 | self.pos_proj = Linear(d_model, d_model, bias=False) 63 | 64 | self.dropout = nn.Dropout(p=dropout_p) 65 | self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) 66 | self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) 67 | torch.nn.init.xavier_uniform_(self.u_bias) 68 | torch.nn.init.xavier_uniform_(self.v_bias) 69 | 70 | self.out_proj = Linear(d_model, d_model) 71 | 72 | def forward( 73 | self, 74 | query: Tensor, 75 | key: Tensor, 76 | value: Tensor, 77 | pos_embedding: Tensor, 78 | mask: Optional[Tensor] = None, 79 | ) -> Tensor: 80 | batch_size = value.size(0) 81 | 82 | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) 83 | key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) 84 | value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) 85 | pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) 86 | 87 | content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3)) 88 | pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1)) 89 | pos_score = self._relative_shift(pos_score) 90 | 91 | score = (content_score + pos_score) / self.sqrt_dim 92 | 93 | if mask is not None: 94 | mask = mask.unsqueeze(1) 95 | score.masked_fill_(mask, -1e9) 96 | 97 | attn = F.softmax(score, -1) 98 | attn = self.dropout(attn) 99 | 100 | context = torch.matmul(attn, value).transpose(1, 2) 101 | context = context.contiguous().view(batch_size, -1, self.d_model) 102 | 103 | return self.out_proj(context) 104 | 105 | def _relative_shift(self, pos_score: Tensor) -> Tensor: 106 | batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() 107 | zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) 108 | padded_pos_score = torch.cat([zeros, pos_score], dim=-1) 109 | 110 | padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) 111 | pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) 112 | 113 | return pos_score 114 | 115 | 116 | class MultiHeadedSelfAttentionModule(nn.Module): 117 | """ 118 | Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, 119 | the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention 120 | module to generalize better on different input length and the resulting encoder is more robust to the variance of 121 | the utterance length. Conformer use prenorm residual units with dropout which helps training 122 | and regularizing deeper models. 123 | 124 | Args: 125 | d_model (int): The dimension of model 126 | num_heads (int): The number of attention heads. 127 | dropout_p (float): probability of dropout 128 | 129 | Inputs: inputs, mask 130 | - **inputs** (batch, time, dim): Tensor containing input vector 131 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked 132 | 133 | Returns: 134 | - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. 135 | """ 136 | def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1): 137 | super(MultiHeadedSelfAttentionModule, self).__init__() 138 | self.positional_encoding = PositionalEncoding(d_model) 139 | self.layer_norm = nn.LayerNorm(d_model) 140 | self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) 141 | self.dropout = nn.Dropout(p=dropout_p) 142 | 143 | def forward(self, inputs: Tensor, mask: Optional[Tensor] = None): 144 | batch_size, seq_length, _ = inputs.size() 145 | pos_embedding = self.positional_encoding(seq_length) 146 | pos_embedding = pos_embedding.repeat(batch_size, 1, 1) 147 | 148 | inputs = self.layer_norm(inputs) 149 | outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask) 150 | 151 | return self.dropout(outputs) 152 | -------------------------------------------------------------------------------- /conformer/conformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from typing import Tuple 19 | 20 | from .activation import Swish, GLU 21 | from .modules import Transpose 22 | 23 | 24 | class DepthwiseConv1d(nn.Module): 25 | """ 26 | When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, 27 | this operation is termed in literature as depthwise convolution. 28 | 29 | Args: 30 | in_channels (int): Number of channels in the input 31 | out_channels (int): Number of channels produced by the convolution 32 | kernel_size (int or tuple): Size of the convolving kernel 33 | stride (int, optional): Stride of the convolution. Default: 1 34 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 35 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True 36 | 37 | Inputs: inputs 38 | - **inputs** (batch, in_channels, time): Tensor containing input vector 39 | 40 | Returns: outputs 41 | - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. 42 | """ 43 | def __init__( 44 | self, 45 | in_channels: int, 46 | out_channels: int, 47 | kernel_size: int, 48 | stride: int = 1, 49 | padding: int = 0, 50 | bias: bool = False, 51 | ) -> None: 52 | super(DepthwiseConv1d, self).__init__() 53 | assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" 54 | self.conv = nn.Conv1d( 55 | in_channels=in_channels, 56 | out_channels=out_channels, 57 | kernel_size=kernel_size, 58 | groups=in_channels, 59 | stride=stride, 60 | padding=padding, 61 | bias=bias, 62 | ) 63 | 64 | def forward(self, inputs: Tensor) -> Tensor: 65 | return self.conv(inputs) 66 | 67 | 68 | class PointwiseConv1d(nn.Module): 69 | """ 70 | When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution. 71 | This operation often used to match dimensions. 72 | 73 | Args: 74 | in_channels (int): Number of channels in the input 75 | out_channels (int): Number of channels produced by the convolution 76 | stride (int, optional): Stride of the convolution. Default: 1 77 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 78 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True 79 | 80 | Inputs: inputs 81 | - **inputs** (batch, in_channels, time): Tensor containing input vector 82 | 83 | Returns: outputs 84 | - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution. 85 | """ 86 | def __init__( 87 | self, 88 | in_channels: int, 89 | out_channels: int, 90 | stride: int = 1, 91 | padding: int = 0, 92 | bias: bool = True, 93 | ) -> None: 94 | super(PointwiseConv1d, self).__init__() 95 | self.conv = nn.Conv1d( 96 | in_channels=in_channels, 97 | out_channels=out_channels, 98 | kernel_size=1, 99 | stride=stride, 100 | padding=padding, 101 | bias=bias, 102 | ) 103 | 104 | def forward(self, inputs: Tensor) -> Tensor: 105 | return self.conv(inputs) 106 | 107 | 108 | class ConformerConvModule(nn.Module): 109 | """ 110 | Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). 111 | This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution 112 | to aid training deep models. 113 | 114 | Args: 115 | in_channels (int): Number of channels in the input 116 | kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 117 | dropout_p (float, optional): probability of dropout 118 | 119 | Inputs: inputs 120 | inputs (batch, time, dim): Tensor contains input sequences 121 | 122 | Outputs: outputs 123 | outputs (batch, time, dim): Tensor produces by conformer convolution module. 124 | """ 125 | def __init__( 126 | self, 127 | in_channels: int, 128 | kernel_size: int = 31, 129 | expansion_factor: int = 2, 130 | dropout_p: float = 0.1, 131 | ) -> None: 132 | super(ConformerConvModule, self).__init__() 133 | assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" 134 | assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" 135 | 136 | self.sequential = nn.Sequential( 137 | nn.LayerNorm(in_channels), 138 | Transpose(shape=(1, 2)), 139 | PointwiseConv1d(in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True), 140 | GLU(dim=1), 141 | DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2), 142 | nn.BatchNorm1d(in_channels), 143 | Swish(), 144 | PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), 145 | nn.Dropout(p=dropout_p), 146 | ) 147 | 148 | def forward(self, inputs: Tensor) -> Tensor: 149 | return self.sequential(inputs).transpose(1, 2) 150 | 151 | 152 | class Conv2dSubampling(nn.Module): 153 | """ 154 | Convolutional 2D subsampling (to 1/4 length) 155 | 156 | Args: 157 | in_channels (int): Number of channels in the input image 158 | out_channels (int): Number of channels produced by the convolution 159 | 160 | Inputs: inputs 161 | - **inputs** (batch, time, dim): Tensor containing sequence of inputs 162 | 163 | Returns: outputs, output_lengths 164 | - **outputs** (batch, time, dim): Tensor produced by the convolution 165 | - **output_lengths** (batch): list of sequence output lengths 166 | """ 167 | def __init__(self, in_channels: int, out_channels: int) -> None: 168 | super(Conv2dSubampling, self).__init__() 169 | self.sequential = nn.Sequential( 170 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2), 171 | nn.ReLU(), 172 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2), 173 | nn.ReLU(), 174 | ) 175 | 176 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: 177 | outputs = self.sequential(inputs.unsqueeze(1)) 178 | batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size() 179 | 180 | outputs = outputs.permute(0, 2, 1, 3) 181 | outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim) 182 | 183 | output_lengths = input_lengths >> 2 184 | output_lengths -= 1 185 | 186 | return outputs, output_lengths 187 | -------------------------------------------------------------------------------- /conformer/conformer/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | from torch import Tensor 19 | 20 | 21 | class PositionalEncoding(nn.Module): 22 | """ 23 | Positional Encoding proposed in "Attention Is All You Need". 24 | Since transformer contains no recurrence and no convolution, in order for the model to make 25 | use of the order of the sequence, we must add some positional information. 26 | 27 | "Attention Is All You Need" use sine and cosine functions of different frequencies: 28 | PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) 29 | PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model)) 30 | """ 31 | def __init__(self, d_model: int = 512, max_len: int = 10000) -> None: 32 | super(PositionalEncoding, self).__init__() 33 | pe = torch.zeros(max_len, d_model, requires_grad=False) 34 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 35 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 36 | pe[:, 0::2] = torch.sin(position * div_term) 37 | pe[:, 1::2] = torch.cos(position * div_term) 38 | pe = pe.unsqueeze(0) 39 | self.register_buffer('pe', pe) 40 | 41 | def forward(self, length: int) -> Tensor: 42 | return self.pe[:, :length] -------------------------------------------------------------------------------- /conformer/conformer/feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | 19 | from .activation import Swish 20 | from .modules import Linear 21 | 22 | 23 | class FeedForwardModule(nn.Module): 24 | """ 25 | Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit 26 | and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps 27 | regularizing the network. 28 | 29 | Args: 30 | encoder_dim (int): Dimension of conformer encoder 31 | expansion_factor (int): Expansion factor of feed forward module. 32 | dropout_p (float): Ratio of dropout 33 | 34 | Inputs: inputs 35 | - **inputs** (batch, time, dim): Tensor contains input sequences 36 | 37 | Outputs: outputs 38 | - **outputs** (batch, time, dim): Tensor produces by feed forward module. 39 | """ 40 | def __init__( 41 | self, 42 | encoder_dim: int = 512, 43 | expansion_factor: int = 4, 44 | dropout_p: float = 0.1, 45 | ) -> None: 46 | super(FeedForwardModule, self).__init__() 47 | self.sequential = nn.Sequential( 48 | nn.LayerNorm(encoder_dim), 49 | Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), 50 | Swish(), 51 | nn.Dropout(p=dropout_p), 52 | Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), 53 | nn.Dropout(p=dropout_p), 54 | ) 55 | 56 | def forward(self, inputs: Tensor) -> Tensor: 57 | return self.sequential(inputs) 58 | -------------------------------------------------------------------------------- /conformer/conformer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from typing import Tuple 19 | 20 | from .encoder import ConformerEncoder 21 | from .modules import Linear 22 | 23 | 24 | class Conformer(nn.Module): 25 | """ 26 | Conformer: Convolution-augmented Transformer for Speech Recognition 27 | The paper used a one-lstm Transducer decoder, currently still only implemented 28 | the conformer encoder shown in the paper. 29 | 30 | Args: 31 | num_classes (int): Number of classification classes 32 | input_dim (int, optional): Dimension of input vector 33 | encoder_dim (int, optional): Dimension of conformer encoder 34 | num_encoder_layers (int, optional): Number of conformer blocks 35 | num_attention_heads (int, optional): Number of attention heads 36 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module 37 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module 38 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout 39 | attention_dropout_p (float, optional): Probability of attention module dropout 40 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout 41 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel 42 | half_step_residual (bool): Flag indication whether to use half step residual or not 43 | 44 | Inputs: inputs, input_lengths 45 | - **inputs** (batch, time, dim): Tensor containing input vector 46 | - **input_lengths** (batch): list of sequence input lengths 47 | 48 | Returns: outputs, output_lengths 49 | - **outputs** (batch, out_channels, time): Tensor produces by conformer. 50 | - **output_lengths** (batch): list of sequence output lengths 51 | """ 52 | def __init__( 53 | self, 54 | num_classes: int, 55 | input_dim: int = 80, 56 | encoder_dim: int = 512, 57 | num_encoder_layers: int = 17, 58 | num_attention_heads: int = 8, 59 | feed_forward_expansion_factor: int = 4, 60 | conv_expansion_factor: int = 2, 61 | input_dropout_p: float = 0.1, 62 | feed_forward_dropout_p: float = 0.1, 63 | attention_dropout_p: float = 0.1, 64 | conv_dropout_p: float = 0.1, 65 | conv_kernel_size: int = 31, 66 | half_step_residual: bool = True, 67 | ) -> None: 68 | super(Conformer, self).__init__() 69 | self.encoder = ConformerEncoder( 70 | input_dim=input_dim, 71 | encoder_dim=encoder_dim, 72 | num_layers=num_encoder_layers, 73 | num_attention_heads=num_attention_heads, 74 | feed_forward_expansion_factor=feed_forward_expansion_factor, 75 | conv_expansion_factor=conv_expansion_factor, 76 | input_dropout_p=input_dropout_p, 77 | feed_forward_dropout_p=feed_forward_dropout_p, 78 | attention_dropout_p=attention_dropout_p, 79 | conv_dropout_p=conv_dropout_p, 80 | conv_kernel_size=conv_kernel_size, 81 | half_step_residual=half_step_residual, 82 | ) 83 | self.fc = Linear(encoder_dim, num_classes, bias=False) 84 | 85 | def count_parameters(self) -> int: 86 | """ Count parameters of encoder """ 87 | return self.encoder.count_parameters() 88 | 89 | def update_dropout(self, dropout_p) -> None: 90 | """ Update dropout probability of model """ 91 | self.encoder.update_dropout(dropout_p) 92 | 93 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: 94 | """ 95 | Forward propagate a `inputs` and `targets` pair for training. 96 | 97 | Args: 98 | inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded 99 | `FloatTensor` of size ``(batch, seq_length, dimension)``. 100 | input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` 101 | 102 | Returns: 103 | * predictions (torch.FloatTensor): Result of model predictions. 104 | """ 105 | encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths) 106 | outputs = self.fc(encoder_outputs) 107 | outputs = nn.functional.log_softmax(outputs, dim=-1) 108 | return outputs, encoder_output_lengths 109 | -------------------------------------------------------------------------------- /conformer/conformer/model_def.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from typing import Tuple 19 | 20 | from .encoder import ConformerEncoder 21 | from .modules import Linear 22 | 23 | 24 | class Conformer(nn.Module): 25 | """ 26 | Conformer: Convolution-augmented Transformer for Speech Recognition 27 | The paper used a one-lstm Transducer decoder, currently still only implemented 28 | the conformer encoder shown in the paper. 29 | 30 | Args: 31 | num_classes (int): Number of classification classes 32 | input_dim (int, optional): Dimension of input vector 33 | encoder_dim (int, optional): Dimension of conformer encoder 34 | num_encoder_layers (int, optional): Number of conformer blocks 35 | num_attention_heads (int, optional): Number of attention heads 36 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module 37 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module 38 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout 39 | attention_dropout_p (float, optional): Probability of attention module dropout 40 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout 41 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel 42 | half_step_residual (bool): Flag indication whether to use half step residual or not 43 | 44 | Inputs: inputs, input_lengths 45 | - **inputs** (batch, time, dim): Tensor containing input vector 46 | - **input_lengths** (batch): list of sequence input lengths 47 | 48 | Returns: outputs, output_lengths 49 | - **outputs** (batch, out_channels, time): Tensor produces by conformer. 50 | - **output_lengths** (batch): list of sequence output lengths 51 | """ 52 | def __init__( 53 | self, 54 | input_dim: int = 80, 55 | encoder_dim: int = 512, 56 | num_encoder_layers: int = 17, 57 | num_attention_heads: int = 8, 58 | feed_forward_expansion_factor: int = 4, 59 | conv_expansion_factor: int = 2, 60 | input_dropout_p: float = 0.1, 61 | feed_forward_dropout_p: float = 0.1, 62 | attention_dropout_p: float = 0.1, 63 | conv_dropout_p: float = 0.1, 64 | conv_kernel_size: int = 31, 65 | half_step_residual: bool = True, 66 | ) -> None: 67 | super(Conformer, self).__init__() 68 | self.encoder = ConformerEncoder( 69 | input_dim=input_dim, 70 | encoder_dim=encoder_dim, 71 | num_layers=num_encoder_layers, 72 | num_attention_heads=num_attention_heads, 73 | feed_forward_expansion_factor=feed_forward_expansion_factor, 74 | conv_expansion_factor=conv_expansion_factor, 75 | input_dropout_p=input_dropout_p, 76 | feed_forward_dropout_p=feed_forward_dropout_p, 77 | attention_dropout_p=attention_dropout_p, 78 | conv_dropout_p=conv_dropout_p, 79 | conv_kernel_size=conv_kernel_size, 80 | half_step_residual=half_step_residual, 81 | ) 82 | 83 | def count_parameters(self) -> int: 84 | """ Count parameters of encoder """ 85 | return self.encoder.count_parameters() 86 | 87 | def update_dropout(self, dropout_p) -> None: 88 | """ Update dropout probability of model """ 89 | self.encoder.update_dropout(dropout_p) 90 | 91 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: 92 | """ 93 | Forward propagate a `inputs` and `targets` pair for training. 94 | 95 | Args: 96 | inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded 97 | `FloatTensor` of size ``(batch, seq_length, dimension)``. 98 | input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` 99 | 100 | Returns: 101 | * predictions (torch.FloatTensor): Result of model predictions. 102 | """ 103 | encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths) 104 | return encoder_outputs, encoder_output_lengths 105 | -------------------------------------------------------------------------------- /conformer/conformer/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.init as init 18 | from torch import Tensor 19 | 20 | 21 | class ResidualConnectionModule(nn.Module): 22 | """ 23 | Residual Connection Module. 24 | outputs = (module(inputs) x module_factor + inputs x input_factor) 25 | """ 26 | def __init__(self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0): 27 | super(ResidualConnectionModule, self).__init__() 28 | self.module = module 29 | self.module_factor = module_factor 30 | self.input_factor = input_factor 31 | 32 | def forward(self, inputs: Tensor) -> Tensor: 33 | return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) 34 | 35 | 36 | class Linear(nn.Module): 37 | """ 38 | Wrapper class of torch.nn.Linear 39 | Weight initialize by xavier initialization and bias initialize to zeros. 40 | """ 41 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: 42 | super(Linear, self).__init__() 43 | self.linear = nn.Linear(in_features, out_features, bias=bias) 44 | init.xavier_uniform_(self.linear.weight) 45 | if bias: 46 | init.zeros_(self.linear.bias) 47 | 48 | def forward(self, x: Tensor) -> Tensor: 49 | return self.linear(x) 50 | 51 | 52 | class View(nn.Module): 53 | """ Wrapper class of torch.view() for Sequential module. """ 54 | def __init__(self, shape: tuple, contiguous: bool = False): 55 | super(View, self).__init__() 56 | self.shape = shape 57 | self.contiguous = contiguous 58 | 59 | def forward(self, x: Tensor) -> Tensor: 60 | if self.contiguous: 61 | x = x.contiguous() 62 | 63 | return x.view(*self.shape) 64 | 65 | 66 | class Transpose(nn.Module): 67 | """ Wrapper class of torch.transpose() for Sequential module. """ 68 | def __init__(self, shape: tuple): 69 | super(Transpose, self).__init__() 70 | self.shape = shape 71 | 72 | def forward(self, x: Tensor) -> Tensor: 73 | return x.transpose(*self.shape) 74 | -------------------------------------------------------------------------------- /data-processing/wenetspeech/README.md: -------------------------------------------------------------------------------- 1 | /Path/to/your/wenetspeech -------------------------------------------------------------------------------- /data-processing/wenetspeech/norm_txt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import Pool 3 | from tqdm import tqdm 4 | import json 5 | import os 6 | 7 | def find_files_with_suffix(root_folder, endwiths): 8 | matching_files = [] # 以后缀名为键,文件路径列表为值初始化字典 9 | for root, dirs, files in os.walk(root_folder): 10 | for file in files: 11 | for suffix in endwiths: 12 | if file.endswith(suffix): 13 | matching_files.append(os.path.join(root, file)) 14 | return matching_files 15 | 16 | import os 17 | import multiprocessing 18 | from tqdm import tqdm 19 | import subprocess 20 | def process_file(file): 21 | output_file = file.replace('.txt', '_norm.txt') 22 | subprocess.run(['python', 'data-processing/wenetspeech/cn_tn.py', file, output_file]) 23 | 24 | root_folder = '/path/to/wenetspeech_clips/' 25 | endwiths = ['.txt'] 26 | files = find_files_with_suffix(root_folder, endwiths) 27 | with multiprocessing.Pool(32) as pool: 28 | # 使用 tqdm 并行迭代文件列表,并显示进度条 29 | for _ in tqdm(pool.imap_unordered(process_file, files), total=len(files)): 30 | pass -------------------------------------------------------------------------------- /data-processing/wenetspeech/read.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import Pool 3 | from tqdm import tqdm 4 | import torchaudio 5 | import json 6 | 7 | def process_audio(aidx): 8 | path = os.path.join(data_dir, wenetspeech['audios'][aidx]['path'][:-5] + '.wav') 9 | waveform, sr = torchaudio.load(path) 10 | 11 | for seg in range(len(wenetspeech['audios'][aidx]['segments'])): 12 | sid = wenetspeech['audios'][aidx]['segments'][seg]['sid'] 13 | begin_time = wenetspeech['audios'][aidx]['segments'][seg]['begin_time'] 14 | end_time = wenetspeech['audios'][aidx]['segments'][seg]['end_time'] 15 | subsets = wenetspeech['audios'][aidx]['segments'][seg]['subsets'] 16 | text = wenetspeech['audios'][aidx]['segments'][seg]['text'] 17 | 18 | if 'M' in subsets and 'S' not in subsets: 19 | save_path = os.path.join("/server24/aizq/wenetspeech_clips/M_S", path.split('/')[-3], path.split('/')[-2], sid) + '.wav' 20 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 21 | 22 | if not os.path.exists(save_path): 23 | torchaudio.save(save_path, waveform[:, int(begin_time * sr):int(end_time * sr)], sample_rate=sr) 24 | 25 | save_path_label = os.path.join("/server24/aizq/wenetspeech_clips/M_S", path.split('/')[-3], path.split('/')[-2], sid) + '.txt' 26 | 27 | if not os.path.exists(save_path_label): 28 | with open(save_path_label, "w") as file: 29 | file.write(text) 30 | 31 | if 'S' in subsets: 32 | save_path = os.path.join("/server24/aizq/wenetspeech_clips/S", path.split('/')[-3], path.split('/')[-2], sid) + '.wav' 33 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 34 | 35 | if not os.path.exists(save_path): 36 | torchaudio.save(save_path, waveform[:, int(begin_time * sr):int(end_time * sr)], sample_rate=sr) 37 | 38 | save_path_label = os.path.join("/server24/aizq/wenetspeech_clips/S", path.split('/')[-3], path.split('/')[-2], sid) + '.txt' 39 | 40 | if not os.path.exists(save_path_label): 41 | with open(save_path_label, "w") as file: 42 | file.write(text) 43 | 44 | 45 | if __name__ == '__main__': 46 | print("读取 WenetSpeech.json") 47 | with open('/server24/aizq/wenetspeech_UNTAR/WenetSpeech.json', 'r') as f: wenetspeech = json.load(f) 48 | print("读取完成正在处理") 49 | data_dir = "/server24/aizq/wenetspeech_UNTAR" 50 | num_processes = 32 51 | with Pool(num_processes) as pool: 52 | for _ in tqdm(pool.imap_unordered(process_audio, range(len(wenetspeech['audios']))), total=len(wenetspeech['audios'])): 53 | pass 54 | -------------------------------------------------------------------------------- /data-processing/wenetspeech/wenetclip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import multiprocessing 3 | from tqdm import tqdm 4 | 5 | 6 | def find_files_with_suffix(root_folder, endwiths): 7 | matching_files = [] # 以后缀名为键,文件路径列表为值初始化字典 8 | for root, dirs, files in os.walk(root_folder): 9 | for file in files: 10 | for suffix in endwiths: 11 | if file.endswith(suffix): 12 | matching_files.append(os.path.join(root, file)) 13 | return matching_files 14 | 15 | 16 | def read_text_file(file_path): 17 | with open(file_path, 'r') as file: 18 | content = file.read().rstrip('\n') 19 | return content 20 | 21 | 22 | # 创建一个包含汉字数量大于等于2的集合 23 | def select_hanzi_greater_than_or_equal_to_2(input_set): 24 | result_set = set() 25 | for item in input_set: 26 | # 判断是否只包含汉字且汉字数量大于等于2 27 | if all('\u4e00' <= char <= '\u9fff' for char in item) and len(item) >= 2 and len(item) <= 6: 28 | result_set.add(item) 29 | return result_set 30 | 31 | 32 | import re 33 | 34 | def find_all_occurrences(text, pattern): 35 | occurrences = [(match.start(), match.end()) for match in re.finditer(pattern, text)] 36 | return occurrences 37 | 38 | 39 | import torch 40 | import torchaudio 41 | from torchaudio.pipelines import MMS_FA as bundle 42 | from typing import List 43 | from g2pM import G2pM # g2pM # g2pW 论文结果更好,但是是台湾人高的,普通话还需要额外做好多操作emm 44 | import jieba 45 | import numpy as np 46 | 47 | def _score(spans): 48 | return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans) 49 | 50 | 51 | 52 | class TORCHAUDIO_MFA: 53 | def __init__(self, device_id='0', save_dirs="/server24/aizq/mm_kws/datasets/WenetPhrase/M_S"): 54 | self.device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu") 55 | self.mfa = bundle.get_model() 56 | self.mfa.to(self.device) 57 | self.tokenizer = bundle.get_tokenizer() 58 | self.aligner = bundle.get_aligner() 59 | self.g2pm = G2pM() 60 | self.save_dirs = save_dirs 61 | 62 | 63 | def compute_alignments(self, waveform: torch.Tensor, transcript: List[str]): 64 | with torch.inference_mode(): 65 | emission, _ = self.mfa(waveform.to(self.device)) 66 | token_spans = self.aligner(emission[0], self.tokenizer(transcript)) 67 | return emission, token_spans 68 | 69 | 70 | def make_mfa(self, wav_file: str, sentence: str): 71 | try: 72 | seg_list = select_hanzi_greater_than_or_equal_to_2(list(jieba.cut(sentence, cut_all=True))) 73 | transcript = self.g2pm(sentence, tone=False) 74 | transcript = " ".join(transcript) 75 | # 使用列表推导式替换所有的 "nu" 76 | transcript = transcript.replace('u:', 'v') 77 | transcript = transcript.split() 78 | waveform, sample_rate = torchaudio.load(wav_file) 79 | waveform = waveform[0:1] 80 | emission, token_spans = self.compute_alignments(waveform, transcript) 81 | num_frames = emission.size(1) 82 | ratio = waveform.size(1) / num_frames 83 | for word in seg_list: 84 | result = find_all_occurrences(sentence, word) 85 | for i, (s, e) in enumerate(result): 86 | start = s 87 | end = e - 1 88 | x0 = int(ratio * token_spans[start][0].start) 89 | x1 = int(ratio * token_spans[end][-1].end) 90 | score = np.mean([_score(token_spans[i]) for i in range(start, end+1)]) 91 | if score > 0.3: # 设置了平均阈值 92 | save_path = os.path.join(self.save_dirs, word) + "/" + "_".join([wav_file.split('/')[-3], wav_file.split('/')[-2], wav_file.split('/')[-1][:-4], str(i)]) + '.wav' 93 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 94 | torchaudio.save(save_path, waveform[:, x0:x1], sample_rate=sample_rate) 95 | except Exception as e: 96 | print(wav_file) 97 | print(e) 98 | 99 | 100 | 101 | def data_process( 102 | sub_files, 103 | device_id 104 | ): 105 | try: 106 | os.environ['CUDA_VISIBLE_DEVICES'] = f'{device_id}' 107 | model = TORCHAUDIO_MFA(device_id='0') # 使用不同的设备ID 108 | for wav_file in tqdm(sub_files): 109 | txt_file = wav_file.replace('.wav', '_norm.txt') 110 | sentence = read_text_file(txt_file) 111 | model.make_mfa(wav_file, sentence) 112 | del model 113 | except Exception as e: 114 | print(e) 115 | 116 | 117 | 118 | def multiprocess_to_( 119 | todo_list, 120 | num_gpu=[0], # Default to GPU 0 if num_gpu is not provided 121 | num_process=3, 122 | ): 123 | num_available_gpus = len(num_gpu) 124 | with multiprocessing.Pool(processes=num_process) as pool: 125 | for i in range(num_process): 126 | sub_files = todo_list[i::num_process] 127 | device_id = num_gpu[i % num_available_gpus] 128 | pool.apply_async(data_process, args=(sub_files, device_id)) 129 | pool.close() 130 | pool.join() 131 | 132 | 133 | if __name__ == "__main__": 134 | target_dirs = "/server24/aizq/wenetspeech_clips/M_S" # S: 151600 & M_S: 1362900 135 | save_dirs = "/server24/aizq/mm_kws/datasets/WenetPhrase/M_S" 136 | todo_list = find_files_with_suffix(target_dirs, '.wav') 137 | print(len(todo_list)) 138 | num_gpu=[0, 1, 2, 3, 4, 5] 139 | multiprocess_to_( 140 | todo_list, 141 | num_gpu=num_gpu, 142 | num_process=len(num_gpu) * 5, 143 | ) -------------------------------------------------------------------------------- /dataloaders/SPC_N0_ALL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | import Levenshtein 15 | import re 16 | import random 17 | 18 | import os 19 | 20 | def get_files(path, endswith): 21 | _files = [] 22 | for root, dirs, files in os.walk(path): 23 | for file in files: 24 | if file.endswith(endswith): 25 | _files.append(os.path.join(root, file)) 26 | return _files 27 | 28 | 29 | class SPC_NO_Dataset(Dataset): 30 | def __init__( 31 | self, 32 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data", 33 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text", 34 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N0_ALL.csv", 35 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle", 36 | preprocess=True, 37 | ): 38 | super().__init__() 39 | if preprocess: 40 | target_dict = {} 41 | idx = 0 42 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Query_label', 'Support_label', 'label']) 43 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list))) 44 | classes = os.listdir(test_dir) 45 | random.shuffle(classes) 46 | Targets = classes[:10] 47 | for wav_idx in range(len(wav_id)): 48 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav' 49 | query_text = wav_id[wav_idx].split('_')[0] 50 | if query_text in Targets: 51 | for comparison_text in Targets: 52 | _label = 1 if comparison_text == query_text else 0 53 | target_dict[idx] = { 54 | 'id': wav_id[wav_idx], 55 | 'Query_text': query_text, 56 | 'Query_wav': wav, 57 | 'Support_text': comparison_text, 58 | 'Query_label': Targets.index(query_text), 59 | 'Support_label': Targets.index(comparison_text), 60 | 'label': _label 61 | } 62 | idx += 1 63 | else: 64 | for comparison_text in Targets: 65 | _label = 0 66 | target_dict[idx] = { 67 | 'id': wav_id[wav_idx], 68 | 'Query_text': query_text, 69 | 'Query_wav': wav, 70 | 'Support_text': comparison_text, 71 | 'Query_label': 10, 72 | 'Support_label': Targets.index(comparison_text), 73 | 'label': 10, 74 | } 75 | idx += 1 76 | 77 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True) 78 | self.data.to_csv(save_path, index=False) 79 | else: 80 | self.data = pd.read_csv(save_path) 81 | 82 | self.data = self.data.values.tolist() 83 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 84 | 85 | 86 | def __getitem__( 87 | self, 88 | index 89 | ): 90 | ids, Query_text, Query_wav, Support_text, Query_label, Support_label, label = self.data[index] 91 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank 92 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 93 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 94 | fbank_feature = fbank( 95 | Query_wav, 96 | num_mel_bins=80 97 | ) 98 | g2p_embed = torch.from_numpy(g2p_embed) 99 | g2p_embed = g2p_embed.type_as(fbank_feature) 100 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label) 101 | 102 | def __len__( 103 | self 104 | ): 105 | return len(self.data) 106 | 107 | 108 | 109 | 110 | 111 | def collate_fn(batch): 112 | ids, fbank_feature, g2p_embed, lm_embed, Query_label, Support_label, label = zip(*batch) 113 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 114 | lengths = [len(seq) for seq in padded_fbank_feature] 115 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 116 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 117 | label_tensor = torch.tensor(label) 118 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, Query_label, Support_label, label_tensor 119 | 120 | 121 | 122 | spc = SPC_NO_Dataset() 123 | spc[0] -------------------------------------------------------------------------------- /dataloaders/SPC_N0_TARGET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | 13 | import torch.nn.functional as F 14 | from torch.nn.utils.rnn import pad_sequence 15 | import Levenshtein 16 | import re 17 | import random 18 | 19 | import os 20 | 21 | def get_files(path, endswith): 22 | _files = [] 23 | for root, dirs, files in os.walk(path): 24 | for file in files: 25 | if file.endswith(endswith): 26 | _files.append(os.path.join(root, file)) 27 | return _files 28 | 29 | 30 | class SPC_NO_Dataset(Dataset): 31 | def __init__( 32 | self, 33 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data", 34 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text", 35 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N0_TARGET.csv", 36 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle", 37 | preprocess=True, 38 | ): 39 | super().__init__() 40 | if preprocess: 41 | target_dict = {} 42 | idx = 0 43 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Query_label', 'Support_label', 'label']) 44 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list))) 45 | classes = os.listdir(test_dir) 46 | random.shuffle(classes) 47 | Targets = classes[:10] 48 | for wav_idx in range(len(wav_id)): 49 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav' 50 | query_text = wav_id[wav_idx].split('_')[0] 51 | if query_text in Targets: 52 | for comparison_text in Targets: 53 | _label = 1 if comparison_text == query_text else 0 54 | target_dict[idx] = { 55 | 'id': wav_id[wav_idx], 56 | 'Query_text': query_text, 57 | 'Query_wav': wav, 58 | 'Support_text': comparison_text, 59 | 'Query_label': Targets.index(query_text), 60 | 'Support_label': Targets.index(comparison_text), 61 | 'label': _label 62 | } 63 | idx += 1 64 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True) 65 | self.data.to_csv(save_path, index=False) 66 | else: 67 | self.data = pd.read_csv(save_path) 68 | 69 | self.data = self.data.values.tolist() 70 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 71 | 72 | 73 | def __getitem__( 74 | self, 75 | index 76 | ): 77 | ids, Query_text, Query_wav, Support_text, Query_label, Support_label, label = self.data[index] 78 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank 79 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 80 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 81 | fbank_feature = fbank( 82 | Query_wav, 83 | num_mel_bins=80 84 | ) 85 | g2p_embed = torch.from_numpy(g2p_embed) 86 | g2p_embed = g2p_embed.type_as(fbank_feature) 87 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label) 88 | 89 | def __len__( 90 | self 91 | ): 92 | return len(self.data) 93 | 94 | 95 | 96 | 97 | 98 | def collate_fn(batch): 99 | ids, fbank_feature, g2p_embed, lm_embed, Query_label, Support_label, label = zip(*batch) 100 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 101 | lengths = [len(seq) for seq in padded_fbank_feature] 102 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 103 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 104 | label_tensor = torch.tensor(label) 105 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, Query_label, Support_label, label_tensor 106 | 107 | 108 | 109 | # spc = SPC_NO_Dataset() 110 | -------------------------------------------------------------------------------- /dataloaders/SPC_N1_ALL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | # import dataaug 13 | import torch.nn.functional as F 14 | from torch.nn.utils.rnn import pad_sequence 15 | import Levenshtein 16 | import re 17 | import random 18 | 19 | import os 20 | 21 | def get_files(path, endswith): 22 | _files = [] 23 | for root, dirs, files in os.walk(path): 24 | for file in files: 25 | if file.endswith(endswith): 26 | _files.append(os.path.join(root, file)) 27 | return _files 28 | from tqdm import tqdm 29 | 30 | class SPC_N1_Dataset(Dataset): 31 | def __init__( 32 | self, 33 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data", 34 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text", 35 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N1_ALL.csv", 36 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle", 37 | preprocess=True, 38 | ): 39 | super().__init__() 40 | if preprocess: 41 | target_dict = {} 42 | idx = 0 43 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Support_wav', 'Query_label', 'Support_label', 'label']) 44 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list))) 45 | classes = os.listdir(test_dir) 46 | random.shuffle(classes) 47 | Targets = classes[:10] 48 | supports_wavs = {} 49 | for comparison_text in Targets: 50 | supports_wav = get_files(os.path.join(test_dir, comparison_text), '_18.npy') 51 | random.shuffle(supports_wav) 52 | supports_wavs[comparison_text] = supports_wav 53 | for wav_idx in range(len(wav_id)): 54 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav' 55 | query_text = wav_id[wav_idx].split('_')[0] 56 | if query_text in Targets: 57 | for comparison_text in Targets: 58 | support_wav = random.choices(supports_wavs[comparison_text], k=2) 59 | _label = 1 if comparison_text == query_text else 0 60 | target_dict[idx] = { 61 | 'id': wav_id[wav_idx], 62 | 'Query_text': query_text, 63 | 'Query_wav': wav, 64 | 'Support_text': comparison_text, 65 | 'Support_wav': support_wav[0], 66 | 'Query_label': Targets.index(query_text), 67 | 'Support_label': Targets.index(comparison_text), 68 | 'label': _label 69 | } 70 | idx += 1 71 | else: 72 | for comparison_text in Targets: 73 | support_wav = random.choices(supports_wavs[comparison_text], k=5) 74 | _label = 0 75 | target_dict[idx] = { 76 | 'id': wav_id[wav_idx], 77 | 'Query_text': query_text, 78 | 'Query_wav': wav, 79 | 'Support_text': comparison_text, 80 | 'Support_wav': support_wav[0], 81 | 'Query_label': 10, 82 | 'Support_label': Targets.index(comparison_text), 83 | 'label': 10, 84 | } 85 | idx += 1 86 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True) 87 | self.data.to_csv(save_path, index=False) 88 | else: 89 | self.data = pd.read_csv(save_path) 90 | 91 | self.data = self.data.values.tolist() 92 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 93 | 94 | 95 | def __getitem__( 96 | self, 97 | index 98 | ): 99 | ids, Query_text, Query_wav, Support_text, Support_wav, Query_label, Support_label, label = self.data[index] 100 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank 101 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 102 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 103 | audiolm_embed = np.load(Support_wav) 104 | fbank_feature = fbank( 105 | Query_wav, 106 | num_mel_bins=80 107 | ) 108 | g2p_embed = torch.from_numpy(g2p_embed) 109 | g2p_embed = g2p_embed.type_as(fbank_feature) 110 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label) 111 | 112 | def __len__( 113 | self 114 | ): 115 | return len(self.data) 116 | 117 | 118 | 119 | 120 | 121 | def collate_fn(batch): 122 | ids, fbank_feature, g2p_embed, lm_embed, audiolm_embed, Query_label, Support_label, label = zip(*batch) 123 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 124 | lengths = [len(seq) for seq in padded_fbank_feature] 125 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 126 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 127 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0) 128 | label_tensor = torch.tensor(label) 129 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, Query_label, Support_label, label_tensor 130 | 131 | 132 | 133 | spc = SPC_N1_Dataset() 134 | # spc[0] -------------------------------------------------------------------------------- /dataloaders/SPC_N1_TARGET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | import Levenshtein 15 | import re 16 | import random 17 | 18 | import os 19 | 20 | def get_files(path, endswith): 21 | _files = [] 22 | for root, dirs, files in os.walk(path): 23 | for file in files: 24 | if file.endswith(endswith): 25 | _files.append(os.path.join(root, file)) 26 | return _files 27 | from tqdm import tqdm 28 | 29 | class SPC_N1_Dataset(Dataset): 30 | def __init__( 31 | self, 32 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data", 33 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text", 34 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N1_TARGET.csv", 35 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle", 36 | preprocess=True, 37 | ): 38 | super().__init__() 39 | if preprocess: 40 | target_dict = {} 41 | idx = 0 42 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Support_wav', 'Query_label', 'Support_label', 'label']) 43 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list))) 44 | classes = os.listdir(test_dir) 45 | random.shuffle(classes) 46 | Targets = classes[:10] 47 | # Targets = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go'] 48 | supports_wavs = {} 49 | for comparison_text in Targets: 50 | supports_wav = get_files(os.path.join(test_dir, comparison_text), '_18.npy') 51 | random.shuffle(supports_wav) 52 | supports_wavs[comparison_text] = supports_wav 53 | for wav_idx in range(len(wav_id)): 54 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav' 55 | query_text = wav_id[wav_idx].split('_')[0] 56 | if query_text in Targets: 57 | for comparison_text in Targets: 58 | _label = 1 if comparison_text == query_text else 0 59 | target_dict[idx] = { 60 | 'id': wav_id[wav_idx], 61 | 'Query_text': query_text, 62 | 'Query_wav': wav, 63 | 'Support_text': comparison_text, 64 | 'Support_wav': supports_wavs[comparison_text][0], 65 | 'Query_label': Targets.index(query_text), 66 | 'Support_label': Targets.index(comparison_text), 67 | 'label': _label 68 | } 69 | idx += 1 70 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True) 71 | self.data.to_csv(save_path, index=False) 72 | else: 73 | self.data = pd.read_csv(save_path) 74 | 75 | self.data = self.data.values.tolist() 76 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 77 | 78 | 79 | def __getitem__( 80 | self, 81 | index 82 | ): 83 | ids, Query_text, Query_wav, Support_text, Support_wav, Query_label, Support_label, label = self.data[index] 84 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank 85 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 86 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 87 | audiolm_embed = np.load(Support_wav) 88 | fbank_feature = fbank( 89 | Query_wav, 90 | num_mel_bins=80 91 | ) 92 | g2p_embed = torch.from_numpy(g2p_embed) 93 | g2p_embed = g2p_embed.type_as(fbank_feature) 94 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label) 95 | 96 | def __len__( 97 | self 98 | ): 99 | return len(self.data) 100 | 101 | 102 | 103 | 104 | 105 | def collate_fn(batch): 106 | ids, fbank_feature, g2p_embed, lm_embed, audiolm_embed, Query_label, Support_label, label = zip(*batch) 107 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 108 | lengths = [len(seq) for seq in padded_fbank_feature] 109 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 110 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 111 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0) 112 | label_tensor = torch.tensor(label) 113 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, Query_label, Support_label, label_tensor 114 | 115 | 116 | 117 | spc = SPC_N1_Dataset() 118 | # spc[0] -------------------------------------------------------------------------------- /dataloaders/__pycache__/SPC_N0_ALL.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/SPC_N0_ALL.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/SPC_N0_TARGET.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/SPC_N0_TARGET.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/SPC_N1_ALL.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/SPC_N1_ALL.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/SPC_N1_TARGET.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/SPC_N1_TARGET.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/libriphrase_test.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/libriphrase_test.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/libriphrase_test_18.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/libriphrase_test_18.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/libriphrase_train.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/libriphrase_train.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/libriphrase_trainY.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/libriphrase_trainY.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/libriphrase_train_18.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/libriphrase_train_18.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/g2p/g2p_en/__init__.py: -------------------------------------------------------------------------------- 1 | from .g2p import G2p 2 | -------------------------------------------------------------------------------- /dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /dataloaders/g2p/g2p_en/__pycache__/expand.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/g2p/g2p_en/__pycache__/expand.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-39.pyc -------------------------------------------------------------------------------- /dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-310.pyc -------------------------------------------------------------------------------- /dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-39.pyc -------------------------------------------------------------------------------- /dataloaders/g2p/g2p_en/checkpoint20.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/checkpoint20.npz -------------------------------------------------------------------------------- /dataloaders/g2p/g2p_en/expand.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | Borrowed 5 | from https://github.com/keithito/tacotron/blob/master/text/numbers.py 6 | By kyubyong park. kbpark.linguist@gmail.com. 7 | https://www.github.com/kyubyong/g2p 8 | ''' 9 | from __future__ import print_function 10 | import inflect 11 | import re 12 | 13 | 14 | 15 | _inflect = inflect.engine() 16 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 17 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 18 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 19 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 20 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 21 | _number_re = re.compile(r'[0-9]+') 22 | 23 | 24 | def _remove_commas(m): 25 | return m.group(1).replace(',', '') 26 | 27 | 28 | def _expand_decimal_point(m): 29 | return m.group(1).replace('.', ' point ') 30 | 31 | 32 | def _expand_dollars(m): 33 | match = m.group(1) 34 | parts = match.split('.') 35 | if len(parts) > 2: 36 | return match + ' dollars' # Unexpected format 37 | dollars = int(parts[0]) if parts[0] else 0 38 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 39 | if dollars and cents: 40 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 41 | cent_unit = 'cent' if cents == 1 else 'cents' 42 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 43 | elif dollars: 44 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 45 | return '%s %s' % (dollars, dollar_unit) 46 | elif cents: 47 | cent_unit = 'cent' if cents == 1 else 'cents' 48 | return '%s %s' % (cents, cent_unit) 49 | else: 50 | return 'zero dollars' 51 | 52 | 53 | def _expand_ordinal(m): 54 | return _inflect.number_to_words(m.group(0)) 55 | 56 | 57 | def _expand_number(m): 58 | num = int(m.group(0)) 59 | if num > 1000 and num < 3000: 60 | if num == 2000: 61 | return 'two thousand' 62 | elif num > 2000 and num < 2010: 63 | return 'two thousand ' + _inflect.number_to_words(num % 100) 64 | elif num % 100 == 0: 65 | return _inflect.number_to_words(num // 100) + ' hundred' 66 | else: 67 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 68 | else: 69 | return _inflect.number_to_words(num, andword='') 70 | 71 | 72 | def normalize_numbers(text): 73 | text = re.sub(_comma_number_re, _remove_commas, text) 74 | text = re.sub(_pounds_re, r'\1 pounds', text) 75 | text = re.sub(_dollars_re, _expand_dollars, text) 76 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 77 | text = re.sub(_ordinal_re, _expand_ordinal, text) 78 | text = re.sub(_number_re, _expand_number, text) 79 | return text 80 | -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_0/events.out.tfevents.1703650022.great-server24.1715931.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_0/events.out.tfevents.1703650022.great-server24.1715931.0 -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_0/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_1/events.out.tfevents.1703650063.great-server24.1718627.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_1/events.out.tfevents.1703650063.great-server24.1718627.0 -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_1/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_2/events.out.tfevents.1703650148.great-server24.1718627.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_2/events.out.tfevents.1703650148.great-server24.1718627.1 -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_2/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_3/events.out.tfevents.1703651214.great-server24.1769388.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_3/events.out.tfevents.1703651214.great-server24.1769388.0 -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_3/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_4/events.out.tfevents.1703651335.great-server24.1769388.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_4/events.out.tfevents.1703651335.great-server24.1769388.1 -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_4/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_5/events.out.tfevents.1703651958.great-server24.1794497.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_5/events.out.tfevents.1703651958.great-server24.1794497.0 -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_5/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_6/events.out.tfevents.1703652071.great-server24.1794497.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_6/events.out.tfevents.1703652071.great-server24.1794497.1 -------------------------------------------------------------------------------- /dataloaders/g2p/lightning_logs/version_6/hparams.yaml: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /dataloaders/libriphrase_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | import re 15 | 16 | class LibriPhrase_Test_Dataset(Dataset): 17 | def __init__( 18 | self, 19 | test_dir="/nvme01/aizq/mmkws/datasets/LibriPhrase_Test", 20 | csv=[ 21 | "libriphrase_diffspk_all_1word.csv", 22 | "libriphrase_diffspk_all_2word.csv", 23 | "libriphrase_diffspk_all_3word.csv", 24 | "libriphrase_diffspk_all_4word.csv" 25 | ], 26 | save_path='/nvme01/aizq/mmkws/datasets/LibriPhrase_Test/test_phrase.csv', 27 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/test_text_embeddings.pickle", 28 | preprocess=False, 29 | types='easy' 30 | ): 31 | if preprocess: 32 | self.data = pd.DataFrame(columns=['Query_text', 'Query_wav', 'Query_dur', 'Support_text', 'Support_wav', 'Support_dur', 'label', 'type']) 33 | for path in csv: 34 | n_word = os.path.join(test_dir, path) 35 | df = pd.read_csv(n_word) 36 | anc = df[['anchor_text', 'anchor', 'anchor_dur', 'comparison_text', 'comparison', 'comparison_dur', 'target', 'type']] 37 | com = df[['comparison_text', 'comparison', 'comparison_dur', 'anchor_text', 'anchor', 'anchor_dur', 'target', 'type']] 38 | self.data = self.data._append(anc.rename(columns={y: x for x, y in zip(self.data.columns, anc.columns)}), ignore_index=True) 39 | self.data = self.data._append(com.rename(columns={y: x for x, y in zip(self.data.columns, com.columns)}), ignore_index=True) 40 | self.data.to_csv(save_path, index=False) 41 | else: 42 | self.data = pd.read_csv(save_path) 43 | # print(self.data)/ 44 | # self.data['dist'] = self.data.apply(lambda x: Levenshtein.ratio(re.sub(r"[^a-zA-Z0-9]+", ' ', x['Support_text']), re.sub(r"[^a-zA-Z0-9]+", ' ', x['Query_text'])), axis=1) 45 | if types == 'easy': 46 | self.data = self.data.loc[self.data['type'].isin(['diffspk_easyneg', 'diffspk_positive'])] 47 | elif types == 'hard': 48 | self.data = self.data.loc[self.data['type'].isin(['diffspk_hardneg', 'diffspk_positive'])] 49 | 50 | 51 | self.data = self.data.values.tolist() 52 | # self.data = self.data[:1000] 53 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 54 | self.test_dir = test_dir 55 | 56 | def __getitem__( 57 | self, 58 | index 59 | ): 60 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label 61 | Query_text, Query_wav, _, Support_text, Support_wav, _, label, _ = self.data[index] 62 | # print(Query_text, Query_wav, Support_text, Support_wav, label) 63 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank 64 | phoneme = self.text_embedder[Support_text]['phoneme'] 65 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 66 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 67 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '.npy') 68 | fbank_feature = fbank( 69 | Query_wav, 70 | num_mel_bins=80 71 | ) 72 | g2p_embed = torch.from_numpy(g2p_embed) 73 | g2p_embed = g2p_embed.type_as(fbank_feature) 74 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label) 75 | 76 | 77 | def __len__( 78 | self 79 | ): 80 | return len(self.data) 81 | 82 | 83 | def collate_fn(batch): 84 | # 将 batch 中的每个样本按照其数据类型分组 85 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch) 86 | # 对每个特征进行填充 87 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 88 | lengths = [len(seq) for seq in padded_fbank_feature] 89 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 90 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 91 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0) 92 | # 对 label 进行转换为 Tensor 93 | label_tensor = torch.tensor(label) 94 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor 95 | 96 | 97 | 98 | 99 | 100 | # test_data = LibriPhrase_Test_Dataset() 101 | # dataloader = DataLoader(test_data, batch_size=128, collate_fn=collate_fn, num_workers=16, shuffle=False) 102 | # from tqdm import tqdm 103 | # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)): 104 | # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data 105 | # print(dist_tensor) 106 | # break 107 | # # pass 108 | # pass -------------------------------------------------------------------------------- /dataloaders/libriphrase_test_18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | import re 15 | 16 | class LibriPhrase_Test_Dataset(Dataset): 17 | def __init__( 18 | self, 19 | test_dir="/nvme01/aizq/mmkws/datasets/LibriPhrase_Test", 20 | csv=[ 21 | "libriphrase_diffspk_all_1word.csv", 22 | "libriphrase_diffspk_all_2word.csv", 23 | "libriphrase_diffspk_all_3word.csv", 24 | "libriphrase_diffspk_all_4word.csv" 25 | ], 26 | save_path='/nvme01/aizq/mmkws/datasets/LibriPhrase_Test/test_phrase.csv', 27 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/test_text_embeddings.pickle", 28 | preprocess=False, 29 | types='easy' 30 | ): 31 | if preprocess: 32 | self.data = pd.DataFrame(columns=['Query_text', 'Query_wav', 'Query_dur', 'Support_text', 'Support_wav', 'Support_dur', 'label', 'type']) 33 | for path in csv: 34 | n_word = os.path.join(test_dir, path) 35 | df = pd.read_csv(n_word) 36 | anc = df[['anchor_text', 'anchor', 'anchor_dur', 'comparison_text', 'comparison', 'comparison_dur', 'target', 'type']] 37 | com = df[['comparison_text', 'comparison', 'comparison_dur', 'anchor_text', 'anchor', 'anchor_dur', 'target', 'type']] 38 | self.data = self.data._append(anc.rename(columns={y: x for x, y in zip(self.data.columns, anc.columns)}), ignore_index=True) 39 | self.data = self.data._append(com.rename(columns={y: x for x, y in zip(self.data.columns, com.columns)}), ignore_index=True) 40 | self.data.to_csv(save_path, index=False) 41 | else: 42 | self.data = pd.read_csv(save_path) 43 | # print(self.data)/ 44 | # self.data['dist'] = self.data.apply(lambda x: Levenshtein.ratio(re.sub(r"[^a-zA-Z0-9]+", ' ', x['Support_text']), re.sub(r"[^a-zA-Z0-9]+", ' ', x['Query_text'])), axis=1) 45 | if types == 'easy': 46 | self.data = self.data.loc[self.data['type'].isin(['diffspk_easyneg', 'diffspk_positive'])] 47 | elif types == 'hard': 48 | self.data = self.data.loc[self.data['type'].isin(['diffspk_hardneg', 'diffspk_positive'])] 49 | 50 | 51 | self.data = self.data.values.tolist() 52 | # self.data = self.data[:1000] 53 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 54 | self.test_dir = test_dir 55 | 56 | def __getitem__( 57 | self, 58 | index 59 | ): 60 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label 61 | Query_text, Query_wav, _, Support_text, Support_wav, _, label, _ = self.data[index] 62 | # print(Query_text, Query_wav, Support_text, Support_wav, label) 63 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank 64 | phoneme = self.text_embedder[Support_text]['phoneme'] 65 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 66 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 67 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '_18.npy') 68 | fbank_feature = fbank( 69 | Query_wav, 70 | num_mel_bins=80 71 | ) 72 | g2p_embed = torch.from_numpy(g2p_embed) 73 | g2p_embed = g2p_embed.type_as(fbank_feature) 74 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label) 75 | 76 | 77 | def __len__( 78 | self 79 | ): 80 | return len(self.data) 81 | 82 | 83 | def collate_fn(batch): 84 | # 将 batch 中的每个样本按照其数据类型分组 85 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch) 86 | # 对每个特征进行填充 87 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 88 | lengths = [len(seq) for seq in padded_fbank_feature] 89 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 90 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 91 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0) 92 | # 对 label 进行转换为 Tensor 93 | label_tensor = torch.tensor(label) 94 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor 95 | 96 | 97 | 98 | 99 | 100 | # test_data = LibriPhrase_Test_Dataset() 101 | # dataloader = DataLoader(test_data, batch_size=128, collate_fn=collate_fn, num_workers=16, shuffle=False) 102 | # from tqdm import tqdm 103 | # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)): 104 | # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data 105 | # print(dist_tensor) 106 | # break 107 | # # pass 108 | # pass -------------------------------------------------------------------------------- /dataloaders/wenetphrase_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | import re 15 | 16 | class WenetPhrase_Test_Dataset(Dataset): 17 | def __init__( 18 | self, 19 | test_dir="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/WenetPhrase2/S", 20 | csv="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/wenetphrase_test.csv", 21 | test_text_embedding="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/zh_test_text_embeddings.pickle", 22 | types='easy' 23 | ): 24 | self.data = pd.read_csv(csv) 25 | if types == 'easy': 26 | self.data = self.data.loc[self.data['type'].isin(['easy', 'pos'])] 27 | elif types == 'hard': 28 | self.data = self.data.loc[self.data['type'].isin(['hard', 'pos'])] 29 | self.data = self.data.values.tolist() 30 | print(len(self.data)) 31 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 32 | self.test_dir = test_dir 33 | 34 | def __getitem__( 35 | self, 36 | index 37 | ): 38 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label 39 | Query_text, Query_wav, Support_text, Support_wav, label, _ = self.data[index] 40 | # print(Query_text, Query_wav, Support_text, Support_wav, label) 41 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank 42 | phoneme = self.text_embedder[Support_text]['phoneme'] 43 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 44 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 45 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '_18.npy') 46 | fbank_feature = fbank( 47 | Query_wav, 48 | num_mel_bins=80 49 | ) 50 | g2p_embed = torch.from_numpy(g2p_embed) 51 | g2p_embed = g2p_embed.type_as(fbank_feature) 52 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label) 53 | 54 | 55 | def __len__( 56 | self 57 | ): 58 | return len(self.data) 59 | 60 | 61 | def collate_fn(batch): 62 | # 将 batch 中的每个样本按照其数据类型分组 63 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch) 64 | # 对每个特征进行填充 65 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 66 | lengths = [len(seq) for seq in padded_fbank_feature] 67 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 68 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 69 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0) 70 | # 对 label 进行转换为 Tensor 71 | label_tensor = torch.tensor(label) 72 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor 73 | 74 | 75 | 76 | 77 | 78 | # test_data = WenetPhrase_Test_Dataset() 79 | # print(test_data[0]) 80 | # # dataloader = DataLoader(test_data, batch_size=2, collate_fn=collate_fn, num_workers=1, shuffle=False) 81 | # # from tqdm import tqdm 82 | # # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)): 83 | # # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data 84 | # # print(dist_tensor) 85 | # # break 86 | # # # pass 87 | # # pass -------------------------------------------------------------------------------- /g2p/g2p_en/__init__.py: -------------------------------------------------------------------------------- 1 | from .g2p import G2p 2 | -------------------------------------------------------------------------------- /g2p/g2p_en/checkpoint20.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/g2p/g2p_en/checkpoint20.npz -------------------------------------------------------------------------------- /g2p/g2p_en/expand.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | Borrowed 5 | from https://github.com/keithito/tacotron/blob/master/text/numbers.py 6 | By kyubyong park. kbpark.linguist@gmail.com. 7 | https://www.github.com/kyubyong/g2p 8 | ''' 9 | from __future__ import print_function 10 | import inflect 11 | import re 12 | 13 | 14 | 15 | _inflect = inflect.engine() 16 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 17 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 18 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 19 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 20 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 21 | _number_re = re.compile(r'[0-9]+') 22 | 23 | 24 | def _remove_commas(m): 25 | return m.group(1).replace(',', '') 26 | 27 | 28 | def _expand_decimal_point(m): 29 | return m.group(1).replace('.', ' point ') 30 | 31 | 32 | def _expand_dollars(m): 33 | match = m.group(1) 34 | parts = match.split('.') 35 | if len(parts) > 2: 36 | return match + ' dollars' # Unexpected format 37 | dollars = int(parts[0]) if parts[0] else 0 38 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 39 | if dollars and cents: 40 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 41 | cent_unit = 'cent' if cents == 1 else 'cents' 42 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 43 | elif dollars: 44 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 45 | return '%s %s' % (dollars, dollar_unit) 46 | elif cents: 47 | cent_unit = 'cent' if cents == 1 else 'cents' 48 | return '%s %s' % (cents, cent_unit) 49 | else: 50 | return 'zero dollars' 51 | 52 | 53 | def _expand_ordinal(m): 54 | return _inflect.number_to_words(m.group(0)) 55 | 56 | 57 | def _expand_number(m): 58 | num = int(m.group(0)) 59 | if num > 1000 and num < 3000: 60 | if num == 2000: 61 | return 'two thousand' 62 | elif num > 2000 and num < 2010: 63 | return 'two thousand ' + _inflect.number_to_words(num % 100) 64 | elif num % 100 == 0: 65 | return _inflect.number_to_words(num // 100) + ' hundred' 66 | else: 67 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 68 | else: 69 | return _inflect.number_to_words(num, andword='') 70 | 71 | 72 | def normalize_numbers(text): 73 | text = re.sub(_comma_number_re, _remove_commas, text) 74 | text = re.sub(_pounds_re, r'\1 pounds', text) 75 | text = re.sub(_dollars_re, _expand_dollars, text) 76 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 77 | text = re.sub(_ordinal_re, _expand_ordinal, text) 78 | text = re.sub(_number_re, _expand_number, text) 79 | return text 80 | -------------------------------------------------------------------------------- /libriphrase_hardneg.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/libriphrase_hardneg.json.zip -------------------------------------------------------------------------------- /mm-kws/README.md: -------------------------------------------------------------------------------- 1 | # MM-KWS: Multi-modal Prompts for Multilingual User-defined Keyword Spotting(Updating) 2 | 3 | ### Note 4 | 1. Code for the paper 'MM-KWS: Multi-modal Prompts for Multilingual User-defined Keyword Spotting', Interspeech 2024 accepted 5 | 2. Arxiv: https://arxiv.org/pdf/2406.07310 6 | 3. The code-version1 7 | 4. WenetPrase hardneg-data & Libriphrase hardneg-data 8 | 5. DataAug data(todo) 9 | 10 | --- 11 | ### Performance 12 | #### 1.1 Performance on LibriPhrase 13 | 14 | 15 | #### 1.2 Performance on WenetPhrase 16 | 17 | 18 | #### 2.Zero-shot performance on Speech Command 19 | 20 | 21 | #### 3. Few-shot performance on wake-up word(snips) 22 | Note: I wrote about the wake word experiment for my master thesis, which was deleted from the Interspeech manuscript due to space limit. 23 | 24 | 25 | 26 | --- 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /mm-kws/conformer/conformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .model import Conformer 16 | -------------------------------------------------------------------------------- /mm-kws/conformer/conformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | from torch import Tensor 17 | 18 | 19 | class Swish(nn.Module): 20 | """ 21 | Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied 22 | to a variety of challenging domains such as Image classification and Machine translation. 23 | """ 24 | def __init__(self): 25 | super(Swish, self).__init__() 26 | 27 | def forward(self, inputs: Tensor) -> Tensor: 28 | return inputs * inputs.sigmoid() 29 | 30 | 31 | class GLU(nn.Module): 32 | """ 33 | The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing 34 | in the paper “Language Modeling with Gated Convolutional Networks” 35 | """ 36 | def __init__(self, dim: int) -> None: 37 | super(GLU, self).__init__() 38 | self.dim = dim 39 | 40 | def forward(self, inputs: Tensor) -> Tensor: 41 | outputs, gate = inputs.chunk(2, dim=self.dim) 42 | return outputs * gate.sigmoid() 43 | -------------------------------------------------------------------------------- /mm-kws/conformer/conformer/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from torch import Tensor 20 | from typing import Optional 21 | 22 | from .embedding import PositionalEncoding 23 | from .modules import Linear 24 | 25 | 26 | class RelativeMultiHeadAttention(nn.Module): 27 | """ 28 | Multi-head attention with relative positional encoding. 29 | This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" 30 | 31 | Args: 32 | d_model (int): The dimension of model 33 | num_heads (int): The number of attention heads. 34 | dropout_p (float): probability of dropout 35 | 36 | Inputs: query, key, value, pos_embedding, mask 37 | - **query** (batch, time, dim): Tensor containing query vector 38 | - **key** (batch, time, dim): Tensor containing key vector 39 | - **value** (batch, time, dim): Tensor containing value vector 40 | - **pos_embedding** (batch, time, dim): Positional embedding tensor 41 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked 42 | 43 | Returns: 44 | - **outputs**: Tensor produces by relative multi head attention module. 45 | """ 46 | def __init__( 47 | self, 48 | d_model: int = 512, 49 | num_heads: int = 16, 50 | dropout_p: float = 0.1, 51 | ): 52 | super(RelativeMultiHeadAttention, self).__init__() 53 | assert d_model % num_heads == 0, "d_model % num_heads should be zero." 54 | self.d_model = d_model 55 | self.d_head = int(d_model / num_heads) 56 | self.num_heads = num_heads 57 | self.sqrt_dim = math.sqrt(d_model) 58 | 59 | self.query_proj = Linear(d_model, d_model) 60 | self.key_proj = Linear(d_model, d_model) 61 | self.value_proj = Linear(d_model, d_model) 62 | self.pos_proj = Linear(d_model, d_model, bias=False) 63 | 64 | self.dropout = nn.Dropout(p=dropout_p) 65 | self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) 66 | self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head)) 67 | torch.nn.init.xavier_uniform_(self.u_bias) 68 | torch.nn.init.xavier_uniform_(self.v_bias) 69 | 70 | self.out_proj = Linear(d_model, d_model) 71 | 72 | def forward( 73 | self, 74 | query: Tensor, 75 | key: Tensor, 76 | value: Tensor, 77 | pos_embedding: Tensor, 78 | mask: Optional[Tensor] = None, 79 | ) -> Tensor: 80 | batch_size = value.size(0) 81 | 82 | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) 83 | key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) 84 | value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) 85 | pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head) 86 | 87 | content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3)) 88 | pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1)) 89 | pos_score = self._relative_shift(pos_score) 90 | 91 | score = (content_score + pos_score) / self.sqrt_dim 92 | 93 | if mask is not None: 94 | mask = mask.unsqueeze(1) 95 | score.masked_fill_(mask, -1e9) 96 | 97 | attn = F.softmax(score, -1) 98 | attn = self.dropout(attn) 99 | 100 | context = torch.matmul(attn, value).transpose(1, 2) 101 | context = context.contiguous().view(batch_size, -1, self.d_model) 102 | 103 | return self.out_proj(context) 104 | 105 | def _relative_shift(self, pos_score: Tensor) -> Tensor: 106 | batch_size, num_heads, seq_length1, seq_length2 = pos_score.size() 107 | zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1) 108 | padded_pos_score = torch.cat([zeros, pos_score], dim=-1) 109 | 110 | padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) 111 | pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) 112 | 113 | return pos_score 114 | 115 | 116 | class MultiHeadedSelfAttentionModule(nn.Module): 117 | """ 118 | Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL, 119 | the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention 120 | module to generalize better on different input length and the resulting encoder is more robust to the variance of 121 | the utterance length. Conformer use prenorm residual units with dropout which helps training 122 | and regularizing deeper models. 123 | 124 | Args: 125 | d_model (int): The dimension of model 126 | num_heads (int): The number of attention heads. 127 | dropout_p (float): probability of dropout 128 | 129 | Inputs: inputs, mask 130 | - **inputs** (batch, time, dim): Tensor containing input vector 131 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked 132 | 133 | Returns: 134 | - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module. 135 | """ 136 | def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1): 137 | super(MultiHeadedSelfAttentionModule, self).__init__() 138 | self.positional_encoding = PositionalEncoding(d_model) 139 | self.layer_norm = nn.LayerNorm(d_model) 140 | self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p) 141 | self.dropout = nn.Dropout(p=dropout_p) 142 | 143 | def forward(self, inputs: Tensor, mask: Optional[Tensor] = None): 144 | batch_size, seq_length, _ = inputs.size() 145 | pos_embedding = self.positional_encoding(seq_length) 146 | pos_embedding = pos_embedding.repeat(batch_size, 1, 1) 147 | 148 | inputs = self.layer_norm(inputs) 149 | outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask) 150 | 151 | return self.dropout(outputs) 152 | -------------------------------------------------------------------------------- /mm-kws/conformer/conformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from typing import Tuple 19 | 20 | from .activation import Swish, GLU 21 | from .modules import Transpose 22 | 23 | 24 | class DepthwiseConv1d(nn.Module): 25 | """ 26 | When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, 27 | this operation is termed in literature as depthwise convolution. 28 | 29 | Args: 30 | in_channels (int): Number of channels in the input 31 | out_channels (int): Number of channels produced by the convolution 32 | kernel_size (int or tuple): Size of the convolving kernel 33 | stride (int, optional): Stride of the convolution. Default: 1 34 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 35 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True 36 | 37 | Inputs: inputs 38 | - **inputs** (batch, in_channels, time): Tensor containing input vector 39 | 40 | Returns: outputs 41 | - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. 42 | """ 43 | def __init__( 44 | self, 45 | in_channels: int, 46 | out_channels: int, 47 | kernel_size: int, 48 | stride: int = 1, 49 | padding: int = 0, 50 | bias: bool = False, 51 | ) -> None: 52 | super(DepthwiseConv1d, self).__init__() 53 | assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels" 54 | self.conv = nn.Conv1d( 55 | in_channels=in_channels, 56 | out_channels=out_channels, 57 | kernel_size=kernel_size, 58 | groups=in_channels, 59 | stride=stride, 60 | padding=padding, 61 | bias=bias, 62 | ) 63 | 64 | def forward(self, inputs: Tensor) -> Tensor: 65 | return self.conv(inputs) 66 | 67 | 68 | class PointwiseConv1d(nn.Module): 69 | """ 70 | When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution. 71 | This operation often used to match dimensions. 72 | 73 | Args: 74 | in_channels (int): Number of channels in the input 75 | out_channels (int): Number of channels produced by the convolution 76 | stride (int, optional): Stride of the convolution. Default: 1 77 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 78 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True 79 | 80 | Inputs: inputs 81 | - **inputs** (batch, in_channels, time): Tensor containing input vector 82 | 83 | Returns: outputs 84 | - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution. 85 | """ 86 | def __init__( 87 | self, 88 | in_channels: int, 89 | out_channels: int, 90 | stride: int = 1, 91 | padding: int = 0, 92 | bias: bool = True, 93 | ) -> None: 94 | super(PointwiseConv1d, self).__init__() 95 | self.conv = nn.Conv1d( 96 | in_channels=in_channels, 97 | out_channels=out_channels, 98 | kernel_size=1, 99 | stride=stride, 100 | padding=padding, 101 | bias=bias, 102 | ) 103 | 104 | def forward(self, inputs: Tensor) -> Tensor: 105 | return self.conv(inputs) 106 | 107 | 108 | class ConformerConvModule(nn.Module): 109 | """ 110 | Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). 111 | This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution 112 | to aid training deep models. 113 | 114 | Args: 115 | in_channels (int): Number of channels in the input 116 | kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31 117 | dropout_p (float, optional): probability of dropout 118 | 119 | Inputs: inputs 120 | inputs (batch, time, dim): Tensor contains input sequences 121 | 122 | Outputs: outputs 123 | outputs (batch, time, dim): Tensor produces by conformer convolution module. 124 | """ 125 | def __init__( 126 | self, 127 | in_channels: int, 128 | kernel_size: int = 31, 129 | expansion_factor: int = 2, 130 | dropout_p: float = 0.1, 131 | ) -> None: 132 | super(ConformerConvModule, self).__init__() 133 | assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" 134 | assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2" 135 | 136 | self.sequential = nn.Sequential( 137 | nn.LayerNorm(in_channels), 138 | Transpose(shape=(1, 2)), 139 | PointwiseConv1d(in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True), 140 | GLU(dim=1), 141 | DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2), 142 | nn.BatchNorm1d(in_channels), 143 | Swish(), 144 | PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True), 145 | nn.Dropout(p=dropout_p), 146 | ) 147 | 148 | def forward(self, inputs: Tensor) -> Tensor: 149 | return self.sequential(inputs).transpose(1, 2) 150 | 151 | 152 | class Conv2dSubampling(nn.Module): 153 | """ 154 | Convolutional 2D subsampling (to 1/4 length) 155 | 156 | Args: 157 | in_channels (int): Number of channels in the input image 158 | out_channels (int): Number of channels produced by the convolution 159 | 160 | Inputs: inputs 161 | - **inputs** (batch, time, dim): Tensor containing sequence of inputs 162 | 163 | Returns: outputs, output_lengths 164 | - **outputs** (batch, time, dim): Tensor produced by the convolution 165 | - **output_lengths** (batch): list of sequence output lengths 166 | """ 167 | def __init__(self, in_channels: int, out_channels: int) -> None: 168 | super(Conv2dSubampling, self).__init__() 169 | self.sequential = nn.Sequential( 170 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2), 171 | nn.ReLU(), 172 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2), 173 | nn.ReLU(), 174 | ) 175 | 176 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: 177 | outputs = self.sequential(inputs.unsqueeze(1)) 178 | batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size() 179 | 180 | outputs = outputs.permute(0, 2, 1, 3) 181 | outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim) 182 | 183 | output_lengths = input_lengths >> 2 184 | output_lengths -= 1 185 | 186 | return outputs, output_lengths 187 | -------------------------------------------------------------------------------- /mm-kws/conformer/conformer/embedding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import torch 17 | import torch.nn as nn 18 | from torch import Tensor 19 | 20 | 21 | class PositionalEncoding(nn.Module): 22 | """ 23 | Positional Encoding proposed in "Attention Is All You Need". 24 | Since transformer contains no recurrence and no convolution, in order for the model to make 25 | use of the order of the sequence, we must add some positional information. 26 | 27 | "Attention Is All You Need" use sine and cosine functions of different frequencies: 28 | PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) 29 | PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model)) 30 | """ 31 | def __init__(self, d_model: int = 512, max_len: int = 10000) -> None: 32 | super(PositionalEncoding, self).__init__() 33 | pe = torch.zeros(max_len, d_model, requires_grad=False) 34 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 35 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 36 | pe[:, 0::2] = torch.sin(position * div_term) 37 | pe[:, 1::2] = torch.cos(position * div_term) 38 | pe = pe.unsqueeze(0) 39 | self.register_buffer('pe', pe) 40 | 41 | def forward(self, length: int) -> Tensor: 42 | return self.pe[:, :length] -------------------------------------------------------------------------------- /mm-kws/conformer/conformer/feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | 19 | from .activation import Swish 20 | from .modules import Linear 21 | 22 | 23 | class FeedForwardModule(nn.Module): 24 | """ 25 | Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit 26 | and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps 27 | regularizing the network. 28 | 29 | Args: 30 | encoder_dim (int): Dimension of conformer encoder 31 | expansion_factor (int): Expansion factor of feed forward module. 32 | dropout_p (float): Ratio of dropout 33 | 34 | Inputs: inputs 35 | - **inputs** (batch, time, dim): Tensor contains input sequences 36 | 37 | Outputs: outputs 38 | - **outputs** (batch, time, dim): Tensor produces by feed forward module. 39 | """ 40 | def __init__( 41 | self, 42 | encoder_dim: int = 512, 43 | expansion_factor: int = 4, 44 | dropout_p: float = 0.1, 45 | ) -> None: 46 | super(FeedForwardModule, self).__init__() 47 | self.sequential = nn.Sequential( 48 | nn.LayerNorm(encoder_dim), 49 | Linear(encoder_dim, encoder_dim * expansion_factor, bias=True), 50 | Swish(), 51 | nn.Dropout(p=dropout_p), 52 | Linear(encoder_dim * expansion_factor, encoder_dim, bias=True), 53 | nn.Dropout(p=dropout_p), 54 | ) 55 | 56 | def forward(self, inputs: Tensor) -> Tensor: 57 | return self.sequential(inputs) 58 | -------------------------------------------------------------------------------- /mm-kws/conformer/conformer/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from typing import Tuple 19 | 20 | from .encoder import ConformerEncoder 21 | from .modules import Linear 22 | 23 | 24 | class Conformer(nn.Module): 25 | """ 26 | Conformer: Convolution-augmented Transformer for Speech Recognition 27 | The paper used a one-lstm Transducer decoder, currently still only implemented 28 | the conformer encoder shown in the paper. 29 | 30 | Args: 31 | num_classes (int): Number of classification classes 32 | input_dim (int, optional): Dimension of input vector 33 | encoder_dim (int, optional): Dimension of conformer encoder 34 | num_encoder_layers (int, optional): Number of conformer blocks 35 | num_attention_heads (int, optional): Number of attention heads 36 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module 37 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module 38 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout 39 | attention_dropout_p (float, optional): Probability of attention module dropout 40 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout 41 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel 42 | half_step_residual (bool): Flag indication whether to use half step residual or not 43 | 44 | Inputs: inputs, input_lengths 45 | - **inputs** (batch, time, dim): Tensor containing input vector 46 | - **input_lengths** (batch): list of sequence input lengths 47 | 48 | Returns: outputs, output_lengths 49 | - **outputs** (batch, out_channels, time): Tensor produces by conformer. 50 | - **output_lengths** (batch): list of sequence output lengths 51 | """ 52 | def __init__( 53 | self, 54 | num_classes: int, 55 | input_dim: int = 80, 56 | encoder_dim: int = 512, 57 | num_encoder_layers: int = 17, 58 | num_attention_heads: int = 8, 59 | feed_forward_expansion_factor: int = 4, 60 | conv_expansion_factor: int = 2, 61 | input_dropout_p: float = 0.1, 62 | feed_forward_dropout_p: float = 0.1, 63 | attention_dropout_p: float = 0.1, 64 | conv_dropout_p: float = 0.1, 65 | conv_kernel_size: int = 31, 66 | half_step_residual: bool = True, 67 | ) -> None: 68 | super(Conformer, self).__init__() 69 | self.encoder = ConformerEncoder( 70 | input_dim=input_dim, 71 | encoder_dim=encoder_dim, 72 | num_layers=num_encoder_layers, 73 | num_attention_heads=num_attention_heads, 74 | feed_forward_expansion_factor=feed_forward_expansion_factor, 75 | conv_expansion_factor=conv_expansion_factor, 76 | input_dropout_p=input_dropout_p, 77 | feed_forward_dropout_p=feed_forward_dropout_p, 78 | attention_dropout_p=attention_dropout_p, 79 | conv_dropout_p=conv_dropout_p, 80 | conv_kernel_size=conv_kernel_size, 81 | half_step_residual=half_step_residual, 82 | ) 83 | self.fc = Linear(encoder_dim, num_classes, bias=False) 84 | 85 | def count_parameters(self) -> int: 86 | """ Count parameters of encoder """ 87 | return self.encoder.count_parameters() 88 | 89 | def update_dropout(self, dropout_p) -> None: 90 | """ Update dropout probability of model """ 91 | self.encoder.update_dropout(dropout_p) 92 | 93 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: 94 | """ 95 | Forward propagate a `inputs` and `targets` pair for training. 96 | 97 | Args: 98 | inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded 99 | `FloatTensor` of size ``(batch, seq_length, dimension)``. 100 | input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` 101 | 102 | Returns: 103 | * predictions (torch.FloatTensor): Result of model predictions. 104 | """ 105 | encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths) 106 | outputs = self.fc(encoder_outputs) 107 | outputs = nn.functional.log_softmax(outputs, dim=-1) 108 | return outputs, encoder_output_lengths 109 | -------------------------------------------------------------------------------- /mm-kws/conformer/conformer/model_def.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | from typing import Tuple 19 | 20 | from .encoder import ConformerEncoder 21 | from .modules import Linear 22 | 23 | 24 | class Conformer(nn.Module): 25 | """ 26 | Conformer: Convolution-augmented Transformer for Speech Recognition 27 | The paper used a one-lstm Transducer decoder, currently still only implemented 28 | the conformer encoder shown in the paper. 29 | 30 | Args: 31 | num_classes (int): Number of classification classes 32 | input_dim (int, optional): Dimension of input vector 33 | encoder_dim (int, optional): Dimension of conformer encoder 34 | num_encoder_layers (int, optional): Number of conformer blocks 35 | num_attention_heads (int, optional): Number of attention heads 36 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module 37 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module 38 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout 39 | attention_dropout_p (float, optional): Probability of attention module dropout 40 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout 41 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel 42 | half_step_residual (bool): Flag indication whether to use half step residual or not 43 | 44 | Inputs: inputs, input_lengths 45 | - **inputs** (batch, time, dim): Tensor containing input vector 46 | - **input_lengths** (batch): list of sequence input lengths 47 | 48 | Returns: outputs, output_lengths 49 | - **outputs** (batch, out_channels, time): Tensor produces by conformer. 50 | - **output_lengths** (batch): list of sequence output lengths 51 | """ 52 | def __init__( 53 | self, 54 | input_dim: int = 80, 55 | encoder_dim: int = 512, 56 | num_encoder_layers: int = 17, 57 | num_attention_heads: int = 8, 58 | feed_forward_expansion_factor: int = 4, 59 | conv_expansion_factor: int = 2, 60 | input_dropout_p: float = 0.1, 61 | feed_forward_dropout_p: float = 0.1, 62 | attention_dropout_p: float = 0.1, 63 | conv_dropout_p: float = 0.1, 64 | conv_kernel_size: int = 31, 65 | half_step_residual: bool = True, 66 | ) -> None: 67 | super(Conformer, self).__init__() 68 | self.encoder = ConformerEncoder( 69 | input_dim=input_dim, 70 | encoder_dim=encoder_dim, 71 | num_layers=num_encoder_layers, 72 | num_attention_heads=num_attention_heads, 73 | feed_forward_expansion_factor=feed_forward_expansion_factor, 74 | conv_expansion_factor=conv_expansion_factor, 75 | input_dropout_p=input_dropout_p, 76 | feed_forward_dropout_p=feed_forward_dropout_p, 77 | attention_dropout_p=attention_dropout_p, 78 | conv_dropout_p=conv_dropout_p, 79 | conv_kernel_size=conv_kernel_size, 80 | half_step_residual=half_step_residual, 81 | ) 82 | 83 | def count_parameters(self) -> int: 84 | """ Count parameters of encoder """ 85 | return self.encoder.count_parameters() 86 | 87 | def update_dropout(self, dropout_p) -> None: 88 | """ Update dropout probability of model """ 89 | self.encoder.update_dropout(dropout_p) 90 | 91 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]: 92 | """ 93 | Forward propagate a `inputs` and `targets` pair for training. 94 | 95 | Args: 96 | inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded 97 | `FloatTensor` of size ``(batch, seq_length, dimension)``. 98 | input_lengths (torch.LongTensor): The length of input tensor. ``(batch)`` 99 | 100 | Returns: 101 | * predictions (torch.FloatTensor): Result of model predictions. 102 | """ 103 | encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths) 104 | return encoder_outputs, encoder_output_lengths 105 | -------------------------------------------------------------------------------- /mm-kws/conformer/conformer/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.init as init 18 | from torch import Tensor 19 | 20 | 21 | class ResidualConnectionModule(nn.Module): 22 | """ 23 | Residual Connection Module. 24 | outputs = (module(inputs) x module_factor + inputs x input_factor) 25 | """ 26 | def __init__(self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0): 27 | super(ResidualConnectionModule, self).__init__() 28 | self.module = module 29 | self.module_factor = module_factor 30 | self.input_factor = input_factor 31 | 32 | def forward(self, inputs: Tensor) -> Tensor: 33 | return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor) 34 | 35 | 36 | class Linear(nn.Module): 37 | """ 38 | Wrapper class of torch.nn.Linear 39 | Weight initialize by xavier initialization and bias initialize to zeros. 40 | """ 41 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None: 42 | super(Linear, self).__init__() 43 | self.linear = nn.Linear(in_features, out_features, bias=bias) 44 | init.xavier_uniform_(self.linear.weight) 45 | if bias: 46 | init.zeros_(self.linear.bias) 47 | 48 | def forward(self, x: Tensor) -> Tensor: 49 | return self.linear(x) 50 | 51 | 52 | class View(nn.Module): 53 | """ Wrapper class of torch.view() for Sequential module. """ 54 | def __init__(self, shape: tuple, contiguous: bool = False): 55 | super(View, self).__init__() 56 | self.shape = shape 57 | self.contiguous = contiguous 58 | 59 | def forward(self, x: Tensor) -> Tensor: 60 | if self.contiguous: 61 | x = x.contiguous() 62 | 63 | return x.view(*self.shape) 64 | 65 | 66 | class Transpose(nn.Module): 67 | """ Wrapper class of torch.transpose() for Sequential module. """ 68 | def __init__(self, shape: tuple): 69 | super(Transpose, self).__init__() 70 | self.shape = shape 71 | 72 | def forward(self, x: Tensor) -> Tensor: 73 | return x.transpose(*self.shape) 74 | -------------------------------------------------------------------------------- /mm-kws/dataloaders/SPC_N0_ALL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | import Levenshtein 15 | import re 16 | import random 17 | 18 | import os 19 | 20 | def get_files(path, endswith): 21 | _files = [] 22 | for root, dirs, files in os.walk(path): 23 | for file in files: 24 | if file.endswith(endswith): 25 | _files.append(os.path.join(root, file)) 26 | return _files 27 | 28 | 29 | class SPC_NO_Dataset(Dataset): 30 | def __init__( 31 | self, 32 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data", 33 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text", 34 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N0_ALL.csv", 35 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle", 36 | preprocess=True, 37 | ): 38 | super().__init__() 39 | if preprocess: 40 | target_dict = {} 41 | idx = 0 42 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Query_label', 'Support_label', 'label']) 43 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list))) 44 | classes = os.listdir(test_dir) 45 | random.shuffle(classes) 46 | Targets = classes[:10] 47 | for wav_idx in range(len(wav_id)): 48 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav' 49 | query_text = wav_id[wav_idx].split('_')[0] 50 | if query_text in Targets: 51 | for comparison_text in Targets: 52 | _label = 1 if comparison_text == query_text else 0 53 | target_dict[idx] = { 54 | 'id': wav_id[wav_idx], 55 | 'Query_text': query_text, 56 | 'Query_wav': wav, 57 | 'Support_text': comparison_text, 58 | 'Query_label': Targets.index(query_text), 59 | 'Support_label': Targets.index(comparison_text), 60 | 'label': _label 61 | } 62 | idx += 1 63 | else: 64 | for comparison_text in Targets: 65 | _label = 0 66 | target_dict[idx] = { 67 | 'id': wav_id[wav_idx], 68 | 'Query_text': query_text, 69 | 'Query_wav': wav, 70 | 'Support_text': comparison_text, 71 | 'Query_label': 10, 72 | 'Support_label': Targets.index(comparison_text), 73 | 'label': 10, 74 | } 75 | idx += 1 76 | 77 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True) 78 | self.data.to_csv(save_path, index=False) 79 | else: 80 | self.data = pd.read_csv(save_path) 81 | 82 | self.data = self.data.values.tolist() 83 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 84 | 85 | 86 | def __getitem__( 87 | self, 88 | index 89 | ): 90 | ids, Query_text, Query_wav, Support_text, Query_label, Support_label, label = self.data[index] 91 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank 92 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 93 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 94 | fbank_feature = fbank( 95 | Query_wav, 96 | num_mel_bins=80 97 | ) 98 | g2p_embed = torch.from_numpy(g2p_embed) 99 | g2p_embed = g2p_embed.type_as(fbank_feature) 100 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label) 101 | 102 | def __len__( 103 | self 104 | ): 105 | return len(self.data) 106 | 107 | 108 | 109 | 110 | 111 | def collate_fn(batch): 112 | ids, fbank_feature, g2p_embed, lm_embed, Query_label, Support_label, label = zip(*batch) 113 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 114 | lengths = [len(seq) for seq in padded_fbank_feature] 115 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 116 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 117 | label_tensor = torch.tensor(label) 118 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, Query_label, Support_label, label_tensor 119 | 120 | 121 | 122 | spc = SPC_NO_Dataset() 123 | spc[0] -------------------------------------------------------------------------------- /mm-kws/dataloaders/SPC_N0_TARGET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | 13 | import torch.nn.functional as F 14 | from torch.nn.utils.rnn import pad_sequence 15 | import Levenshtein 16 | import re 17 | import random 18 | 19 | import os 20 | 21 | def get_files(path, endswith): 22 | _files = [] 23 | for root, dirs, files in os.walk(path): 24 | for file in files: 25 | if file.endswith(endswith): 26 | _files.append(os.path.join(root, file)) 27 | return _files 28 | 29 | 30 | class SPC_NO_Dataset(Dataset): 31 | def __init__( 32 | self, 33 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data", 34 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text", 35 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N0_TARGET.csv", 36 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle", 37 | preprocess=True, 38 | ): 39 | super().__init__() 40 | if preprocess: 41 | target_dict = {} 42 | idx = 0 43 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Query_label', 'Support_label', 'label']) 44 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list))) 45 | classes = os.listdir(test_dir) 46 | random.shuffle(classes) 47 | Targets = classes[:10] 48 | for wav_idx in range(len(wav_id)): 49 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav' 50 | query_text = wav_id[wav_idx].split('_')[0] 51 | if query_text in Targets: 52 | for comparison_text in Targets: 53 | _label = 1 if comparison_text == query_text else 0 54 | target_dict[idx] = { 55 | 'id': wav_id[wav_idx], 56 | 'Query_text': query_text, 57 | 'Query_wav': wav, 58 | 'Support_text': comparison_text, 59 | 'Query_label': Targets.index(query_text), 60 | 'Support_label': Targets.index(comparison_text), 61 | 'label': _label 62 | } 63 | idx += 1 64 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True) 65 | self.data.to_csv(save_path, index=False) 66 | else: 67 | self.data = pd.read_csv(save_path) 68 | 69 | self.data = self.data.values.tolist() 70 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 71 | 72 | 73 | def __getitem__( 74 | self, 75 | index 76 | ): 77 | ids, Query_text, Query_wav, Support_text, Query_label, Support_label, label = self.data[index] 78 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank 79 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 80 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 81 | fbank_feature = fbank( 82 | Query_wav, 83 | num_mel_bins=80 84 | ) 85 | g2p_embed = torch.from_numpy(g2p_embed) 86 | g2p_embed = g2p_embed.type_as(fbank_feature) 87 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label) 88 | 89 | def __len__( 90 | self 91 | ): 92 | return len(self.data) 93 | 94 | 95 | 96 | 97 | 98 | def collate_fn(batch): 99 | ids, fbank_feature, g2p_embed, lm_embed, Query_label, Support_label, label = zip(*batch) 100 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 101 | lengths = [len(seq) for seq in padded_fbank_feature] 102 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 103 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 104 | label_tensor = torch.tensor(label) 105 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, Query_label, Support_label, label_tensor 106 | 107 | 108 | 109 | # spc = SPC_NO_Dataset() 110 | -------------------------------------------------------------------------------- /mm-kws/dataloaders/SPC_N1_ALL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | # import dataaug 13 | import torch.nn.functional as F 14 | from torch.nn.utils.rnn import pad_sequence 15 | import Levenshtein 16 | import re 17 | import random 18 | 19 | import os 20 | 21 | def get_files(path, endswith): 22 | _files = [] 23 | for root, dirs, files in os.walk(path): 24 | for file in files: 25 | if file.endswith(endswith): 26 | _files.append(os.path.join(root, file)) 27 | return _files 28 | from tqdm import tqdm 29 | 30 | class SPC_N1_Dataset(Dataset): 31 | def __init__( 32 | self, 33 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data", 34 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text", 35 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N1_ALL.csv", 36 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle", 37 | preprocess=True, 38 | ): 39 | super().__init__() 40 | if preprocess: 41 | target_dict = {} 42 | idx = 0 43 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Support_wav', 'Query_label', 'Support_label', 'label']) 44 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list))) 45 | classes = os.listdir(test_dir) 46 | random.shuffle(classes) 47 | Targets = classes[:10] 48 | supports_wavs = {} 49 | for comparison_text in Targets: 50 | supports_wav = get_files(os.path.join(test_dir, comparison_text), '_18.npy') 51 | random.shuffle(supports_wav) 52 | supports_wavs[comparison_text] = supports_wav 53 | for wav_idx in range(len(wav_id)): 54 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav' 55 | query_text = wav_id[wav_idx].split('_')[0] 56 | if query_text in Targets: 57 | for comparison_text in Targets: 58 | support_wav = random.choices(supports_wavs[comparison_text], k=2) 59 | _label = 1 if comparison_text == query_text else 0 60 | target_dict[idx] = { 61 | 'id': wav_id[wav_idx], 62 | 'Query_text': query_text, 63 | 'Query_wav': wav, 64 | 'Support_text': comparison_text, 65 | 'Support_wav': support_wav[0], 66 | 'Query_label': Targets.index(query_text), 67 | 'Support_label': Targets.index(comparison_text), 68 | 'label': _label 69 | } 70 | idx += 1 71 | else: 72 | for comparison_text in Targets: 73 | support_wav = random.choices(supports_wavs[comparison_text], k=5) 74 | _label = 0 75 | target_dict[idx] = { 76 | 'id': wav_id[wav_idx], 77 | 'Query_text': query_text, 78 | 'Query_wav': wav, 79 | 'Support_text': comparison_text, 80 | 'Support_wav': support_wav[0], 81 | 'Query_label': 10, 82 | 'Support_label': Targets.index(comparison_text), 83 | 'label': 10, 84 | } 85 | idx += 1 86 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True) 87 | self.data.to_csv(save_path, index=False) 88 | else: 89 | self.data = pd.read_csv(save_path) 90 | 91 | self.data = self.data.values.tolist() 92 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 93 | 94 | 95 | def __getitem__( 96 | self, 97 | index 98 | ): 99 | ids, Query_text, Query_wav, Support_text, Support_wav, Query_label, Support_label, label = self.data[index] 100 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank 101 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 102 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 103 | audiolm_embed = np.load(Support_wav) 104 | fbank_feature = fbank( 105 | Query_wav, 106 | num_mel_bins=80 107 | ) 108 | g2p_embed = torch.from_numpy(g2p_embed) 109 | g2p_embed = g2p_embed.type_as(fbank_feature) 110 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label) 111 | 112 | def __len__( 113 | self 114 | ): 115 | return len(self.data) 116 | 117 | 118 | 119 | 120 | 121 | def collate_fn(batch): 122 | ids, fbank_feature, g2p_embed, lm_embed, audiolm_embed, Query_label, Support_label, label = zip(*batch) 123 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 124 | lengths = [len(seq) for seq in padded_fbank_feature] 125 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 126 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 127 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0) 128 | label_tensor = torch.tensor(label) 129 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, Query_label, Support_label, label_tensor 130 | 131 | 132 | 133 | spc = SPC_N1_Dataset() 134 | # spc[0] -------------------------------------------------------------------------------- /mm-kws/dataloaders/SPC_N1_TARGET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | import Levenshtein 15 | import re 16 | import random 17 | 18 | import os 19 | 20 | def get_files(path, endswith): 21 | _files = [] 22 | for root, dirs, files in os.walk(path): 23 | for file in files: 24 | if file.endswith(endswith): 25 | _files.append(os.path.join(root, file)) 26 | return _files 27 | from tqdm import tqdm 28 | 29 | class SPC_N1_Dataset(Dataset): 30 | def __init__( 31 | self, 32 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data", 33 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text", 34 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N1_TARGET.csv", 35 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle", 36 | preprocess=True, 37 | ): 38 | super().__init__() 39 | if preprocess: 40 | target_dict = {} 41 | idx = 0 42 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Support_wav', 'Query_label', 'Support_label', 'label']) 43 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list))) 44 | classes = os.listdir(test_dir) 45 | random.shuffle(classes) 46 | Targets = classes[:10] 47 | # Targets = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go'] 48 | supports_wavs = {} 49 | for comparison_text in Targets: 50 | supports_wav = get_files(os.path.join(test_dir, comparison_text), '_18.npy') 51 | random.shuffle(supports_wav) 52 | supports_wavs[comparison_text] = supports_wav 53 | for wav_idx in range(len(wav_id)): 54 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav' 55 | query_text = wav_id[wav_idx].split('_')[0] 56 | if query_text in Targets: 57 | for comparison_text in Targets: 58 | _label = 1 if comparison_text == query_text else 0 59 | target_dict[idx] = { 60 | 'id': wav_id[wav_idx], 61 | 'Query_text': query_text, 62 | 'Query_wav': wav, 63 | 'Support_text': comparison_text, 64 | 'Support_wav': supports_wavs[comparison_text][0], 65 | 'Query_label': Targets.index(query_text), 66 | 'Support_label': Targets.index(comparison_text), 67 | 'label': _label 68 | } 69 | idx += 1 70 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True) 71 | self.data.to_csv(save_path, index=False) 72 | else: 73 | self.data = pd.read_csv(save_path) 74 | 75 | self.data = self.data.values.tolist() 76 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 77 | 78 | 79 | def __getitem__( 80 | self, 81 | index 82 | ): 83 | ids, Query_text, Query_wav, Support_text, Support_wav, Query_label, Support_label, label = self.data[index] 84 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank 85 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 86 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 87 | audiolm_embed = np.load(Support_wav) 88 | fbank_feature = fbank( 89 | Query_wav, 90 | num_mel_bins=80 91 | ) 92 | g2p_embed = torch.from_numpy(g2p_embed) 93 | g2p_embed = g2p_embed.type_as(fbank_feature) 94 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label) 95 | 96 | def __len__( 97 | self 98 | ): 99 | return len(self.data) 100 | 101 | 102 | 103 | 104 | 105 | def collate_fn(batch): 106 | ids, fbank_feature, g2p_embed, lm_embed, audiolm_embed, Query_label, Support_label, label = zip(*batch) 107 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 108 | lengths = [len(seq) for seq in padded_fbank_feature] 109 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 110 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 111 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0) 112 | label_tensor = torch.tensor(label) 113 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, Query_label, Support_label, label_tensor 114 | 115 | 116 | 117 | spc = SPC_N1_Dataset() 118 | # spc[0] -------------------------------------------------------------------------------- /mm-kws/dataloaders/__pycache__/SPC_N0_ALL.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/SPC_N0_ALL.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/__pycache__/SPC_N0_TARGET.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/SPC_N0_TARGET.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/__pycache__/SPC_N1_ALL.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/SPC_N1_ALL.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/__pycache__/SPC_N1_TARGET.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/SPC_N1_TARGET.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/__pycache__/libriphrase_test.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/libriphrase_test.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/__pycache__/libriphrase_test_18.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/libriphrase_test_18.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/__pycache__/libriphrase_train.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/libriphrase_train.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/__pycache__/libriphrase_trainY.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/libriphrase_trainY.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/__pycache__/libriphrase_train_18.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/libriphrase_train_18.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/g2p/g2p_en/__init__.py: -------------------------------------------------------------------------------- 1 | from .g2p import G2p 2 | -------------------------------------------------------------------------------- /mm-kws/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-39.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-310.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-39.pyc -------------------------------------------------------------------------------- /mm-kws/dataloaders/g2p/g2p_en/checkpoint20.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/checkpoint20.npz -------------------------------------------------------------------------------- /mm-kws/dataloaders/g2p/g2p_en/expand.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | Borrowed 5 | from https://github.com/keithito/tacotron/blob/master/text/numbers.py 6 | By kyubyong park. kbpark.linguist@gmail.com. 7 | https://www.github.com/kyubyong/g2p 8 | ''' 9 | from __future__ import print_function 10 | import inflect 11 | import re 12 | 13 | 14 | 15 | _inflect = inflect.engine() 16 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 17 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 18 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 19 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 20 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 21 | _number_re = re.compile(r'[0-9]+') 22 | 23 | 24 | def _remove_commas(m): 25 | return m.group(1).replace(',', '') 26 | 27 | 28 | def _expand_decimal_point(m): 29 | return m.group(1).replace('.', ' point ') 30 | 31 | 32 | def _expand_dollars(m): 33 | match = m.group(1) 34 | parts = match.split('.') 35 | if len(parts) > 2: 36 | return match + ' dollars' # Unexpected format 37 | dollars = int(parts[0]) if parts[0] else 0 38 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 39 | if dollars and cents: 40 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 41 | cent_unit = 'cent' if cents == 1 else 'cents' 42 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 43 | elif dollars: 44 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 45 | return '%s %s' % (dollars, dollar_unit) 46 | elif cents: 47 | cent_unit = 'cent' if cents == 1 else 'cents' 48 | return '%s %s' % (cents, cent_unit) 49 | else: 50 | return 'zero dollars' 51 | 52 | 53 | def _expand_ordinal(m): 54 | return _inflect.number_to_words(m.group(0)) 55 | 56 | 57 | def _expand_number(m): 58 | num = int(m.group(0)) 59 | if num > 1000 and num < 3000: 60 | if num == 2000: 61 | return 'two thousand' 62 | elif num > 2000 and num < 2010: 63 | return 'two thousand ' + _inflect.number_to_words(num % 100) 64 | elif num % 100 == 0: 65 | return _inflect.number_to_words(num // 100) + ' hundred' 66 | else: 67 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 68 | else: 69 | return _inflect.number_to_words(num, andword='') 70 | 71 | 72 | def normalize_numbers(text): 73 | text = re.sub(_comma_number_re, _remove_commas, text) 74 | text = re.sub(_pounds_re, r'\1 pounds', text) 75 | text = re.sub(_dollars_re, _expand_dollars, text) 76 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 77 | text = re.sub(_ordinal_re, _expand_ordinal, text) 78 | text = re.sub(_number_re, _expand_number, text) 79 | return text 80 | -------------------------------------------------------------------------------- /mm-kws/dataloaders/libriphrase_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | import re 15 | 16 | class LibriPhrase_Test_Dataset(Dataset): 17 | def __init__( 18 | self, 19 | test_dir="/nvme01/aizq/mmkws/datasets/LibriPhrase_Test", 20 | csv=[ 21 | "libriphrase_diffspk_all_1word.csv", 22 | "libriphrase_diffspk_all_2word.csv", 23 | "libriphrase_diffspk_all_3word.csv", 24 | "libriphrase_diffspk_all_4word.csv" 25 | ], 26 | save_path='/nvme01/aizq/mmkws/datasets/LibriPhrase_Test/test_phrase.csv', 27 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/test_text_embeddings.pickle", 28 | preprocess=False, 29 | types='easy' 30 | ): 31 | if preprocess: 32 | self.data = pd.DataFrame(columns=['Query_text', 'Query_wav', 'Query_dur', 'Support_text', 'Support_wav', 'Support_dur', 'label', 'type']) 33 | for path in csv: 34 | n_word = os.path.join(test_dir, path) 35 | df = pd.read_csv(n_word) 36 | anc = df[['anchor_text', 'anchor', 'anchor_dur', 'comparison_text', 'comparison', 'comparison_dur', 'target', 'type']] 37 | com = df[['comparison_text', 'comparison', 'comparison_dur', 'anchor_text', 'anchor', 'anchor_dur', 'target', 'type']] 38 | self.data = self.data._append(anc.rename(columns={y: x for x, y in zip(self.data.columns, anc.columns)}), ignore_index=True) 39 | self.data = self.data._append(com.rename(columns={y: x for x, y in zip(self.data.columns, com.columns)}), ignore_index=True) 40 | self.data.to_csv(save_path, index=False) 41 | else: 42 | self.data = pd.read_csv(save_path) 43 | # print(self.data)/ 44 | # self.data['dist'] = self.data.apply(lambda x: Levenshtein.ratio(re.sub(r"[^a-zA-Z0-9]+", ' ', x['Support_text']), re.sub(r"[^a-zA-Z0-9]+", ' ', x['Query_text'])), axis=1) 45 | if types == 'easy': 46 | self.data = self.data.loc[self.data['type'].isin(['diffspk_easyneg', 'diffspk_positive'])] 47 | elif types == 'hard': 48 | self.data = self.data.loc[self.data['type'].isin(['diffspk_hardneg', 'diffspk_positive'])] 49 | 50 | 51 | self.data = self.data.values.tolist() 52 | # self.data = self.data[:1000] 53 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 54 | self.test_dir = test_dir 55 | 56 | def __getitem__( 57 | self, 58 | index 59 | ): 60 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label 61 | Query_text, Query_wav, _, Support_text, Support_wav, _, label, _ = self.data[index] 62 | # print(Query_text, Query_wav, Support_text, Support_wav, label) 63 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank 64 | phoneme = self.text_embedder[Support_text]['phoneme'] 65 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 66 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 67 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '.npy') 68 | fbank_feature = fbank( 69 | Query_wav, 70 | num_mel_bins=80 71 | ) 72 | g2p_embed = torch.from_numpy(g2p_embed) 73 | g2p_embed = g2p_embed.type_as(fbank_feature) 74 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label) 75 | 76 | 77 | def __len__( 78 | self 79 | ): 80 | return len(self.data) 81 | 82 | 83 | def collate_fn(batch): 84 | # 将 batch 中的每个样本按照其数据类型分组 85 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch) 86 | # 对每个特征进行填充 87 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 88 | lengths = [len(seq) for seq in padded_fbank_feature] 89 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 90 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 91 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0) 92 | # 对 label 进行转换为 Tensor 93 | label_tensor = torch.tensor(label) 94 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor 95 | 96 | 97 | 98 | 99 | 100 | # test_data = LibriPhrase_Test_Dataset() 101 | # dataloader = DataLoader(test_data, batch_size=128, collate_fn=collate_fn, num_workers=16, shuffle=False) 102 | # from tqdm import tqdm 103 | # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)): 104 | # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data 105 | # print(dist_tensor) 106 | # break 107 | # # pass 108 | # pass -------------------------------------------------------------------------------- /mm-kws/dataloaders/libriphrase_test_18.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | import re 15 | 16 | class LibriPhrase_Test_Dataset(Dataset): 17 | def __init__( 18 | self, 19 | test_dir="/nvme01/aizq/mmkws/datasets/LibriPhrase_Test", 20 | csv=[ 21 | "libriphrase_diffspk_all_1word.csv", 22 | "libriphrase_diffspk_all_2word.csv", 23 | "libriphrase_diffspk_all_3word.csv", 24 | "libriphrase_diffspk_all_4word.csv" 25 | ], 26 | save_path='/nvme01/aizq/mmkws/datasets/LibriPhrase_Test/test_phrase.csv', 27 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/test_text_embeddings.pickle", 28 | preprocess=False, 29 | types='easy' 30 | ): 31 | if preprocess: 32 | self.data = pd.DataFrame(columns=['Query_text', 'Query_wav', 'Query_dur', 'Support_text', 'Support_wav', 'Support_dur', 'label', 'type']) 33 | for path in csv: 34 | n_word = os.path.join(test_dir, path) 35 | df = pd.read_csv(n_word) 36 | anc = df[['anchor_text', 'anchor', 'anchor_dur', 'comparison_text', 'comparison', 'comparison_dur', 'target', 'type']] 37 | com = df[['comparison_text', 'comparison', 'comparison_dur', 'anchor_text', 'anchor', 'anchor_dur', 'target', 'type']] 38 | self.data = self.data._append(anc.rename(columns={y: x for x, y in zip(self.data.columns, anc.columns)}), ignore_index=True) 39 | self.data = self.data._append(com.rename(columns={y: x for x, y in zip(self.data.columns, com.columns)}), ignore_index=True) 40 | self.data.to_csv(save_path, index=False) 41 | else: 42 | self.data = pd.read_csv(save_path) 43 | # print(self.data)/ 44 | # self.data['dist'] = self.data.apply(lambda x: Levenshtein.ratio(re.sub(r"[^a-zA-Z0-9]+", ' ', x['Support_text']), re.sub(r"[^a-zA-Z0-9]+", ' ', x['Query_text'])), axis=1) 45 | if types == 'easy': 46 | self.data = self.data.loc[self.data['type'].isin(['diffspk_easyneg', 'diffspk_positive'])] 47 | elif types == 'hard': 48 | self.data = self.data.loc[self.data['type'].isin(['diffspk_hardneg', 'diffspk_positive'])] 49 | 50 | 51 | self.data = self.data.values.tolist() 52 | # self.data = self.data[:1000] 53 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 54 | self.test_dir = test_dir 55 | 56 | def __getitem__( 57 | self, 58 | index 59 | ): 60 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label 61 | Query_text, Query_wav, _, Support_text, Support_wav, _, label, _ = self.data[index] 62 | # print(Query_text, Query_wav, Support_text, Support_wav, label) 63 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank 64 | phoneme = self.text_embedder[Support_text]['phoneme'] 65 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 66 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 67 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '_18.npy') 68 | fbank_feature = fbank( 69 | Query_wav, 70 | num_mel_bins=80 71 | ) 72 | g2p_embed = torch.from_numpy(g2p_embed) 73 | g2p_embed = g2p_embed.type_as(fbank_feature) 74 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label) 75 | 76 | 77 | def __len__( 78 | self 79 | ): 80 | return len(self.data) 81 | 82 | 83 | def collate_fn(batch): 84 | # 将 batch 中的每个样本按照其数据类型分组 85 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch) 86 | # 对每个特征进行填充 87 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 88 | lengths = [len(seq) for seq in padded_fbank_feature] 89 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 90 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 91 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0) 92 | # 对 label 进行转换为 Tensor 93 | label_tensor = torch.tensor(label) 94 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor 95 | 96 | 97 | 98 | 99 | 100 | # test_data = LibriPhrase_Test_Dataset() 101 | # dataloader = DataLoader(test_data, batch_size=128, collate_fn=collate_fn, num_workers=16, shuffle=False) 102 | # from tqdm import tqdm 103 | # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)): 104 | # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data 105 | # print(dist_tensor) 106 | # break 107 | # # pass 108 | # pass -------------------------------------------------------------------------------- /mm-kws/dataloaders/wenetphrase_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import os 5 | import pandas as pd 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | import torchaudio 9 | import pickle 10 | import numpy as np 11 | from torchaudio.compliance.kaldi import fbank 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pad_sequence 14 | import re 15 | 16 | class WenetPhrase_Test_Dataset(Dataset): 17 | def __init__( 18 | self, 19 | test_dir="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/WenetPhrase2/S", 20 | csv="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/wenetphrase_test.csv", 21 | test_text_embedding="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/zh_test_text_embeddings.pickle", 22 | types='easy' 23 | ): 24 | self.data = pd.read_csv(csv) 25 | if types == 'easy': 26 | self.data = self.data.loc[self.data['type'].isin(['easy', 'pos'])] 27 | elif types == 'hard': 28 | self.data = self.data.loc[self.data['type'].isin(['hard', 'pos'])] 29 | self.data = self.data.values.tolist() 30 | print(len(self.data)) 31 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file) 32 | self.test_dir = test_dir 33 | 34 | def __getitem__( 35 | self, 36 | index 37 | ): 38 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label 39 | Query_text, Query_wav, Support_text, Support_wav, label, _ = self.data[index] 40 | # print(Query_text, Query_wav, Support_text, Support_wav, label) 41 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank 42 | phoneme = self.text_embedder[Support_text]['phoneme'] 43 | g2p_embed = self.text_embedder[Support_text]['g2p_embed'] 44 | lm_embed = self.text_embedder[Support_text]['lm_embed'] 45 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '_18.npy') 46 | fbank_feature = fbank( 47 | Query_wav, 48 | num_mel_bins=80 49 | ) 50 | g2p_embed = torch.from_numpy(g2p_embed) 51 | g2p_embed = g2p_embed.type_as(fbank_feature) 52 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label) 53 | 54 | 55 | def __len__( 56 | self 57 | ): 58 | return len(self.data) 59 | 60 | 61 | def collate_fn(batch): 62 | # 将 batch 中的每个样本按照其数据类型分组 63 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch) 64 | # 对每个特征进行填充 65 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True) 66 | lengths = [len(seq) for seq in padded_fbank_feature] 67 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True) 68 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0) 69 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0) 70 | # 对 label 进行转换为 Tensor 71 | label_tensor = torch.tensor(label) 72 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor 73 | 74 | 75 | 76 | 77 | 78 | # test_data = WenetPhrase_Test_Dataset() 79 | # print(test_data[0]) 80 | # # dataloader = DataLoader(test_data, batch_size=2, collate_fn=collate_fn, num_workers=1, shuffle=False) 81 | # # from tqdm import tqdm 82 | # # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)): 83 | # # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data 84 | # # print(dist_tensor) 85 | # # break 86 | # # # pass 87 | # # pass -------------------------------------------------------------------------------- /mm-kws/g2p/g2p_en/__init__.py: -------------------------------------------------------------------------------- 1 | from .g2p import G2p 2 | -------------------------------------------------------------------------------- /mm-kws/g2p/g2p_en/checkpoint20.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/g2p/g2p_en/checkpoint20.npz -------------------------------------------------------------------------------- /mm-kws/g2p/g2p_en/expand.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #/usr/bin/python2 3 | ''' 4 | Borrowed 5 | from https://github.com/keithito/tacotron/blob/master/text/numbers.py 6 | By kyubyong park. kbpark.linguist@gmail.com. 7 | https://www.github.com/kyubyong/g2p 8 | ''' 9 | from __future__ import print_function 10 | import inflect 11 | import re 12 | 13 | 14 | 15 | _inflect = inflect.engine() 16 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 17 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 18 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 19 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 20 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 21 | _number_re = re.compile(r'[0-9]+') 22 | 23 | 24 | def _remove_commas(m): 25 | return m.group(1).replace(',', '') 26 | 27 | 28 | def _expand_decimal_point(m): 29 | return m.group(1).replace('.', ' point ') 30 | 31 | 32 | def _expand_dollars(m): 33 | match = m.group(1) 34 | parts = match.split('.') 35 | if len(parts) > 2: 36 | return match + ' dollars' # Unexpected format 37 | dollars = int(parts[0]) if parts[0] else 0 38 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 39 | if dollars and cents: 40 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 41 | cent_unit = 'cent' if cents == 1 else 'cents' 42 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 43 | elif dollars: 44 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 45 | return '%s %s' % (dollars, dollar_unit) 46 | elif cents: 47 | cent_unit = 'cent' if cents == 1 else 'cents' 48 | return '%s %s' % (cents, cent_unit) 49 | else: 50 | return 'zero dollars' 51 | 52 | 53 | def _expand_ordinal(m): 54 | return _inflect.number_to_words(m.group(0)) 55 | 56 | 57 | def _expand_number(m): 58 | num = int(m.group(0)) 59 | if num > 1000 and num < 3000: 60 | if num == 2000: 61 | return 'two thousand' 62 | elif num > 2000 and num < 2010: 63 | return 'two thousand ' + _inflect.number_to_words(num % 100) 64 | elif num % 100 == 0: 65 | return _inflect.number_to_words(num // 100) + ' hundred' 66 | else: 67 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 68 | else: 69 | return _inflect.number_to_words(num, andword='') 70 | 71 | 72 | def normalize_numbers(text): 73 | text = re.sub(_comma_number_re, _remove_commas, text) 74 | text = re.sub(_pounds_re, r'\1 pounds', text) 75 | text = re.sub(_dollars_re, _expand_dollars, text) 76 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 77 | text = re.sub(_ordinal_re, _expand_ordinal, text) 78 | text = re.sub(_number_re, _expand_number, text) 79 | return text 80 | -------------------------------------------------------------------------------- /mm-kws/libriphrase_hardneg.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/libriphrase_hardneg.json.zip -------------------------------------------------------------------------------- /mm-kws/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 4 | import math 5 | 6 | from conformer.conformer.model_def import Conformer 7 | 8 | class PositionalEmbedding(nn.Module): 9 | def __init__(self, d_model=512, max_len=512): 10 | super().__init__() 11 | # Compute the positional encodings once in log space. 12 | pe = torch.zeros(max_len, d_model).float() 13 | pe.require_grad = False 14 | position = torch.arange(0, max_len).float().unsqueeze(1) 15 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | pe = pe.unsqueeze(0) 19 | self.register_buffer('pe', pe) 20 | 21 | def forward(self, x): 22 | return self.pe[:, :x.size(1)] 23 | 24 | 25 | class TEXT_Fusion_transformer_encoder(nn.Module): 26 | def __init__( 27 | self, 28 | d_model, 29 | nlayers, 30 | nhead, 31 | dim_feedforward, 32 | dropout=0.1 33 | ): 34 | super().__init__() 35 | self.position_audio = PositionalEmbedding(d_model=128) 36 | self.position_text_g2p = PositionalEmbedding(d_model=128) 37 | self.position_text_lm = PositionalEmbedding(d_model=128) 38 | self.modality = nn.Embedding(4, 128, padding_idx=0) # 1 for audio, 2 for g2p, 3 for text lm 39 | self.dropout = nn.Dropout(p=dropout) 40 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True) 41 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 42 | 43 | def forward(self,audio_embedding, g2p_embedding, lm_embedding): 44 | position_audio_encoding = self.position_audio(audio_embedding) 45 | position_g2p_encoding = self.position_text_g2p(g2p_embedding) 46 | position_lm_encoding = self.position_text_lm(lm_embedding) 47 | 48 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device)) 49 | modality_g2p = self.modality(2 * torch.ones((position_g2p_encoding.size(0), g2p_embedding.shape[1]), dtype=torch.int).to(g2p_embedding.device)) 50 | modality_lm = self.modality(3 * torch.ones((position_lm_encoding.size(0), lm_embedding.shape[1]), dtype=torch.int).to(lm_embedding.device)) 51 | 52 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio 53 | g2p_tokens = g2p_embedding + position_g2p_encoding + modality_g2p 54 | lm_tokens = lm_embedding + position_lm_encoding + modality_lm 55 | 56 | #(3) concat tokens 57 | input_tokens = torch.cat((audio_tokens, g2p_tokens, lm_tokens), dim=1) 58 | input_tokens = self.dropout(input_tokens) 59 | 60 | output = self.transformer_encoder(input_tokens) 61 | return output 62 | 63 | 64 | class Audio_Fusion_transformer_encoder(nn.Module): 65 | def __init__( 66 | self, 67 | d_model, 68 | nlayers, 69 | nhead, 70 | dim_feedforward, 71 | dropout=0.1 72 | ): 73 | super().__init__() 74 | self.position_audio = PositionalEmbedding(d_model=128) 75 | self.position_audio_lm = PositionalEmbedding(d_model=128) 76 | self.modality = nn.Embedding(3, 128, padding_idx=0) # 1 for audio, 2 for audiolm 77 | self.dropout = nn.Dropout(p=dropout) 78 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True) 79 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 80 | 81 | def forward(self,audio_embedding, audiolm_embedding): 82 | position_audio_encoding = self.position_audio(audio_embedding) 83 | position_audiolm_encoding = self.position_audio_lm(audiolm_embedding) 84 | 85 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device)) 86 | modality_audiolm = self.modality(2 * torch.ones((position_audiolm_encoding.size(0), audiolm_embedding.shape[1]), dtype=torch.int).to(audiolm_embedding.device)) 87 | 88 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio 89 | audiolm_tokens = audiolm_embedding + position_audiolm_encoding + modality_audiolm 90 | 91 | #(3) concat tokens 92 | input_tokens = torch.cat((audio_tokens, audiolm_tokens), dim=1) 93 | input_tokens = self.dropout(input_tokens) 94 | 95 | output = self.transformer_encoder(input_tokens) 96 | return output 97 | 98 | 99 | class GRUFCModel(nn.Module): 100 | def __init__(self, input_dim, hidden_dim, output_dim): 101 | super(GRUFCModel, self).__init__() 102 | self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True) 103 | self.fc = nn.Linear(hidden_dim, output_dim) 104 | 105 | def forward(self, x): 106 | gru_out, _ = self.gru(x) 107 | gru_last_output = gru_out[:, -1, :] 108 | fc_out = self.fc(gru_last_output) 109 | return fc_out 110 | 111 | 112 | import torch.nn as nn 113 | class Projection(nn.Module): 114 | def __init__(self, input_dim, output_dim): 115 | super(Projection, self).__init__() 116 | layers = [] 117 | layers.append(nn.LayerNorm(input_dim)) 118 | layers.append(nn.Linear(input_dim, output_dim)) 119 | layers.append(nn.SiLU()) 120 | self.projection_block = nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | return self.projection_block(x) 124 | 125 | 126 | class MMKWS(nn.Module): 127 | def __init__(self): 128 | super().__init__() 129 | self.audioencoder = Conformer( 130 | input_dim= 80, 131 | encoder_dim= 128, 132 | num_encoder_layers= 6, 133 | num_attention_heads= 4, 134 | ) 135 | self.g2p_projection = Projection(input_dim=256, output_dim=128) 136 | self.lm_projection = Projection(input_dim=768, output_dim=128) 137 | self.audiolm_projection = Projection(input_dim=1024, output_dim=128) 138 | self.text_fusion_transformer = TEXT_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1) 139 | self.audio_fusion_transformer = Audio_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1) 140 | self.gru1 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64) 141 | self.gru2 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64) 142 | self.fc = nn.Linear(64, 1) 143 | self.phoneme_fc = nn.Linear(128, 1) 144 | self.text_fc = nn.Linear(128, 1) 145 | 146 | 147 | def forward(self, fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed): 148 | audio_embedding = self.audioencoder(fbank_feature, lengths)[0] 149 | g2p_embedding = self.g2p_projection(g2p_embed) 150 | lm_embedding = self.lm_projection(lm_embed) 151 | audiolm_embedding = self.audiolm_projection(audiolm_embed) 152 | 153 | fusion_text = self.text_fusion_transformer(audio_embedding, g2p_embedding, lm_embedding) 154 | fusion_audio = self.audio_fusion_transformer(audio_embedding, audiolm_embedding) 155 | fusion = self.gru1(fusion_text)+self.gru2(fusion_audio) 156 | fusion_pred = self.fc(fusion) 157 | 158 | fusion_phoneme_pred = self.phoneme_fc(fusion_text[:, audio_embedding.shape[1]:(audio_embedding.shape[1]+g2p_embedding.shape[1]), :]) 159 | fusion_text_pred = self.text_fc(fusion_text[:, (audio_embedding.shape[1]+g2p_embedding.shape[1]):, :]) 160 | return fusion_pred, fusion_phoneme_pred, fusion_text_pred 161 | -------------------------------------------------------------------------------- /mm-kws/test.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from pytorch_lightning.utilities.types import STEP_OUTPUT 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from pytorch_lightning import LightningModule, Trainer 6 | import pytorch_lightning as pl 7 | from dataloaders.libriphrase_test import LibriPhrase_Test_Dataset, collate_fn 8 | import torch.nn as nn 9 | from models import MMKWS 10 | from sklearn.metrics import roc_auc_score 11 | from sklearn.metrics import roc_curve 12 | import sklearn 13 | import numpy as np 14 | import torch.nn as nn 15 | import torch 16 | 17 | def compute_eer(label, pred): 18 | fpr, tpr, threshold = sklearn.metrics.roc_curve(label, pred) 19 | fnr = 1 - tpr 20 | 21 | eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))] 22 | eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 23 | eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))] 24 | 25 | eer = (eer_1 + eer_2) / 2 26 | return eer 27 | 28 | class EER(nn.Module): 29 | def __init__(self): 30 | super(EER, self).__init__() 31 | self.score = 0.0 32 | self.count = 0.0 33 | 34 | def forward(self, y_true, y_pred): 35 | label_np = y_true.flatten() # Convert to numpy array 36 | pred_np = y_pred.flatten() # Convert to numpy array 37 | 38 | eer_value = compute_eer(label_np, pred_np) 39 | 40 | self.score += eer_value 41 | self.count += 1 42 | 43 | return torch.tensor(self.score / self.count) 44 | 45 | # 1. 定义 LightningModuley 46 | class MMKWS_Wrapper(LightningModule): 47 | def __init__(self): 48 | super().__init__() 49 | self.model = MMKWS() 50 | self.criterion = nn.BCEWithLogitsLoss() 51 | self.test_preds, self.test_labels = [], [] 52 | 53 | 54 | def test_step(self, batch, batch_size): 55 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label = batch 56 | preds, _, _ = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed) 57 | preds = torch.sigmoid(preds) 58 | preds = preds.squeeze(dim=1) 59 | self.test_preds.append(preds) 60 | self.test_labels.append(label) 61 | 62 | 63 | def on_test_epoch_end(self): 64 | eer_loss = EER() 65 | all_preds = torch.cat(self.test_preds) 66 | all_labels = torch.cat(self.test_labels) 67 | y_true = all_labels.cpu().detach().numpy() 68 | y_scores = all_preds.cpu().detach().numpy() 69 | auc = roc_auc_score(y_true, y_scores) 70 | eer = eer_loss(y_true, y_scores) 71 | self.log('test_auc', auc) 72 | self.log('test_eer', eer) 73 | 74 | def configure_optimizers(self): 75 | optim = torch.optim.Adam(self.model.parameters(), lr=5e-4) 76 | return optim 77 | 78 | 79 | # 3. 设置 Trainer 和训练 80 | if __name__ == "__main__": 81 | import os 82 | os.environ['CUDA_VISIBLE_DEVICES'] = '5' 83 | test_dataset = LibriPhrase_Test_Dataset(types='easy') 84 | test_dataloader = DataLoader(test_dataset, batch_size=1024, collate_fn=collate_fn, shuffle=False, num_workers=24, drop_last=False) 85 | model = MMKWS_Wrapper.load_from_checkpoint("/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts/epochepoch=19.ckpt") 86 | model.eval() 87 | trainer = Trainer(devices=1, accelerator='gpu') # 设置训练器参数 88 | trainer.test(model, test_dataloader) 89 | pl.seed_everything(1234) 90 | test_dataset = LibriPhrase_Test_Dataset(types='hard') 91 | test_dataloader = DataLoader(test_dataset, batch_size=1024, collate_fn=collate_fn, shuffle=False, num_workers=24, drop_last=False) 92 | model = MMKWS_Wrapper.load_from_checkpoint("/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts/epochepoch=19.ckpt") 93 | trainer = Trainer(devices=1, accelerator='gpu', ) # 设置训练器参数 94 | model.eval() 95 | trainer.test(model, test_dataloader) 96 | -------------------------------------------------------------------------------- /mm-kws/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | from torch.optim.lr_scheduler import ExponentialLR 6 | 7 | import numpy as np 8 | from sklearn.metrics import roc_auc_score 9 | from sklearn.metrics import roc_curve 10 | 11 | import pytorch_lightning as pl 12 | from pytorch_lightning import LightningModule, Trainer 13 | from pytorch_lightning.callbacks import ModelCheckpoint 14 | from dataloaders.libriphrase_train import LibriPhrase_Train_Dataset, train_collate_fn 15 | from dataloaders.libriphrase_test import LibriPhrase_Test_Dataset, collate_fn 16 | from models import MMKWS 17 | 18 | def compute_eer(label, pred): 19 | fpr, tpr, threshold = roc_curve(label, pred) 20 | fnr = 1 - tpr 21 | 22 | eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))] 23 | eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 24 | eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))] 25 | 26 | eer = (eer_1 + eer_2) / 2 27 | return eer 28 | 29 | class EER(nn.Module): 30 | def __init__(self): 31 | super(EER, self).__init__() 32 | self.score = 0.0 33 | self.count = 0.0 34 | 35 | def forward(self, y_true, y_pred): 36 | label_np = y_true.flatten() # Convert to numpy array 37 | pred_np = y_pred.flatten() # Convert to numpy array 38 | 39 | eer_value = compute_eer(label_np, pred_np) 40 | 41 | self.score += eer_value 42 | self.count += 1 43 | 44 | return torch.tensor(self.score / self.count) 45 | 46 | 47 | class MMKWS_Wrapper(LightningModule): 48 | def __init__(self): 49 | super().__init__() 50 | self.model = MMKWS() 51 | self.criterion = nn.BCEWithLogitsLoss() 52 | self.test_preds = [] 53 | self.test_labels = [] 54 | 55 | 56 | def training_step(self, batch, batch_idx): 57 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label, pl, mask_pl, tl, mask_tl = batch 58 | preds, phoneme_preds, text_preds = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed) 59 | phoneme_preds = phoneme_preds.squeeze(dim=2) 60 | text_preds = text_preds.squeeze(dim=2) 61 | preds = preds.squeeze(dim=1) # Output is [Batch size, 1], but we want [Batch size] 62 | phoneme_loss = self.sequence_bce_loss(phoneme_preds, pl.float(), mask_pl) 63 | text_loss = self.sequence_bce_loss(text_preds, tl.float(), mask_tl) 64 | utt_loss = self.criterion(preds, label.float()) 65 | all_loss = utt_loss + phoneme_loss + text_loss 66 | self.log('train/all_loss', all_loss, on_step=True, prog_bar=True) 67 | self.log('train/utt_loss', utt_loss, on_step=True, prog_bar=True) 68 | self.log('train/text_loss', text_loss, on_step=True, prog_bar=True) 69 | self.log('train/phoneme_loss', phoneme_loss, on_step=True, prog_bar=True) 70 | return all_loss 71 | 72 | def validation_step(self, batch, batch_idx): 73 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label = batch 74 | preds, _, _ = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed) 75 | preds = torch.sigmoid(preds) 76 | preds = preds.squeeze(dim=1) 77 | self.test_preds.append(preds) 78 | self.test_labels.append(label) 79 | 80 | 81 | def on_validation_epoch_end(self): 82 | if self.current_epoch > 0: 83 | eer_loss = EER() 84 | all_preds = torch.cat(self.test_preds) 85 | all_labels = torch.cat(self.test_labels) 86 | y_true = all_labels.cpu().detach().numpy() 87 | y_scores = all_preds.cpu().detach().numpy() 88 | # 计算 AUC 89 | auc = roc_auc_score(y_true, y_scores) 90 | eer = eer_loss(y_true, y_scores) 91 | self.log('test/test_auc', auc) 92 | self.log('test/test_eer', eer) 93 | self.test_preds.clear() 94 | self.test_labels.clear() 95 | 96 | 97 | def configure_optimizers(self): 98 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 99 | return { 100 | "optimizer": optimizer, 101 | "lr_scheduler": { 102 | "scheduler": ExponentialLR(optimizer, gamma=0.95), 103 | "frequency": 1, 104 | "interval": 'epoch', 105 | }, 106 | } 107 | 108 | def sequence_bce_loss(self, preds, labels, mask): 109 | # 将预测值、标签和掩码都展平为一维向量 110 | preds_flat = preds.view(-1) 111 | labels_flat = labels.view(-1) 112 | mask_flat = mask.view(-1) 113 | # 仅考虑掩码为1的位置计算二元交叉熵损失 114 | valid_indices = torch.where(mask_flat == 1)[0] 115 | valid_preds = preds_flat[valid_indices] 116 | valid_labels = labels_flat[valid_indices] 117 | # 使用PyTorch内置的二元交叉熵损失函数 118 | loss = F.binary_cross_entropy_with_logits(valid_preds, valid_labels) 119 | return loss 120 | 121 | 122 | # 3. 设置 Trainer 和训练 123 | if __name__ == "__main__": 124 | pl.seed_everything(2024) 125 | import os 126 | os.environ['CUDA_VISIBLE_DEVICES']='0, 1, 2, 3' 127 | train_dataset = LibriPhrase_Train_Dataset() 128 | train_dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=train_collate_fn, shuffle=True, num_workers=16, drop_last=True) 129 | test_dataset = LibriPhrase_Test_Dataset(types='hard') 130 | test_dataloader = DataLoader(test_dataset, batch_size=256, collate_fn=collate_fn, shuffle=False, num_workers=8, drop_last=True) 131 | model = MMKWS_Wrapper() 132 | model_checkpoint = ModelCheckpoint( 133 | dirpath="/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts", 134 | filename='epoch{epoch:02d}', 135 | save_top_k=-1, 136 | ) 137 | logger = pl.loggers.TensorBoardLogger('/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/', name='MMKWS+') 138 | trainer = Trainer(devices=4, accelerator='gpu', # strategy='ddp_find_unused_parameters_true', 139 | logger=logger, max_epochs=100, callbacks=[model_checkpoint], accumulate_grad_batches=4, precision='16-mixed') # 设置训练器参数 140 | trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader) 141 | -------------------------------------------------------------------------------- /mm-kws/wenetphrase_hardneg.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/wenetphrase_hardneg.json.zip -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 4 | import math 5 | 6 | from conformer.conformer.model_def import Conformer 7 | 8 | class PositionalEmbedding(nn.Module): 9 | def __init__(self, d_model=512, max_len=512): 10 | super().__init__() 11 | # Compute the positional encodings once in log space. 12 | pe = torch.zeros(max_len, d_model).float() 13 | pe.require_grad = False 14 | position = torch.arange(0, max_len).float().unsqueeze(1) 15 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 16 | pe[:, 0::2] = torch.sin(position * div_term) 17 | pe[:, 1::2] = torch.cos(position * div_term) 18 | pe = pe.unsqueeze(0) 19 | self.register_buffer('pe', pe) 20 | 21 | def forward(self, x): 22 | return self.pe[:, :x.size(1)] 23 | 24 | 25 | class TEXT_Fusion_transformer_encoder(nn.Module): 26 | def __init__( 27 | self, 28 | d_model, 29 | nlayers, 30 | nhead, 31 | dim_feedforward, 32 | dropout=0.1 33 | ): 34 | super().__init__() 35 | self.position_audio = PositionalEmbedding(d_model=128) 36 | self.position_text_g2p = PositionalEmbedding(d_model=128) 37 | self.position_text_lm = PositionalEmbedding(d_model=128) 38 | self.modality = nn.Embedding(4, 128, padding_idx=0) # 1 for audio, 2 for g2p, 3 for text lm 39 | self.dropout = nn.Dropout(p=dropout) 40 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True) 41 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 42 | 43 | def forward(self,audio_embedding, g2p_embedding, lm_embedding): 44 | position_audio_encoding = self.position_audio(audio_embedding) 45 | position_g2p_encoding = self.position_text_g2p(g2p_embedding) 46 | position_lm_encoding = self.position_text_lm(lm_embedding) 47 | 48 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device)) 49 | modality_g2p = self.modality(2 * torch.ones((position_g2p_encoding.size(0), g2p_embedding.shape[1]), dtype=torch.int).to(g2p_embedding.device)) 50 | modality_lm = self.modality(3 * torch.ones((position_lm_encoding.size(0), lm_embedding.shape[1]), dtype=torch.int).to(lm_embedding.device)) 51 | 52 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio 53 | g2p_tokens = g2p_embedding + position_g2p_encoding + modality_g2p 54 | lm_tokens = lm_embedding + position_lm_encoding + modality_lm 55 | 56 | #(3) concat tokens 57 | input_tokens = torch.cat((audio_tokens, g2p_tokens, lm_tokens), dim=1) 58 | input_tokens = self.dropout(input_tokens) 59 | 60 | output = self.transformer_encoder(input_tokens) 61 | return output 62 | 63 | 64 | class Audio_Fusion_transformer_encoder(nn.Module): 65 | def __init__( 66 | self, 67 | d_model, 68 | nlayers, 69 | nhead, 70 | dim_feedforward, 71 | dropout=0.1 72 | ): 73 | super().__init__() 74 | self.position_audio = PositionalEmbedding(d_model=128) 75 | self.position_audio_lm = PositionalEmbedding(d_model=128) 76 | self.modality = nn.Embedding(3, 128, padding_idx=0) # 1 for audio, 2 for audiolm 77 | self.dropout = nn.Dropout(p=dropout) 78 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True) 79 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 80 | 81 | def forward(self,audio_embedding, audiolm_embedding): 82 | position_audio_encoding = self.position_audio(audio_embedding) 83 | position_audiolm_encoding = self.position_audio_lm(audiolm_embedding) 84 | 85 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device)) 86 | modality_audiolm = self.modality(2 * torch.ones((position_audiolm_encoding.size(0), audiolm_embedding.shape[1]), dtype=torch.int).to(audiolm_embedding.device)) 87 | 88 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio 89 | audiolm_tokens = audiolm_embedding + position_audiolm_encoding + modality_audiolm 90 | 91 | #(3) concat tokens 92 | input_tokens = torch.cat((audio_tokens, audiolm_tokens), dim=1) 93 | input_tokens = self.dropout(input_tokens) 94 | 95 | output = self.transformer_encoder(input_tokens) 96 | return output 97 | 98 | 99 | class GRUFCModel(nn.Module): 100 | def __init__(self, input_dim, hidden_dim, output_dim): 101 | super(GRUFCModel, self).__init__() 102 | self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True) 103 | self.fc = nn.Linear(hidden_dim, output_dim) 104 | 105 | def forward(self, x): 106 | gru_out, _ = self.gru(x) 107 | gru_last_output = gru_out[:, -1, :] 108 | fc_out = self.fc(gru_last_output) 109 | return fc_out 110 | 111 | 112 | import torch.nn as nn 113 | class Projection(nn.Module): 114 | def __init__(self, input_dim, output_dim): 115 | super(Projection, self).__init__() 116 | layers = [] 117 | layers.append(nn.LayerNorm(input_dim)) 118 | layers.append(nn.Linear(input_dim, output_dim)) 119 | layers.append(nn.SiLU()) 120 | self.projection_block = nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | return self.projection_block(x) 124 | 125 | 126 | class MMKWS(nn.Module): 127 | def __init__(self): 128 | super().__init__() 129 | self.audioencoder = Conformer( 130 | input_dim= 80, 131 | encoder_dim= 128, 132 | num_encoder_layers= 6, 133 | num_attention_heads= 4, 134 | ) 135 | self.g2p_projection = Projection(input_dim=256, output_dim=128) 136 | self.lm_projection = Projection(input_dim=768, output_dim=128) 137 | self.audiolm_projection = Projection(input_dim=1024, output_dim=128) 138 | self.text_fusion_transformer = TEXT_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1) 139 | self.audio_fusion_transformer = Audio_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1) 140 | self.gru1 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64) 141 | self.gru2 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64) 142 | self.fc = nn.Linear(64, 1) 143 | self.phoneme_fc = nn.Linear(128, 1) 144 | self.text_fc = nn.Linear(128, 1) 145 | 146 | 147 | def forward(self, fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed): 148 | audio_embedding = self.audioencoder(fbank_feature, lengths)[0] 149 | g2p_embedding = self.g2p_projection(g2p_embed) 150 | lm_embedding = self.lm_projection(lm_embed) 151 | audiolm_embedding = self.audiolm_projection(audiolm_embed) 152 | 153 | fusion_text = self.text_fusion_transformer(audio_embedding, g2p_embedding, lm_embedding) 154 | fusion_audio = self.audio_fusion_transformer(audio_embedding, audiolm_embedding) 155 | fusion = self.gru1(fusion_text)+self.gru2(fusion_audio) 156 | fusion_pred = self.fc(fusion) 157 | 158 | fusion_phoneme_pred = self.phoneme_fc(fusion_text[:, audio_embedding.shape[1]:(audio_embedding.shape[1]+g2p_embedding.shape[1]), :]) 159 | fusion_text_pred = self.text_fc(fusion_text[:, (audio_embedding.shape[1]+g2p_embedding.shape[1]):, :]) 160 | return fusion_pred, fusion_phoneme_pred, fusion_text_pred 161 | -------------------------------------------------------------------------------- /models_tiny.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 4 | import math 5 | 6 | from conformer.conformer.model_def import Conformer 7 | from mdtc import MDTC_KWS 8 | 9 | class PositionalEmbedding(nn.Module): 10 | def __init__(self, d_model=512, max_len=512): 11 | super().__init__() 12 | # Compute the positional encodings once in log space. 13 | pe = torch.zeros(max_len, d_model).float() 14 | pe.require_grad = False 15 | position = torch.arange(0, max_len).float().unsqueeze(1) 16 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 17 | pe[:, 0::2] = torch.sin(position * div_term) 18 | pe[:, 1::2] = torch.cos(position * div_term) 19 | pe = pe.unsqueeze(0) 20 | self.register_buffer('pe', pe) 21 | 22 | def forward(self, x): 23 | return self.pe[:, :x.size(1)] 24 | 25 | 26 | class TEXT_Fusion_transformer_encoder(nn.Module): 27 | def __init__( 28 | self, 29 | d_model, 30 | nlayers, 31 | nhead, 32 | dim_feedforward, 33 | dropout=0.1 34 | ): 35 | super().__init__() 36 | self.position_audio = PositionalEmbedding(d_model=128) 37 | self.position_text_g2p = PositionalEmbedding(d_model=128) 38 | self.position_text_lm = PositionalEmbedding(d_model=128) 39 | self.modality = nn.Embedding(4, 128, padding_idx=0) # 1 for audio, 2 for g2p, 3 for text lm 40 | self.dropout = nn.Dropout(p=dropout) 41 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True) 42 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 43 | 44 | def forward(self,audio_embedding, g2p_embedding, lm_embedding): 45 | position_audio_encoding = self.position_audio(audio_embedding) 46 | position_g2p_encoding = self.position_text_g2p(g2p_embedding) 47 | position_lm_encoding = self.position_text_lm(lm_embedding) 48 | 49 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device)) 50 | modality_g2p = self.modality(2 * torch.ones((position_g2p_encoding.size(0), g2p_embedding.shape[1]), dtype=torch.int).to(g2p_embedding.device)) 51 | modality_lm = self.modality(3 * torch.ones((position_lm_encoding.size(0), lm_embedding.shape[1]), dtype=torch.int).to(lm_embedding.device)) 52 | 53 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio 54 | g2p_tokens = g2p_embedding + position_g2p_encoding + modality_g2p 55 | lm_tokens = lm_embedding + position_lm_encoding + modality_lm 56 | 57 | #(3) concat tokens 58 | input_tokens = torch.cat((audio_tokens, g2p_tokens, lm_tokens), dim=1) 59 | input_tokens = self.dropout(input_tokens) 60 | 61 | output = self.transformer_encoder(input_tokens) 62 | return output 63 | 64 | 65 | class Audio_Fusion_transformer_encoder(nn.Module): 66 | def __init__( 67 | self, 68 | d_model, 69 | nlayers, 70 | nhead, 71 | dim_feedforward, 72 | dropout=0.1 73 | ): 74 | super().__init__() 75 | self.position_audio = PositionalEmbedding(d_model=128) 76 | self.position_audio_lm = PositionalEmbedding(d_model=128) 77 | self.modality = nn.Embedding(3, 128, padding_idx=0) # 1 for audio, 2 for audiolm 78 | self.dropout = nn.Dropout(p=dropout) 79 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True) 80 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 81 | 82 | def forward(self,audio_embedding, audiolm_embedding): 83 | position_audio_encoding = self.position_audio(audio_embedding) 84 | position_audiolm_encoding = self.position_audio_lm(audiolm_embedding) 85 | 86 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device)) 87 | modality_audiolm = self.modality(2 * torch.ones((position_audiolm_encoding.size(0), audiolm_embedding.shape[1]), dtype=torch.int).to(audiolm_embedding.device)) 88 | 89 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio 90 | audiolm_tokens = audiolm_embedding + position_audiolm_encoding + modality_audiolm 91 | 92 | #(3) concat tokens 93 | input_tokens = torch.cat((audio_tokens, audiolm_tokens), dim=1) 94 | input_tokens = self.dropout(input_tokens) 95 | 96 | output = self.transformer_encoder(input_tokens) 97 | return output 98 | 99 | 100 | class GRUFCModel(nn.Module): 101 | def __init__(self, input_dim, hidden_dim, output_dim): 102 | super(GRUFCModel, self).__init__() 103 | self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True) 104 | self.fc = nn.Linear(hidden_dim, output_dim) 105 | 106 | def forward(self, x): 107 | gru_out, _ = self.gru(x) 108 | gru_last_output = gru_out[:, -1, :] 109 | fc_out = self.fc(gru_last_output) 110 | return fc_out 111 | 112 | 113 | import torch.nn as nn 114 | class Projection(nn.Module): 115 | def __init__(self, input_dim, output_dim): 116 | super(Projection, self).__init__() 117 | layers = [] 118 | layers.append(nn.LayerNorm(input_dim)) 119 | layers.append(nn.Linear(input_dim, output_dim)) 120 | layers.append(nn.SiLU()) 121 | self.projection_block = nn.Sequential(*layers) 122 | 123 | def forward(self, x): 124 | return self.projection_block(x) 125 | 126 | 127 | class MMKWS(nn.Module): 128 | def __init__(self): 129 | super().__init__() 130 | self.audioencoder = MDTC_KWS() 131 | # self.audioencoder = Conformer( 132 | # input_dim= 80, 133 | # encoder_dim= 128, 134 | # num_encoder_layers= 6, 135 | # num_attention_heads= 4, 136 | # ) 137 | self.g2p_projection = Projection(input_dim=256, output_dim=128) 138 | self.lm_projection = Projection(input_dim=768, output_dim=128) 139 | self.audiolm_projection = Projection(input_dim=1024, output_dim=128) 140 | self.text_fusion_transformer = TEXT_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1) 141 | self.audio_fusion_transformer = Audio_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1) 142 | self.gru1 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64) 143 | self.gru2 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64) 144 | self.fc = nn.Linear(64, 1) 145 | self.phoneme_fc = nn.Linear(128, 1) 146 | self.text_fc = nn.Linear(128, 1) 147 | 148 | 149 | def forward(self, fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed): 150 | audio_embedding = self.audioencoder(fbank_feature, lengths)[0] 151 | g2p_embedding = self.g2p_projection(g2p_embed) 152 | lm_embedding = self.lm_projection(lm_embed) 153 | audiolm_embedding = self.audiolm_projection(audiolm_embed) 154 | 155 | fusion_text = self.text_fusion_transformer(audio_embedding, g2p_embedding, lm_embedding) 156 | fusion_audio = self.audio_fusion_transformer(audio_embedding, audiolm_embedding) 157 | fusion = self.gru1(fusion_text)+self.gru2(fusion_audio) 158 | fusion_pred = self.fc(fusion) 159 | 160 | fusion_phoneme_pred = self.phoneme_fc(fusion_text[:, audio_embedding.shape[1]:(audio_embedding.shape[1]+g2p_embedding.shape[1]), :]) 161 | fusion_text_pred = self.text_fc(fusion_text[:, (audio_embedding.shape[1]+g2p_embedding.shape[1]):, :]) 162 | return fusion_pred, fusion_phoneme_pred, fusion_text_pred 163 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from pytorch_lightning.utilities.types import STEP_OUTPUT 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from pytorch_lightning import LightningModule, Trainer 6 | import pytorch_lightning as pl 7 | from dataloaders.libriphrase_test import LibriPhrase_Test_Dataset, collate_fn 8 | import torch.nn as nn 9 | from models import MMKWS 10 | from sklearn.metrics import roc_auc_score 11 | from sklearn.metrics import roc_curve 12 | import sklearn 13 | import numpy as np 14 | import torch.nn as nn 15 | import torch 16 | 17 | def compute_eer(label, pred): 18 | fpr, tpr, threshold = sklearn.metrics.roc_curve(label, pred) 19 | fnr = 1 - tpr 20 | 21 | eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))] 22 | eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 23 | eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))] 24 | 25 | eer = (eer_1 + eer_2) / 2 26 | return eer 27 | 28 | class EER(nn.Module): 29 | def __init__(self): 30 | super(EER, self).__init__() 31 | self.score = 0.0 32 | self.count = 0.0 33 | 34 | def forward(self, y_true, y_pred): 35 | label_np = y_true.flatten() # Convert to numpy array 36 | pred_np = y_pred.flatten() # Convert to numpy array 37 | 38 | eer_value = compute_eer(label_np, pred_np) 39 | 40 | self.score += eer_value 41 | self.count += 1 42 | 43 | return torch.tensor(self.score / self.count) 44 | 45 | # 1. 定义 LightningModuley 46 | class MMKWS_Wrapper(LightningModule): 47 | def __init__(self): 48 | super().__init__() 49 | self.model = MMKWS() 50 | self.criterion = nn.BCEWithLogitsLoss() 51 | self.test_preds, self.test_labels = [], [] 52 | 53 | 54 | def test_step(self, batch, batch_size): 55 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label = batch 56 | preds, _, _ = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed) 57 | preds = torch.sigmoid(preds) 58 | preds = preds.squeeze(dim=1) 59 | self.test_preds.append(preds) 60 | self.test_labels.append(label) 61 | 62 | 63 | def on_test_epoch_end(self): 64 | eer_loss = EER() 65 | all_preds = torch.cat(self.test_preds) 66 | all_labels = torch.cat(self.test_labels) 67 | y_true = all_labels.cpu().detach().numpy() 68 | y_scores = all_preds.cpu().detach().numpy() 69 | auc = roc_auc_score(y_true, y_scores) 70 | eer = eer_loss(y_true, y_scores) 71 | self.log('test_auc', auc) 72 | self.log('test_eer', eer) 73 | 74 | def configure_optimizers(self): 75 | optim = torch.optim.Adam(self.model.parameters(), lr=5e-4) 76 | return optim 77 | 78 | 79 | # 3. 设置 Trainer 和训练 80 | if __name__ == "__main__": 81 | import os 82 | os.environ['CUDA_VISIBLE_DEVICES'] = '5' 83 | test_dataset = LibriPhrase_Test_Dataset(types='easy') 84 | test_dataloader = DataLoader(test_dataset, batch_size=1024, collate_fn=collate_fn, shuffle=False, num_workers=24, drop_last=False) 85 | model = MMKWS_Wrapper.load_from_checkpoint("/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts/epochepoch=19.ckpt") 86 | model.eval() 87 | trainer = Trainer(devices=1, accelerator='gpu') # 设置训练器参数 88 | trainer.test(model, test_dataloader) 89 | pl.seed_everything(1234) 90 | test_dataset = LibriPhrase_Test_Dataset(types='hard') 91 | test_dataloader = DataLoader(test_dataset, batch_size=1024, collate_fn=collate_fn, shuffle=False, num_workers=24, drop_last=False) 92 | model = MMKWS_Wrapper.load_from_checkpoint("/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts/epochepoch=19.ckpt") 93 | trainer = Trainer(devices=1, accelerator='gpu', ) # 设置训练器参数 94 | model.eval() 95 | trainer.test(model, test_dataloader) 96 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | from torch.optim.lr_scheduler import ExponentialLR 6 | 7 | import numpy as np 8 | from sklearn.metrics import roc_auc_score 9 | from sklearn.metrics import roc_curve 10 | 11 | import pytorch_lightning as pl 12 | from pytorch_lightning import LightningModule, Trainer 13 | from pytorch_lightning.callbacks import ModelCheckpoint 14 | from dataloaders.libriphrase_train import LibriPhrase_Train_Dataset, train_collate_fn 15 | from dataloaders.libriphrase_test import LibriPhrase_Test_Dataset, collate_fn 16 | from models import MMKWS 17 | 18 | def compute_eer(label, pred): 19 | fpr, tpr, threshold = roc_curve(label, pred) 20 | fnr = 1 - tpr 21 | 22 | eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))] 23 | eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))] 24 | eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))] 25 | 26 | eer = (eer_1 + eer_2) / 2 27 | return eer 28 | 29 | class EER(nn.Module): 30 | def __init__(self): 31 | super(EER, self).__init__() 32 | self.score = 0.0 33 | self.count = 0.0 34 | 35 | def forward(self, y_true, y_pred): 36 | label_np = y_true.flatten() # Convert to numpy array 37 | pred_np = y_pred.flatten() # Convert to numpy array 38 | 39 | eer_value = compute_eer(label_np, pred_np) 40 | 41 | self.score += eer_value 42 | self.count += 1 43 | 44 | return torch.tensor(self.score / self.count) 45 | 46 | 47 | class MMKWS_Wrapper(LightningModule): 48 | def __init__(self): 49 | super().__init__() 50 | self.model = MMKWS() 51 | self.criterion = nn.BCEWithLogitsLoss() 52 | self.test_preds = [] 53 | self.test_labels = [] 54 | 55 | 56 | def training_step(self, batch, batch_idx): 57 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label, pl, mask_pl, tl, mask_tl = batch 58 | preds, phoneme_preds, text_preds = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed) 59 | phoneme_preds = phoneme_preds.squeeze(dim=2) 60 | text_preds = text_preds.squeeze(dim=2) 61 | preds = preds.squeeze(dim=1) # Output is [Batch size, 1], but we want [Batch size] 62 | phoneme_loss = self.sequence_bce_loss(phoneme_preds, pl.float(), mask_pl) 63 | text_loss = self.sequence_bce_loss(text_preds, tl.float(), mask_tl) 64 | utt_loss = self.criterion(preds, label.float()) 65 | all_loss = utt_loss + phoneme_loss + text_loss 66 | self.log('train/all_loss', all_loss, on_step=True, prog_bar=True) 67 | self.log('train/utt_loss', utt_loss, on_step=True, prog_bar=True) 68 | self.log('train/text_loss', text_loss, on_step=True, prog_bar=True) 69 | self.log('train/phoneme_loss', phoneme_loss, on_step=True, prog_bar=True) 70 | return all_loss 71 | 72 | def validation_step(self, batch, batch_idx): 73 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label = batch 74 | preds, _, _ = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed) 75 | preds = torch.sigmoid(preds) 76 | preds = preds.squeeze(dim=1) 77 | self.test_preds.append(preds) 78 | self.test_labels.append(label) 79 | 80 | 81 | def on_validation_epoch_end(self): 82 | if self.current_epoch > 0: 83 | eer_loss = EER() 84 | all_preds = torch.cat(self.test_preds) 85 | all_labels = torch.cat(self.test_labels) 86 | y_true = all_labels.cpu().detach().numpy() 87 | y_scores = all_preds.cpu().detach().numpy() 88 | # 计算 AUC 89 | auc = roc_auc_score(y_true, y_scores) 90 | eer = eer_loss(y_true, y_scores) 91 | self.log('test/test_auc', auc) 92 | self.log('test/test_eer', eer) 93 | self.test_preds.clear() 94 | self.test_labels.clear() 95 | 96 | 97 | def configure_optimizers(self): 98 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 99 | return { 100 | "optimizer": optimizer, 101 | "lr_scheduler": { 102 | "scheduler": ExponentialLR(optimizer, gamma=0.95), 103 | "frequency": 1, 104 | "interval": 'epoch', 105 | }, 106 | } 107 | 108 | def sequence_bce_loss(self, preds, labels, mask): 109 | # 将预测值、标签和掩码都展平为一维向量 110 | preds_flat = preds.view(-1) 111 | labels_flat = labels.view(-1) 112 | mask_flat = mask.view(-1) 113 | # 仅考虑掩码为1的位置计算二元交叉熵损失 114 | valid_indices = torch.where(mask_flat == 1)[0] 115 | valid_preds = preds_flat[valid_indices] 116 | valid_labels = labels_flat[valid_indices] 117 | # 使用PyTorch内置的二元交叉熵损失函数 118 | loss = F.binary_cross_entropy_with_logits(valid_preds, valid_labels) 119 | return loss 120 | 121 | 122 | # 3. 设置 Trainer 和训练 123 | if __name__ == "__main__": 124 | pl.seed_everything(2024) 125 | import os 126 | os.environ['CUDA_VISIBLE_DEVICES']='0, 1, 2, 3' 127 | train_dataset = LibriPhrase_Train_Dataset() 128 | train_dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=train_collate_fn, shuffle=True, num_workers=16, drop_last=True) 129 | test_dataset = LibriPhrase_Test_Dataset(types='hard') 130 | test_dataloader = DataLoader(test_dataset, batch_size=256, collate_fn=collate_fn, shuffle=False, num_workers=8, drop_last=True) 131 | model = MMKWS_Wrapper() 132 | model_checkpoint = ModelCheckpoint( 133 | dirpath="/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts", 134 | filename='epoch{epoch:02d}', 135 | save_top_k=-1, 136 | ) 137 | logger = pl.loggers.TensorBoardLogger('/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/', name='MMKWS+') 138 | trainer = Trainer(devices=4, accelerator='gpu', # strategy='ddp_find_unused_parameters_true', 139 | logger=logger, max_epochs=100, callbacks=[model_checkpoint], accumulate_grad_batches=4, precision='16-mixed') # 设置训练器参数 140 | trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader) 141 | -------------------------------------------------------------------------------- /wenetphrase_hardneg.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/wenetphrase_hardneg.json.zip --------------------------------------------------------------------------------