├── LICENSE
├── README.md
├── asserts
├── libriphrase.png
├── overview.png
└── wenetphrase.png
├── conformer
└── conformer
│ ├── __init__.py
│ ├── activation.py
│ ├── attention.py
│ ├── convolution.py
│ ├── embedding.py
│ ├── encoder.py
│ ├── feed_forward.py
│ ├── model.py
│ ├── model_def.py
│ └── modules.py
├── data-processing
└── wenetspeech
│ ├── README.md
│ ├── cn_tn.py
│ ├── norm_txt.py
│ ├── read.py
│ └── wenetclip.py
├── dataloaders
├── SPC_N0_ALL.py
├── SPC_N0_TARGET.py
├── SPC_N1_ALL.py
├── SPC_N1_TARGET.py
├── __pycache__
│ ├── SPC_N0_ALL.cpython-310.pyc
│ ├── SPC_N0_TARGET.cpython-310.pyc
│ ├── SPC_N1_ALL.cpython-310.pyc
│ ├── SPC_N1_TARGET.cpython-310.pyc
│ ├── libriphrase_test.cpython-310.pyc
│ ├── libriphrase_test_18.cpython-310.pyc
│ ├── libriphrase_train.cpython-310.pyc
│ ├── libriphrase_trainY.cpython-310.pyc
│ └── libriphrase_train_18.cpython-310.pyc
├── g2p
│ ├── LICENSE.txt
│ ├── g2p_en
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── expand.cpython-310.pyc
│ │ │ ├── expand.cpython-39.pyc
│ │ │ ├── g2p.cpython-310.pyc
│ │ │ └── g2p.cpython-39.pyc
│ │ ├── checkpoint20.npz
│ │ ├── expand.py
│ │ ├── g2p.py
│ │ └── homographs.en
│ └── lightning_logs
│ │ ├── version_0
│ │ ├── events.out.tfevents.1703650022.great-server24.1715931.0
│ │ └── hparams.yaml
│ │ ├── version_1
│ │ ├── events.out.tfevents.1703650063.great-server24.1718627.0
│ │ └── hparams.yaml
│ │ ├── version_2
│ │ ├── events.out.tfevents.1703650148.great-server24.1718627.1
│ │ └── hparams.yaml
│ │ ├── version_3
│ │ ├── events.out.tfevents.1703651214.great-server24.1769388.0
│ │ └── hparams.yaml
│ │ ├── version_4
│ │ ├── events.out.tfevents.1703651335.great-server24.1769388.1
│ │ └── hparams.yaml
│ │ ├── version_5
│ │ ├── events.out.tfevents.1703651958.great-server24.1794497.0
│ │ └── hparams.yaml
│ │ └── version_6
│ │ ├── events.out.tfevents.1703652071.great-server24.1794497.1
│ │ └── hparams.yaml
├── libriphrase_test.py
├── libriphrase_test_18.py
├── libriphrase_train.py
├── libriphrase_trainY.py
├── libriphrase_trainY2.py
├── libriphrase_train_18.py
├── wenetphrase_test.py
└── wenetphrase_train.py
├── g2p
├── LICENSE.txt
└── g2p_en
│ ├── __init__.py
│ ├── checkpoint20.npz
│ ├── expand.py
│ ├── g2p.py
│ └── homographs.en
├── libriphrase_hardneg.json.zip
├── mdtc.py
├── mm-kws
├── LICENSE
├── README.md
├── conformer
│ └── conformer
│ │ ├── __init__.py
│ │ ├── activation.py
│ │ ├── attention.py
│ │ ├── convolution.py
│ │ ├── embedding.py
│ │ ├── encoder.py
│ │ ├── feed_forward.py
│ │ ├── model.py
│ │ ├── model_def.py
│ │ └── modules.py
├── dataloaders
│ ├── SPC_N0_ALL.py
│ ├── SPC_N0_TARGET.py
│ ├── SPC_N1_ALL.py
│ ├── SPC_N1_TARGET.py
│ ├── __pycache__
│ │ ├── SPC_N0_ALL.cpython-310.pyc
│ │ ├── SPC_N0_TARGET.cpython-310.pyc
│ │ ├── SPC_N1_ALL.cpython-310.pyc
│ │ ├── SPC_N1_TARGET.cpython-310.pyc
│ │ ├── libriphrase_test.cpython-310.pyc
│ │ ├── libriphrase_test_18.cpython-310.pyc
│ │ ├── libriphrase_train.cpython-310.pyc
│ │ ├── libriphrase_trainY.cpython-310.pyc
│ │ └── libriphrase_train_18.cpython-310.pyc
│ ├── g2p
│ │ ├── LICENSE.txt
│ │ └── g2p_en
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── expand.cpython-310.pyc
│ │ │ ├── expand.cpython-39.pyc
│ │ │ ├── g2p.cpython-310.pyc
│ │ │ └── g2p.cpython-39.pyc
│ │ │ ├── checkpoint20.npz
│ │ │ ├── expand.py
│ │ │ ├── g2p.py
│ │ │ └── homographs.en
│ ├── libriphrase_test.py
│ ├── libriphrase_test_18.py
│ ├── libriphrase_train.py
│ ├── libriphrase_trainY.py
│ ├── libriphrase_trainY2.py
│ ├── libriphrase_train_18.py
│ ├── wenetphrase_test.py
│ └── wenetphrase_train.py
├── g2p
│ ├── LICENSE.txt
│ └── g2p_en
│ │ ├── __init__.py
│ │ ├── checkpoint20.npz
│ │ ├── expand.py
│ │ ├── g2p.py
│ │ └── homographs.en
├── libriphrase_hardneg.json.zip
├── mdtc.py
├── models.py
├── models_tiny.py
├── test.py
├── train.py
└── wenetphrase_hardneg.json.zip
├── models.py
├── models_tiny.py
├── test.py
├── train.py
└── wenetphrase_hardneg.json.zip
/asserts/libriphrase.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/asserts/libriphrase.png
--------------------------------------------------------------------------------
/asserts/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/asserts/overview.png
--------------------------------------------------------------------------------
/asserts/wenetphrase.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/asserts/wenetphrase.png
--------------------------------------------------------------------------------
/conformer/conformer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .model import Conformer
16 |
--------------------------------------------------------------------------------
/conformer/conformer/activation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch.nn as nn
16 | from torch import Tensor
17 |
18 |
19 | class Swish(nn.Module):
20 | """
21 | Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
22 | to a variety of challenging domains such as Image classification and Machine translation.
23 | """
24 | def __init__(self):
25 | super(Swish, self).__init__()
26 |
27 | def forward(self, inputs: Tensor) -> Tensor:
28 | return inputs * inputs.sigmoid()
29 |
30 |
31 | class GLU(nn.Module):
32 | """
33 | The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
34 | in the paper “Language Modeling with Gated Convolutional Networks”
35 | """
36 | def __init__(self, dim: int) -> None:
37 | super(GLU, self).__init__()
38 | self.dim = dim
39 |
40 | def forward(self, inputs: Tensor) -> Tensor:
41 | outputs, gate = inputs.chunk(2, dim=self.dim)
42 | return outputs * gate.sigmoid()
43 |
--------------------------------------------------------------------------------
/conformer/conformer/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from torch import Tensor
20 | from typing import Optional
21 |
22 | from .embedding import PositionalEncoding
23 | from .modules import Linear
24 |
25 |
26 | class RelativeMultiHeadAttention(nn.Module):
27 | """
28 | Multi-head attention with relative positional encoding.
29 | This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
30 |
31 | Args:
32 | d_model (int): The dimension of model
33 | num_heads (int): The number of attention heads.
34 | dropout_p (float): probability of dropout
35 |
36 | Inputs: query, key, value, pos_embedding, mask
37 | - **query** (batch, time, dim): Tensor containing query vector
38 | - **key** (batch, time, dim): Tensor containing key vector
39 | - **value** (batch, time, dim): Tensor containing value vector
40 | - **pos_embedding** (batch, time, dim): Positional embedding tensor
41 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
42 |
43 | Returns:
44 | - **outputs**: Tensor produces by relative multi head attention module.
45 | """
46 | def __init__(
47 | self,
48 | d_model: int = 512,
49 | num_heads: int = 16,
50 | dropout_p: float = 0.1,
51 | ):
52 | super(RelativeMultiHeadAttention, self).__init__()
53 | assert d_model % num_heads == 0, "d_model % num_heads should be zero."
54 | self.d_model = d_model
55 | self.d_head = int(d_model / num_heads)
56 | self.num_heads = num_heads
57 | self.sqrt_dim = math.sqrt(d_model)
58 |
59 | self.query_proj = Linear(d_model, d_model)
60 | self.key_proj = Linear(d_model, d_model)
61 | self.value_proj = Linear(d_model, d_model)
62 | self.pos_proj = Linear(d_model, d_model, bias=False)
63 |
64 | self.dropout = nn.Dropout(p=dropout_p)
65 | self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
66 | self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
67 | torch.nn.init.xavier_uniform_(self.u_bias)
68 | torch.nn.init.xavier_uniform_(self.v_bias)
69 |
70 | self.out_proj = Linear(d_model, d_model)
71 |
72 | def forward(
73 | self,
74 | query: Tensor,
75 | key: Tensor,
76 | value: Tensor,
77 | pos_embedding: Tensor,
78 | mask: Optional[Tensor] = None,
79 | ) -> Tensor:
80 | batch_size = value.size(0)
81 |
82 | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
83 | key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
84 | value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
85 | pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
86 |
87 | content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
88 | pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
89 | pos_score = self._relative_shift(pos_score)
90 |
91 | score = (content_score + pos_score) / self.sqrt_dim
92 |
93 | if mask is not None:
94 | mask = mask.unsqueeze(1)
95 | score.masked_fill_(mask, -1e9)
96 |
97 | attn = F.softmax(score, -1)
98 | attn = self.dropout(attn)
99 |
100 | context = torch.matmul(attn, value).transpose(1, 2)
101 | context = context.contiguous().view(batch_size, -1, self.d_model)
102 |
103 | return self.out_proj(context)
104 |
105 | def _relative_shift(self, pos_score: Tensor) -> Tensor:
106 | batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
107 | zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
108 | padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
109 |
110 | padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
111 | pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
112 |
113 | return pos_score
114 |
115 |
116 | class MultiHeadedSelfAttentionModule(nn.Module):
117 | """
118 | Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
119 | the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
120 | module to generalize better on different input length and the resulting encoder is more robust to the variance of
121 | the utterance length. Conformer use prenorm residual units with dropout which helps training
122 | and regularizing deeper models.
123 |
124 | Args:
125 | d_model (int): The dimension of model
126 | num_heads (int): The number of attention heads.
127 | dropout_p (float): probability of dropout
128 |
129 | Inputs: inputs, mask
130 | - **inputs** (batch, time, dim): Tensor containing input vector
131 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
132 |
133 | Returns:
134 | - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
135 | """
136 | def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1):
137 | super(MultiHeadedSelfAttentionModule, self).__init__()
138 | self.positional_encoding = PositionalEncoding(d_model)
139 | self.layer_norm = nn.LayerNorm(d_model)
140 | self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p)
141 | self.dropout = nn.Dropout(p=dropout_p)
142 |
143 | def forward(self, inputs: Tensor, mask: Optional[Tensor] = None):
144 | batch_size, seq_length, _ = inputs.size()
145 | pos_embedding = self.positional_encoding(seq_length)
146 | pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
147 |
148 | inputs = self.layer_norm(inputs)
149 | outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask)
150 |
151 | return self.dropout(outputs)
152 |
--------------------------------------------------------------------------------
/conformer/conformer/convolution.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch import Tensor
18 | from typing import Tuple
19 |
20 | from .activation import Swish, GLU
21 | from .modules import Transpose
22 |
23 |
24 | class DepthwiseConv1d(nn.Module):
25 | """
26 | When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
27 | this operation is termed in literature as depthwise convolution.
28 |
29 | Args:
30 | in_channels (int): Number of channels in the input
31 | out_channels (int): Number of channels produced by the convolution
32 | kernel_size (int or tuple): Size of the convolving kernel
33 | stride (int, optional): Stride of the convolution. Default: 1
34 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
35 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True
36 |
37 | Inputs: inputs
38 | - **inputs** (batch, in_channels, time): Tensor containing input vector
39 |
40 | Returns: outputs
41 | - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
42 | """
43 | def __init__(
44 | self,
45 | in_channels: int,
46 | out_channels: int,
47 | kernel_size: int,
48 | stride: int = 1,
49 | padding: int = 0,
50 | bias: bool = False,
51 | ) -> None:
52 | super(DepthwiseConv1d, self).__init__()
53 | assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
54 | self.conv = nn.Conv1d(
55 | in_channels=in_channels,
56 | out_channels=out_channels,
57 | kernel_size=kernel_size,
58 | groups=in_channels,
59 | stride=stride,
60 | padding=padding,
61 | bias=bias,
62 | )
63 |
64 | def forward(self, inputs: Tensor) -> Tensor:
65 | return self.conv(inputs)
66 |
67 |
68 | class PointwiseConv1d(nn.Module):
69 | """
70 | When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
71 | This operation often used to match dimensions.
72 |
73 | Args:
74 | in_channels (int): Number of channels in the input
75 | out_channels (int): Number of channels produced by the convolution
76 | stride (int, optional): Stride of the convolution. Default: 1
77 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
78 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True
79 |
80 | Inputs: inputs
81 | - **inputs** (batch, in_channels, time): Tensor containing input vector
82 |
83 | Returns: outputs
84 | - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
85 | """
86 | def __init__(
87 | self,
88 | in_channels: int,
89 | out_channels: int,
90 | stride: int = 1,
91 | padding: int = 0,
92 | bias: bool = True,
93 | ) -> None:
94 | super(PointwiseConv1d, self).__init__()
95 | self.conv = nn.Conv1d(
96 | in_channels=in_channels,
97 | out_channels=out_channels,
98 | kernel_size=1,
99 | stride=stride,
100 | padding=padding,
101 | bias=bias,
102 | )
103 |
104 | def forward(self, inputs: Tensor) -> Tensor:
105 | return self.conv(inputs)
106 |
107 |
108 | class ConformerConvModule(nn.Module):
109 | """
110 | Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
111 | This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
112 | to aid training deep models.
113 |
114 | Args:
115 | in_channels (int): Number of channels in the input
116 | kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
117 | dropout_p (float, optional): probability of dropout
118 |
119 | Inputs: inputs
120 | inputs (batch, time, dim): Tensor contains input sequences
121 |
122 | Outputs: outputs
123 | outputs (batch, time, dim): Tensor produces by conformer convolution module.
124 | """
125 | def __init__(
126 | self,
127 | in_channels: int,
128 | kernel_size: int = 31,
129 | expansion_factor: int = 2,
130 | dropout_p: float = 0.1,
131 | ) -> None:
132 | super(ConformerConvModule, self).__init__()
133 | assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
134 | assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
135 |
136 | self.sequential = nn.Sequential(
137 | nn.LayerNorm(in_channels),
138 | Transpose(shape=(1, 2)),
139 | PointwiseConv1d(in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True),
140 | GLU(dim=1),
141 | DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
142 | nn.BatchNorm1d(in_channels),
143 | Swish(),
144 | PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True),
145 | nn.Dropout(p=dropout_p),
146 | )
147 |
148 | def forward(self, inputs: Tensor) -> Tensor:
149 | return self.sequential(inputs).transpose(1, 2)
150 |
151 |
152 | class Conv2dSubampling(nn.Module):
153 | """
154 | Convolutional 2D subsampling (to 1/4 length)
155 |
156 | Args:
157 | in_channels (int): Number of channels in the input image
158 | out_channels (int): Number of channels produced by the convolution
159 |
160 | Inputs: inputs
161 | - **inputs** (batch, time, dim): Tensor containing sequence of inputs
162 |
163 | Returns: outputs, output_lengths
164 | - **outputs** (batch, time, dim): Tensor produced by the convolution
165 | - **output_lengths** (batch): list of sequence output lengths
166 | """
167 | def __init__(self, in_channels: int, out_channels: int) -> None:
168 | super(Conv2dSubampling, self).__init__()
169 | self.sequential = nn.Sequential(
170 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2),
171 | nn.ReLU(),
172 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2),
173 | nn.ReLU(),
174 | )
175 |
176 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
177 | outputs = self.sequential(inputs.unsqueeze(1))
178 | batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
179 |
180 | outputs = outputs.permute(0, 2, 1, 3)
181 | outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim)
182 |
183 | output_lengths = input_lengths >> 2
184 | output_lengths -= 1
185 |
186 | return outputs, output_lengths
187 |
--------------------------------------------------------------------------------
/conformer/conformer/embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import torch
17 | import torch.nn as nn
18 | from torch import Tensor
19 |
20 |
21 | class PositionalEncoding(nn.Module):
22 | """
23 | Positional Encoding proposed in "Attention Is All You Need".
24 | Since transformer contains no recurrence and no convolution, in order for the model to make
25 | use of the order of the sequence, we must add some positional information.
26 |
27 | "Attention Is All You Need" use sine and cosine functions of different frequencies:
28 | PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
29 | PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
30 | """
31 | def __init__(self, d_model: int = 512, max_len: int = 10000) -> None:
32 | super(PositionalEncoding, self).__init__()
33 | pe = torch.zeros(max_len, d_model, requires_grad=False)
34 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
35 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
36 | pe[:, 0::2] = torch.sin(position * div_term)
37 | pe[:, 1::2] = torch.cos(position * div_term)
38 | pe = pe.unsqueeze(0)
39 | self.register_buffer('pe', pe)
40 |
41 | def forward(self, length: int) -> Tensor:
42 | return self.pe[:, :length]
--------------------------------------------------------------------------------
/conformer/conformer/feed_forward.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch import Tensor
18 |
19 | from .activation import Swish
20 | from .modules import Linear
21 |
22 |
23 | class FeedForwardModule(nn.Module):
24 | """
25 | Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit
26 | and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps
27 | regularizing the network.
28 |
29 | Args:
30 | encoder_dim (int): Dimension of conformer encoder
31 | expansion_factor (int): Expansion factor of feed forward module.
32 | dropout_p (float): Ratio of dropout
33 |
34 | Inputs: inputs
35 | - **inputs** (batch, time, dim): Tensor contains input sequences
36 |
37 | Outputs: outputs
38 | - **outputs** (batch, time, dim): Tensor produces by feed forward module.
39 | """
40 | def __init__(
41 | self,
42 | encoder_dim: int = 512,
43 | expansion_factor: int = 4,
44 | dropout_p: float = 0.1,
45 | ) -> None:
46 | super(FeedForwardModule, self).__init__()
47 | self.sequential = nn.Sequential(
48 | nn.LayerNorm(encoder_dim),
49 | Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
50 | Swish(),
51 | nn.Dropout(p=dropout_p),
52 | Linear(encoder_dim * expansion_factor, encoder_dim, bias=True),
53 | nn.Dropout(p=dropout_p),
54 | )
55 |
56 | def forward(self, inputs: Tensor) -> Tensor:
57 | return self.sequential(inputs)
58 |
--------------------------------------------------------------------------------
/conformer/conformer/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch import Tensor
18 | from typing import Tuple
19 |
20 | from .encoder import ConformerEncoder
21 | from .modules import Linear
22 |
23 |
24 | class Conformer(nn.Module):
25 | """
26 | Conformer: Convolution-augmented Transformer for Speech Recognition
27 | The paper used a one-lstm Transducer decoder, currently still only implemented
28 | the conformer encoder shown in the paper.
29 |
30 | Args:
31 | num_classes (int): Number of classification classes
32 | input_dim (int, optional): Dimension of input vector
33 | encoder_dim (int, optional): Dimension of conformer encoder
34 | num_encoder_layers (int, optional): Number of conformer blocks
35 | num_attention_heads (int, optional): Number of attention heads
36 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
37 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
38 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
39 | attention_dropout_p (float, optional): Probability of attention module dropout
40 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout
41 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel
42 | half_step_residual (bool): Flag indication whether to use half step residual or not
43 |
44 | Inputs: inputs, input_lengths
45 | - **inputs** (batch, time, dim): Tensor containing input vector
46 | - **input_lengths** (batch): list of sequence input lengths
47 |
48 | Returns: outputs, output_lengths
49 | - **outputs** (batch, out_channels, time): Tensor produces by conformer.
50 | - **output_lengths** (batch): list of sequence output lengths
51 | """
52 | def __init__(
53 | self,
54 | num_classes: int,
55 | input_dim: int = 80,
56 | encoder_dim: int = 512,
57 | num_encoder_layers: int = 17,
58 | num_attention_heads: int = 8,
59 | feed_forward_expansion_factor: int = 4,
60 | conv_expansion_factor: int = 2,
61 | input_dropout_p: float = 0.1,
62 | feed_forward_dropout_p: float = 0.1,
63 | attention_dropout_p: float = 0.1,
64 | conv_dropout_p: float = 0.1,
65 | conv_kernel_size: int = 31,
66 | half_step_residual: bool = True,
67 | ) -> None:
68 | super(Conformer, self).__init__()
69 | self.encoder = ConformerEncoder(
70 | input_dim=input_dim,
71 | encoder_dim=encoder_dim,
72 | num_layers=num_encoder_layers,
73 | num_attention_heads=num_attention_heads,
74 | feed_forward_expansion_factor=feed_forward_expansion_factor,
75 | conv_expansion_factor=conv_expansion_factor,
76 | input_dropout_p=input_dropout_p,
77 | feed_forward_dropout_p=feed_forward_dropout_p,
78 | attention_dropout_p=attention_dropout_p,
79 | conv_dropout_p=conv_dropout_p,
80 | conv_kernel_size=conv_kernel_size,
81 | half_step_residual=half_step_residual,
82 | )
83 | self.fc = Linear(encoder_dim, num_classes, bias=False)
84 |
85 | def count_parameters(self) -> int:
86 | """ Count parameters of encoder """
87 | return self.encoder.count_parameters()
88 |
89 | def update_dropout(self, dropout_p) -> None:
90 | """ Update dropout probability of model """
91 | self.encoder.update_dropout(dropout_p)
92 |
93 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
94 | """
95 | Forward propagate a `inputs` and `targets` pair for training.
96 |
97 | Args:
98 | inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
99 | `FloatTensor` of size ``(batch, seq_length, dimension)``.
100 | input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
101 |
102 | Returns:
103 | * predictions (torch.FloatTensor): Result of model predictions.
104 | """
105 | encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths)
106 | outputs = self.fc(encoder_outputs)
107 | outputs = nn.functional.log_softmax(outputs, dim=-1)
108 | return outputs, encoder_output_lengths
109 |
--------------------------------------------------------------------------------
/conformer/conformer/model_def.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch import Tensor
18 | from typing import Tuple
19 |
20 | from .encoder import ConformerEncoder
21 | from .modules import Linear
22 |
23 |
24 | class Conformer(nn.Module):
25 | """
26 | Conformer: Convolution-augmented Transformer for Speech Recognition
27 | The paper used a one-lstm Transducer decoder, currently still only implemented
28 | the conformer encoder shown in the paper.
29 |
30 | Args:
31 | num_classes (int): Number of classification classes
32 | input_dim (int, optional): Dimension of input vector
33 | encoder_dim (int, optional): Dimension of conformer encoder
34 | num_encoder_layers (int, optional): Number of conformer blocks
35 | num_attention_heads (int, optional): Number of attention heads
36 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
37 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
38 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
39 | attention_dropout_p (float, optional): Probability of attention module dropout
40 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout
41 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel
42 | half_step_residual (bool): Flag indication whether to use half step residual or not
43 |
44 | Inputs: inputs, input_lengths
45 | - **inputs** (batch, time, dim): Tensor containing input vector
46 | - **input_lengths** (batch): list of sequence input lengths
47 |
48 | Returns: outputs, output_lengths
49 | - **outputs** (batch, out_channels, time): Tensor produces by conformer.
50 | - **output_lengths** (batch): list of sequence output lengths
51 | """
52 | def __init__(
53 | self,
54 | input_dim: int = 80,
55 | encoder_dim: int = 512,
56 | num_encoder_layers: int = 17,
57 | num_attention_heads: int = 8,
58 | feed_forward_expansion_factor: int = 4,
59 | conv_expansion_factor: int = 2,
60 | input_dropout_p: float = 0.1,
61 | feed_forward_dropout_p: float = 0.1,
62 | attention_dropout_p: float = 0.1,
63 | conv_dropout_p: float = 0.1,
64 | conv_kernel_size: int = 31,
65 | half_step_residual: bool = True,
66 | ) -> None:
67 | super(Conformer, self).__init__()
68 | self.encoder = ConformerEncoder(
69 | input_dim=input_dim,
70 | encoder_dim=encoder_dim,
71 | num_layers=num_encoder_layers,
72 | num_attention_heads=num_attention_heads,
73 | feed_forward_expansion_factor=feed_forward_expansion_factor,
74 | conv_expansion_factor=conv_expansion_factor,
75 | input_dropout_p=input_dropout_p,
76 | feed_forward_dropout_p=feed_forward_dropout_p,
77 | attention_dropout_p=attention_dropout_p,
78 | conv_dropout_p=conv_dropout_p,
79 | conv_kernel_size=conv_kernel_size,
80 | half_step_residual=half_step_residual,
81 | )
82 |
83 | def count_parameters(self) -> int:
84 | """ Count parameters of encoder """
85 | return self.encoder.count_parameters()
86 |
87 | def update_dropout(self, dropout_p) -> None:
88 | """ Update dropout probability of model """
89 | self.encoder.update_dropout(dropout_p)
90 |
91 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
92 | """
93 | Forward propagate a `inputs` and `targets` pair for training.
94 |
95 | Args:
96 | inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
97 | `FloatTensor` of size ``(batch, seq_length, dimension)``.
98 | input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
99 |
100 | Returns:
101 | * predictions (torch.FloatTensor): Result of model predictions.
102 | """
103 | encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths)
104 | return encoder_outputs, encoder_output_lengths
105 |
--------------------------------------------------------------------------------
/conformer/conformer/modules.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.init as init
18 | from torch import Tensor
19 |
20 |
21 | class ResidualConnectionModule(nn.Module):
22 | """
23 | Residual Connection Module.
24 | outputs = (module(inputs) x module_factor + inputs x input_factor)
25 | """
26 | def __init__(self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0):
27 | super(ResidualConnectionModule, self).__init__()
28 | self.module = module
29 | self.module_factor = module_factor
30 | self.input_factor = input_factor
31 |
32 | def forward(self, inputs: Tensor) -> Tensor:
33 | return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
34 |
35 |
36 | class Linear(nn.Module):
37 | """
38 | Wrapper class of torch.nn.Linear
39 | Weight initialize by xavier initialization and bias initialize to zeros.
40 | """
41 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
42 | super(Linear, self).__init__()
43 | self.linear = nn.Linear(in_features, out_features, bias=bias)
44 | init.xavier_uniform_(self.linear.weight)
45 | if bias:
46 | init.zeros_(self.linear.bias)
47 |
48 | def forward(self, x: Tensor) -> Tensor:
49 | return self.linear(x)
50 |
51 |
52 | class View(nn.Module):
53 | """ Wrapper class of torch.view() for Sequential module. """
54 | def __init__(self, shape: tuple, contiguous: bool = False):
55 | super(View, self).__init__()
56 | self.shape = shape
57 | self.contiguous = contiguous
58 |
59 | def forward(self, x: Tensor) -> Tensor:
60 | if self.contiguous:
61 | x = x.contiguous()
62 |
63 | return x.view(*self.shape)
64 |
65 |
66 | class Transpose(nn.Module):
67 | """ Wrapper class of torch.transpose() for Sequential module. """
68 | def __init__(self, shape: tuple):
69 | super(Transpose, self).__init__()
70 | self.shape = shape
71 |
72 | def forward(self, x: Tensor) -> Tensor:
73 | return x.transpose(*self.shape)
74 |
--------------------------------------------------------------------------------
/data-processing/wenetspeech/README.md:
--------------------------------------------------------------------------------
1 | /Path/to/your/wenetspeech
--------------------------------------------------------------------------------
/data-processing/wenetspeech/norm_txt.py:
--------------------------------------------------------------------------------
1 | import os
2 | from multiprocessing import Pool
3 | from tqdm import tqdm
4 | import json
5 | import os
6 |
7 | def find_files_with_suffix(root_folder, endwiths):
8 | matching_files = [] # 以后缀名为键,文件路径列表为值初始化字典
9 | for root, dirs, files in os.walk(root_folder):
10 | for file in files:
11 | for suffix in endwiths:
12 | if file.endswith(suffix):
13 | matching_files.append(os.path.join(root, file))
14 | return matching_files
15 |
16 | import os
17 | import multiprocessing
18 | from tqdm import tqdm
19 | import subprocess
20 | def process_file(file):
21 | output_file = file.replace('.txt', '_norm.txt')
22 | subprocess.run(['python', 'data-processing/wenetspeech/cn_tn.py', file, output_file])
23 |
24 | root_folder = '/path/to/wenetspeech_clips/'
25 | endwiths = ['.txt']
26 | files = find_files_with_suffix(root_folder, endwiths)
27 | with multiprocessing.Pool(32) as pool:
28 | # 使用 tqdm 并行迭代文件列表,并显示进度条
29 | for _ in tqdm(pool.imap_unordered(process_file, files), total=len(files)):
30 | pass
--------------------------------------------------------------------------------
/data-processing/wenetspeech/read.py:
--------------------------------------------------------------------------------
1 | import os
2 | from multiprocessing import Pool
3 | from tqdm import tqdm
4 | import torchaudio
5 | import json
6 |
7 | def process_audio(aidx):
8 | path = os.path.join(data_dir, wenetspeech['audios'][aidx]['path'][:-5] + '.wav')
9 | waveform, sr = torchaudio.load(path)
10 |
11 | for seg in range(len(wenetspeech['audios'][aidx]['segments'])):
12 | sid = wenetspeech['audios'][aidx]['segments'][seg]['sid']
13 | begin_time = wenetspeech['audios'][aidx]['segments'][seg]['begin_time']
14 | end_time = wenetspeech['audios'][aidx]['segments'][seg]['end_time']
15 | subsets = wenetspeech['audios'][aidx]['segments'][seg]['subsets']
16 | text = wenetspeech['audios'][aidx]['segments'][seg]['text']
17 |
18 | if 'M' in subsets and 'S' not in subsets:
19 | save_path = os.path.join("/server24/aizq/wenetspeech_clips/M_S", path.split('/')[-3], path.split('/')[-2], sid) + '.wav'
20 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
21 |
22 | if not os.path.exists(save_path):
23 | torchaudio.save(save_path, waveform[:, int(begin_time * sr):int(end_time * sr)], sample_rate=sr)
24 |
25 | save_path_label = os.path.join("/server24/aizq/wenetspeech_clips/M_S", path.split('/')[-3], path.split('/')[-2], sid) + '.txt'
26 |
27 | if not os.path.exists(save_path_label):
28 | with open(save_path_label, "w") as file:
29 | file.write(text)
30 |
31 | if 'S' in subsets:
32 | save_path = os.path.join("/server24/aizq/wenetspeech_clips/S", path.split('/')[-3], path.split('/')[-2], sid) + '.wav'
33 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
34 |
35 | if not os.path.exists(save_path):
36 | torchaudio.save(save_path, waveform[:, int(begin_time * sr):int(end_time * sr)], sample_rate=sr)
37 |
38 | save_path_label = os.path.join("/server24/aizq/wenetspeech_clips/S", path.split('/')[-3], path.split('/')[-2], sid) + '.txt'
39 |
40 | if not os.path.exists(save_path_label):
41 | with open(save_path_label, "w") as file:
42 | file.write(text)
43 |
44 |
45 | if __name__ == '__main__':
46 | print("读取 WenetSpeech.json")
47 | with open('/server24/aizq/wenetspeech_UNTAR/WenetSpeech.json', 'r') as f: wenetspeech = json.load(f)
48 | print("读取完成正在处理")
49 | data_dir = "/server24/aizq/wenetspeech_UNTAR"
50 | num_processes = 32
51 | with Pool(num_processes) as pool:
52 | for _ in tqdm(pool.imap_unordered(process_audio, range(len(wenetspeech['audios']))), total=len(wenetspeech['audios'])):
53 | pass
54 |
--------------------------------------------------------------------------------
/data-processing/wenetspeech/wenetclip.py:
--------------------------------------------------------------------------------
1 | import os
2 | import multiprocessing
3 | from tqdm import tqdm
4 |
5 |
6 | def find_files_with_suffix(root_folder, endwiths):
7 | matching_files = [] # 以后缀名为键,文件路径列表为值初始化字典
8 | for root, dirs, files in os.walk(root_folder):
9 | for file in files:
10 | for suffix in endwiths:
11 | if file.endswith(suffix):
12 | matching_files.append(os.path.join(root, file))
13 | return matching_files
14 |
15 |
16 | def read_text_file(file_path):
17 | with open(file_path, 'r') as file:
18 | content = file.read().rstrip('\n')
19 | return content
20 |
21 |
22 | # 创建一个包含汉字数量大于等于2的集合
23 | def select_hanzi_greater_than_or_equal_to_2(input_set):
24 | result_set = set()
25 | for item in input_set:
26 | # 判断是否只包含汉字且汉字数量大于等于2
27 | if all('\u4e00' <= char <= '\u9fff' for char in item) and len(item) >= 2 and len(item) <= 6:
28 | result_set.add(item)
29 | return result_set
30 |
31 |
32 | import re
33 |
34 | def find_all_occurrences(text, pattern):
35 | occurrences = [(match.start(), match.end()) for match in re.finditer(pattern, text)]
36 | return occurrences
37 |
38 |
39 | import torch
40 | import torchaudio
41 | from torchaudio.pipelines import MMS_FA as bundle
42 | from typing import List
43 | from g2pM import G2pM # g2pM # g2pW 论文结果更好,但是是台湾人高的,普通话还需要额外做好多操作emm
44 | import jieba
45 | import numpy as np
46 |
47 | def _score(spans):
48 | return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)
49 |
50 |
51 |
52 | class TORCHAUDIO_MFA:
53 | def __init__(self, device_id='0', save_dirs="/server24/aizq/mm_kws/datasets/WenetPhrase/M_S"):
54 | self.device = torch.device(f"cuda:{device_id}" if torch.cuda.is_available() else "cpu")
55 | self.mfa = bundle.get_model()
56 | self.mfa.to(self.device)
57 | self.tokenizer = bundle.get_tokenizer()
58 | self.aligner = bundle.get_aligner()
59 | self.g2pm = G2pM()
60 | self.save_dirs = save_dirs
61 |
62 |
63 | def compute_alignments(self, waveform: torch.Tensor, transcript: List[str]):
64 | with torch.inference_mode():
65 | emission, _ = self.mfa(waveform.to(self.device))
66 | token_spans = self.aligner(emission[0], self.tokenizer(transcript))
67 | return emission, token_spans
68 |
69 |
70 | def make_mfa(self, wav_file: str, sentence: str):
71 | try:
72 | seg_list = select_hanzi_greater_than_or_equal_to_2(list(jieba.cut(sentence, cut_all=True)))
73 | transcript = self.g2pm(sentence, tone=False)
74 | transcript = " ".join(transcript)
75 | # 使用列表推导式替换所有的 "nu"
76 | transcript = transcript.replace('u:', 'v')
77 | transcript = transcript.split()
78 | waveform, sample_rate = torchaudio.load(wav_file)
79 | waveform = waveform[0:1]
80 | emission, token_spans = self.compute_alignments(waveform, transcript)
81 | num_frames = emission.size(1)
82 | ratio = waveform.size(1) / num_frames
83 | for word in seg_list:
84 | result = find_all_occurrences(sentence, word)
85 | for i, (s, e) in enumerate(result):
86 | start = s
87 | end = e - 1
88 | x0 = int(ratio * token_spans[start][0].start)
89 | x1 = int(ratio * token_spans[end][-1].end)
90 | score = np.mean([_score(token_spans[i]) for i in range(start, end+1)])
91 | if score > 0.3: # 设置了平均阈值
92 | save_path = os.path.join(self.save_dirs, word) + "/" + "_".join([wav_file.split('/')[-3], wav_file.split('/')[-2], wav_file.split('/')[-1][:-4], str(i)]) + '.wav'
93 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
94 | torchaudio.save(save_path, waveform[:, x0:x1], sample_rate=sample_rate)
95 | except Exception as e:
96 | print(wav_file)
97 | print(e)
98 |
99 |
100 |
101 | def data_process(
102 | sub_files,
103 | device_id
104 | ):
105 | try:
106 | os.environ['CUDA_VISIBLE_DEVICES'] = f'{device_id}'
107 | model = TORCHAUDIO_MFA(device_id='0') # 使用不同的设备ID
108 | for wav_file in tqdm(sub_files):
109 | txt_file = wav_file.replace('.wav', '_norm.txt')
110 | sentence = read_text_file(txt_file)
111 | model.make_mfa(wav_file, sentence)
112 | del model
113 | except Exception as e:
114 | print(e)
115 |
116 |
117 |
118 | def multiprocess_to_(
119 | todo_list,
120 | num_gpu=[0], # Default to GPU 0 if num_gpu is not provided
121 | num_process=3,
122 | ):
123 | num_available_gpus = len(num_gpu)
124 | with multiprocessing.Pool(processes=num_process) as pool:
125 | for i in range(num_process):
126 | sub_files = todo_list[i::num_process]
127 | device_id = num_gpu[i % num_available_gpus]
128 | pool.apply_async(data_process, args=(sub_files, device_id))
129 | pool.close()
130 | pool.join()
131 |
132 |
133 | if __name__ == "__main__":
134 | target_dirs = "/server24/aizq/wenetspeech_clips/M_S" # S: 151600 & M_S: 1362900
135 | save_dirs = "/server24/aizq/mm_kws/datasets/WenetPhrase/M_S"
136 | todo_list = find_files_with_suffix(target_dirs, '.wav')
137 | print(len(todo_list))
138 | num_gpu=[0, 1, 2, 3, 4, 5]
139 | multiprocess_to_(
140 | todo_list,
141 | num_gpu=num_gpu,
142 | num_process=len(num_gpu) * 5,
143 | )
--------------------------------------------------------------------------------
/dataloaders/SPC_N0_ALL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | import Levenshtein
15 | import re
16 | import random
17 |
18 | import os
19 |
20 | def get_files(path, endswith):
21 | _files = []
22 | for root, dirs, files in os.walk(path):
23 | for file in files:
24 | if file.endswith(endswith):
25 | _files.append(os.path.join(root, file))
26 | return _files
27 |
28 |
29 | class SPC_NO_Dataset(Dataset):
30 | def __init__(
31 | self,
32 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data",
33 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text",
34 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N0_ALL.csv",
35 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle",
36 | preprocess=True,
37 | ):
38 | super().__init__()
39 | if preprocess:
40 | target_dict = {}
41 | idx = 0
42 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Query_label', 'Support_label', 'label'])
43 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list)))
44 | classes = os.listdir(test_dir)
45 | random.shuffle(classes)
46 | Targets = classes[:10]
47 | for wav_idx in range(len(wav_id)):
48 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav'
49 | query_text = wav_id[wav_idx].split('_')[0]
50 | if query_text in Targets:
51 | for comparison_text in Targets:
52 | _label = 1 if comparison_text == query_text else 0
53 | target_dict[idx] = {
54 | 'id': wav_id[wav_idx],
55 | 'Query_text': query_text,
56 | 'Query_wav': wav,
57 | 'Support_text': comparison_text,
58 | 'Query_label': Targets.index(query_text),
59 | 'Support_label': Targets.index(comparison_text),
60 | 'label': _label
61 | }
62 | idx += 1
63 | else:
64 | for comparison_text in Targets:
65 | _label = 0
66 | target_dict[idx] = {
67 | 'id': wav_id[wav_idx],
68 | 'Query_text': query_text,
69 | 'Query_wav': wav,
70 | 'Support_text': comparison_text,
71 | 'Query_label': 10,
72 | 'Support_label': Targets.index(comparison_text),
73 | 'label': 10,
74 | }
75 | idx += 1
76 |
77 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True)
78 | self.data.to_csv(save_path, index=False)
79 | else:
80 | self.data = pd.read_csv(save_path)
81 |
82 | self.data = self.data.values.tolist()
83 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
84 |
85 |
86 | def __getitem__(
87 | self,
88 | index
89 | ):
90 | ids, Query_text, Query_wav, Support_text, Query_label, Support_label, label = self.data[index]
91 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank
92 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
93 | lm_embed = self.text_embedder[Support_text]['lm_embed']
94 | fbank_feature = fbank(
95 | Query_wav,
96 | num_mel_bins=80
97 | )
98 | g2p_embed = torch.from_numpy(g2p_embed)
99 | g2p_embed = g2p_embed.type_as(fbank_feature)
100 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label)
101 |
102 | def __len__(
103 | self
104 | ):
105 | return len(self.data)
106 |
107 |
108 |
109 |
110 |
111 | def collate_fn(batch):
112 | ids, fbank_feature, g2p_embed, lm_embed, Query_label, Support_label, label = zip(*batch)
113 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
114 | lengths = [len(seq) for seq in padded_fbank_feature]
115 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
116 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
117 | label_tensor = torch.tensor(label)
118 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, Query_label, Support_label, label_tensor
119 |
120 |
121 |
122 | spc = SPC_NO_Dataset()
123 | spc[0]
--------------------------------------------------------------------------------
/dataloaders/SPC_N0_TARGET.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 |
13 | import torch.nn.functional as F
14 | from torch.nn.utils.rnn import pad_sequence
15 | import Levenshtein
16 | import re
17 | import random
18 |
19 | import os
20 |
21 | def get_files(path, endswith):
22 | _files = []
23 | for root, dirs, files in os.walk(path):
24 | for file in files:
25 | if file.endswith(endswith):
26 | _files.append(os.path.join(root, file))
27 | return _files
28 |
29 |
30 | class SPC_NO_Dataset(Dataset):
31 | def __init__(
32 | self,
33 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data",
34 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text",
35 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N0_TARGET.csv",
36 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle",
37 | preprocess=True,
38 | ):
39 | super().__init__()
40 | if preprocess:
41 | target_dict = {}
42 | idx = 0
43 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Query_label', 'Support_label', 'label'])
44 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list)))
45 | classes = os.listdir(test_dir)
46 | random.shuffle(classes)
47 | Targets = classes[:10]
48 | for wav_idx in range(len(wav_id)):
49 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav'
50 | query_text = wav_id[wav_idx].split('_')[0]
51 | if query_text in Targets:
52 | for comparison_text in Targets:
53 | _label = 1 if comparison_text == query_text else 0
54 | target_dict[idx] = {
55 | 'id': wav_id[wav_idx],
56 | 'Query_text': query_text,
57 | 'Query_wav': wav,
58 | 'Support_text': comparison_text,
59 | 'Query_label': Targets.index(query_text),
60 | 'Support_label': Targets.index(comparison_text),
61 | 'label': _label
62 | }
63 | idx += 1
64 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True)
65 | self.data.to_csv(save_path, index=False)
66 | else:
67 | self.data = pd.read_csv(save_path)
68 |
69 | self.data = self.data.values.tolist()
70 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
71 |
72 |
73 | def __getitem__(
74 | self,
75 | index
76 | ):
77 | ids, Query_text, Query_wav, Support_text, Query_label, Support_label, label = self.data[index]
78 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank
79 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
80 | lm_embed = self.text_embedder[Support_text]['lm_embed']
81 | fbank_feature = fbank(
82 | Query_wav,
83 | num_mel_bins=80
84 | )
85 | g2p_embed = torch.from_numpy(g2p_embed)
86 | g2p_embed = g2p_embed.type_as(fbank_feature)
87 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label)
88 |
89 | def __len__(
90 | self
91 | ):
92 | return len(self.data)
93 |
94 |
95 |
96 |
97 |
98 | def collate_fn(batch):
99 | ids, fbank_feature, g2p_embed, lm_embed, Query_label, Support_label, label = zip(*batch)
100 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
101 | lengths = [len(seq) for seq in padded_fbank_feature]
102 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
103 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
104 | label_tensor = torch.tensor(label)
105 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, Query_label, Support_label, label_tensor
106 |
107 |
108 |
109 | # spc = SPC_NO_Dataset()
110 |
--------------------------------------------------------------------------------
/dataloaders/SPC_N1_ALL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | # import dataaug
13 | import torch.nn.functional as F
14 | from torch.nn.utils.rnn import pad_sequence
15 | import Levenshtein
16 | import re
17 | import random
18 |
19 | import os
20 |
21 | def get_files(path, endswith):
22 | _files = []
23 | for root, dirs, files in os.walk(path):
24 | for file in files:
25 | if file.endswith(endswith):
26 | _files.append(os.path.join(root, file))
27 | return _files
28 | from tqdm import tqdm
29 |
30 | class SPC_N1_Dataset(Dataset):
31 | def __init__(
32 | self,
33 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data",
34 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text",
35 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N1_ALL.csv",
36 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle",
37 | preprocess=True,
38 | ):
39 | super().__init__()
40 | if preprocess:
41 | target_dict = {}
42 | idx = 0
43 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Support_wav', 'Query_label', 'Support_label', 'label'])
44 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list)))
45 | classes = os.listdir(test_dir)
46 | random.shuffle(classes)
47 | Targets = classes[:10]
48 | supports_wavs = {}
49 | for comparison_text in Targets:
50 | supports_wav = get_files(os.path.join(test_dir, comparison_text), '_18.npy')
51 | random.shuffle(supports_wav)
52 | supports_wavs[comparison_text] = supports_wav
53 | for wav_idx in range(len(wav_id)):
54 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav'
55 | query_text = wav_id[wav_idx].split('_')[0]
56 | if query_text in Targets:
57 | for comparison_text in Targets:
58 | support_wav = random.choices(supports_wavs[comparison_text], k=2)
59 | _label = 1 if comparison_text == query_text else 0
60 | target_dict[idx] = {
61 | 'id': wav_id[wav_idx],
62 | 'Query_text': query_text,
63 | 'Query_wav': wav,
64 | 'Support_text': comparison_text,
65 | 'Support_wav': support_wav[0],
66 | 'Query_label': Targets.index(query_text),
67 | 'Support_label': Targets.index(comparison_text),
68 | 'label': _label
69 | }
70 | idx += 1
71 | else:
72 | for comparison_text in Targets:
73 | support_wav = random.choices(supports_wavs[comparison_text], k=5)
74 | _label = 0
75 | target_dict[idx] = {
76 | 'id': wav_id[wav_idx],
77 | 'Query_text': query_text,
78 | 'Query_wav': wav,
79 | 'Support_text': comparison_text,
80 | 'Support_wav': support_wav[0],
81 | 'Query_label': 10,
82 | 'Support_label': Targets.index(comparison_text),
83 | 'label': 10,
84 | }
85 | idx += 1
86 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True)
87 | self.data.to_csv(save_path, index=False)
88 | else:
89 | self.data = pd.read_csv(save_path)
90 |
91 | self.data = self.data.values.tolist()
92 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
93 |
94 |
95 | def __getitem__(
96 | self,
97 | index
98 | ):
99 | ids, Query_text, Query_wav, Support_text, Support_wav, Query_label, Support_label, label = self.data[index]
100 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank
101 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
102 | lm_embed = self.text_embedder[Support_text]['lm_embed']
103 | audiolm_embed = np.load(Support_wav)
104 | fbank_feature = fbank(
105 | Query_wav,
106 | num_mel_bins=80
107 | )
108 | g2p_embed = torch.from_numpy(g2p_embed)
109 | g2p_embed = g2p_embed.type_as(fbank_feature)
110 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label)
111 |
112 | def __len__(
113 | self
114 | ):
115 | return len(self.data)
116 |
117 |
118 |
119 |
120 |
121 | def collate_fn(batch):
122 | ids, fbank_feature, g2p_embed, lm_embed, audiolm_embed, Query_label, Support_label, label = zip(*batch)
123 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
124 | lengths = [len(seq) for seq in padded_fbank_feature]
125 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
126 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
127 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0)
128 | label_tensor = torch.tensor(label)
129 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, Query_label, Support_label, label_tensor
130 |
131 |
132 |
133 | spc = SPC_N1_Dataset()
134 | # spc[0]
--------------------------------------------------------------------------------
/dataloaders/SPC_N1_TARGET.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | import Levenshtein
15 | import re
16 | import random
17 |
18 | import os
19 |
20 | def get_files(path, endswith):
21 | _files = []
22 | for root, dirs, files in os.walk(path):
23 | for file in files:
24 | if file.endswith(endswith):
25 | _files.append(os.path.join(root, file))
26 | return _files
27 | from tqdm import tqdm
28 |
29 | class SPC_N1_Dataset(Dataset):
30 | def __init__(
31 | self,
32 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data",
33 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text",
34 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N1_TARGET.csv",
35 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle",
36 | preprocess=True,
37 | ):
38 | super().__init__()
39 | if preprocess:
40 | target_dict = {}
41 | idx = 0
42 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Support_wav', 'Query_label', 'Support_label', 'label'])
43 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list)))
44 | classes = os.listdir(test_dir)
45 | random.shuffle(classes)
46 | Targets = classes[:10]
47 | # Targets = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']
48 | supports_wavs = {}
49 | for comparison_text in Targets:
50 | supports_wav = get_files(os.path.join(test_dir, comparison_text), '_18.npy')
51 | random.shuffle(supports_wav)
52 | supports_wavs[comparison_text] = supports_wav
53 | for wav_idx in range(len(wav_id)):
54 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav'
55 | query_text = wav_id[wav_idx].split('_')[0]
56 | if query_text in Targets:
57 | for comparison_text in Targets:
58 | _label = 1 if comparison_text == query_text else 0
59 | target_dict[idx] = {
60 | 'id': wav_id[wav_idx],
61 | 'Query_text': query_text,
62 | 'Query_wav': wav,
63 | 'Support_text': comparison_text,
64 | 'Support_wav': supports_wavs[comparison_text][0],
65 | 'Query_label': Targets.index(query_text),
66 | 'Support_label': Targets.index(comparison_text),
67 | 'label': _label
68 | }
69 | idx += 1
70 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True)
71 | self.data.to_csv(save_path, index=False)
72 | else:
73 | self.data = pd.read_csv(save_path)
74 |
75 | self.data = self.data.values.tolist()
76 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
77 |
78 |
79 | def __getitem__(
80 | self,
81 | index
82 | ):
83 | ids, Query_text, Query_wav, Support_text, Support_wav, Query_label, Support_label, label = self.data[index]
84 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank
85 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
86 | lm_embed = self.text_embedder[Support_text]['lm_embed']
87 | audiolm_embed = np.load(Support_wav)
88 | fbank_feature = fbank(
89 | Query_wav,
90 | num_mel_bins=80
91 | )
92 | g2p_embed = torch.from_numpy(g2p_embed)
93 | g2p_embed = g2p_embed.type_as(fbank_feature)
94 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label)
95 |
96 | def __len__(
97 | self
98 | ):
99 | return len(self.data)
100 |
101 |
102 |
103 |
104 |
105 | def collate_fn(batch):
106 | ids, fbank_feature, g2p_embed, lm_embed, audiolm_embed, Query_label, Support_label, label = zip(*batch)
107 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
108 | lengths = [len(seq) for seq in padded_fbank_feature]
109 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
110 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
111 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0)
112 | label_tensor = torch.tensor(label)
113 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, Query_label, Support_label, label_tensor
114 |
115 |
116 |
117 | spc = SPC_N1_Dataset()
118 | # spc[0]
--------------------------------------------------------------------------------
/dataloaders/__pycache__/SPC_N0_ALL.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/SPC_N0_ALL.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/SPC_N0_TARGET.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/SPC_N0_TARGET.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/SPC_N1_ALL.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/SPC_N1_ALL.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/SPC_N1_TARGET.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/SPC_N1_TARGET.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/libriphrase_test.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/libriphrase_test.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/libriphrase_test_18.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/libriphrase_test_18.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/libriphrase_train.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/libriphrase_train.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/libriphrase_trainY.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/libriphrase_trainY.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/libriphrase_train_18.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/__pycache__/libriphrase_train_18.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/g2p/g2p_en/__init__.py:
--------------------------------------------------------------------------------
1 | from .g2p import G2p
2 |
--------------------------------------------------------------------------------
/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-39.pyc
--------------------------------------------------------------------------------
/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-310.pyc
--------------------------------------------------------------------------------
/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-39.pyc
--------------------------------------------------------------------------------
/dataloaders/g2p/g2p_en/checkpoint20.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/g2p_en/checkpoint20.npz
--------------------------------------------------------------------------------
/dataloaders/g2p/g2p_en/expand.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | Borrowed
5 | from https://github.com/keithito/tacotron/blob/master/text/numbers.py
6 | By kyubyong park. kbpark.linguist@gmail.com.
7 | https://www.github.com/kyubyong/g2p
8 | '''
9 | from __future__ import print_function
10 | import inflect
11 | import re
12 |
13 |
14 |
15 | _inflect = inflect.engine()
16 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
17 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
18 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
19 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
20 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
21 | _number_re = re.compile(r'[0-9]+')
22 |
23 |
24 | def _remove_commas(m):
25 | return m.group(1).replace(',', '')
26 |
27 |
28 | def _expand_decimal_point(m):
29 | return m.group(1).replace('.', ' point ')
30 |
31 |
32 | def _expand_dollars(m):
33 | match = m.group(1)
34 | parts = match.split('.')
35 | if len(parts) > 2:
36 | return match + ' dollars' # Unexpected format
37 | dollars = int(parts[0]) if parts[0] else 0
38 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
39 | if dollars and cents:
40 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
41 | cent_unit = 'cent' if cents == 1 else 'cents'
42 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
43 | elif dollars:
44 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
45 | return '%s %s' % (dollars, dollar_unit)
46 | elif cents:
47 | cent_unit = 'cent' if cents == 1 else 'cents'
48 | return '%s %s' % (cents, cent_unit)
49 | else:
50 | return 'zero dollars'
51 |
52 |
53 | def _expand_ordinal(m):
54 | return _inflect.number_to_words(m.group(0))
55 |
56 |
57 | def _expand_number(m):
58 | num = int(m.group(0))
59 | if num > 1000 and num < 3000:
60 | if num == 2000:
61 | return 'two thousand'
62 | elif num > 2000 and num < 2010:
63 | return 'two thousand ' + _inflect.number_to_words(num % 100)
64 | elif num % 100 == 0:
65 | return _inflect.number_to_words(num // 100) + ' hundred'
66 | else:
67 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
68 | else:
69 | return _inflect.number_to_words(num, andword='')
70 |
71 |
72 | def normalize_numbers(text):
73 | text = re.sub(_comma_number_re, _remove_commas, text)
74 | text = re.sub(_pounds_re, r'\1 pounds', text)
75 | text = re.sub(_dollars_re, _expand_dollars, text)
76 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
77 | text = re.sub(_ordinal_re, _expand_ordinal, text)
78 | text = re.sub(_number_re, _expand_number, text)
79 | return text
80 |
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_0/events.out.tfevents.1703650022.great-server24.1715931.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_0/events.out.tfevents.1703650022.great-server24.1715931.0
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_0/hparams.yaml:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_1/events.out.tfevents.1703650063.great-server24.1718627.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_1/events.out.tfevents.1703650063.great-server24.1718627.0
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_1/hparams.yaml:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_2/events.out.tfevents.1703650148.great-server24.1718627.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_2/events.out.tfevents.1703650148.great-server24.1718627.1
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_2/hparams.yaml:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_3/events.out.tfevents.1703651214.great-server24.1769388.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_3/events.out.tfevents.1703651214.great-server24.1769388.0
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_3/hparams.yaml:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_4/events.out.tfevents.1703651335.great-server24.1769388.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_4/events.out.tfevents.1703651335.great-server24.1769388.1
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_4/hparams.yaml:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_5/events.out.tfevents.1703651958.great-server24.1794497.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_5/events.out.tfevents.1703651958.great-server24.1794497.0
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_5/hparams.yaml:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_6/events.out.tfevents.1703652071.great-server24.1794497.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/dataloaders/g2p/lightning_logs/version_6/events.out.tfevents.1703652071.great-server24.1794497.1
--------------------------------------------------------------------------------
/dataloaders/g2p/lightning_logs/version_6/hparams.yaml:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/dataloaders/libriphrase_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | import re
15 |
16 | class LibriPhrase_Test_Dataset(Dataset):
17 | def __init__(
18 | self,
19 | test_dir="/nvme01/aizq/mmkws/datasets/LibriPhrase_Test",
20 | csv=[
21 | "libriphrase_diffspk_all_1word.csv",
22 | "libriphrase_diffspk_all_2word.csv",
23 | "libriphrase_diffspk_all_3word.csv",
24 | "libriphrase_diffspk_all_4word.csv"
25 | ],
26 | save_path='/nvme01/aizq/mmkws/datasets/LibriPhrase_Test/test_phrase.csv',
27 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/test_text_embeddings.pickle",
28 | preprocess=False,
29 | types='easy'
30 | ):
31 | if preprocess:
32 | self.data = pd.DataFrame(columns=['Query_text', 'Query_wav', 'Query_dur', 'Support_text', 'Support_wav', 'Support_dur', 'label', 'type'])
33 | for path in csv:
34 | n_word = os.path.join(test_dir, path)
35 | df = pd.read_csv(n_word)
36 | anc = df[['anchor_text', 'anchor', 'anchor_dur', 'comparison_text', 'comparison', 'comparison_dur', 'target', 'type']]
37 | com = df[['comparison_text', 'comparison', 'comparison_dur', 'anchor_text', 'anchor', 'anchor_dur', 'target', 'type']]
38 | self.data = self.data._append(anc.rename(columns={y: x for x, y in zip(self.data.columns, anc.columns)}), ignore_index=True)
39 | self.data = self.data._append(com.rename(columns={y: x for x, y in zip(self.data.columns, com.columns)}), ignore_index=True)
40 | self.data.to_csv(save_path, index=False)
41 | else:
42 | self.data = pd.read_csv(save_path)
43 | # print(self.data)/
44 | # self.data['dist'] = self.data.apply(lambda x: Levenshtein.ratio(re.sub(r"[^a-zA-Z0-9]+", ' ', x['Support_text']), re.sub(r"[^a-zA-Z0-9]+", ' ', x['Query_text'])), axis=1)
45 | if types == 'easy':
46 | self.data = self.data.loc[self.data['type'].isin(['diffspk_easyneg', 'diffspk_positive'])]
47 | elif types == 'hard':
48 | self.data = self.data.loc[self.data['type'].isin(['diffspk_hardneg', 'diffspk_positive'])]
49 |
50 |
51 | self.data = self.data.values.tolist()
52 | # self.data = self.data[:1000]
53 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
54 | self.test_dir = test_dir
55 |
56 | def __getitem__(
57 | self,
58 | index
59 | ):
60 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label
61 | Query_text, Query_wav, _, Support_text, Support_wav, _, label, _ = self.data[index]
62 | # print(Query_text, Query_wav, Support_text, Support_wav, label)
63 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank
64 | phoneme = self.text_embedder[Support_text]['phoneme']
65 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
66 | lm_embed = self.text_embedder[Support_text]['lm_embed']
67 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '.npy')
68 | fbank_feature = fbank(
69 | Query_wav,
70 | num_mel_bins=80
71 | )
72 | g2p_embed = torch.from_numpy(g2p_embed)
73 | g2p_embed = g2p_embed.type_as(fbank_feature)
74 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label)
75 |
76 |
77 | def __len__(
78 | self
79 | ):
80 | return len(self.data)
81 |
82 |
83 | def collate_fn(batch):
84 | # 将 batch 中的每个样本按照其数据类型分组
85 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch)
86 | # 对每个特征进行填充
87 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
88 | lengths = [len(seq) for seq in padded_fbank_feature]
89 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
90 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
91 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0)
92 | # 对 label 进行转换为 Tensor
93 | label_tensor = torch.tensor(label)
94 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor
95 |
96 |
97 |
98 |
99 |
100 | # test_data = LibriPhrase_Test_Dataset()
101 | # dataloader = DataLoader(test_data, batch_size=128, collate_fn=collate_fn, num_workers=16, shuffle=False)
102 | # from tqdm import tqdm
103 | # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
104 | # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data
105 | # print(dist_tensor)
106 | # break
107 | # # pass
108 | # pass
--------------------------------------------------------------------------------
/dataloaders/libriphrase_test_18.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | import re
15 |
16 | class LibriPhrase_Test_Dataset(Dataset):
17 | def __init__(
18 | self,
19 | test_dir="/nvme01/aizq/mmkws/datasets/LibriPhrase_Test",
20 | csv=[
21 | "libriphrase_diffspk_all_1word.csv",
22 | "libriphrase_diffspk_all_2word.csv",
23 | "libriphrase_diffspk_all_3word.csv",
24 | "libriphrase_diffspk_all_4word.csv"
25 | ],
26 | save_path='/nvme01/aizq/mmkws/datasets/LibriPhrase_Test/test_phrase.csv',
27 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/test_text_embeddings.pickle",
28 | preprocess=False,
29 | types='easy'
30 | ):
31 | if preprocess:
32 | self.data = pd.DataFrame(columns=['Query_text', 'Query_wav', 'Query_dur', 'Support_text', 'Support_wav', 'Support_dur', 'label', 'type'])
33 | for path in csv:
34 | n_word = os.path.join(test_dir, path)
35 | df = pd.read_csv(n_word)
36 | anc = df[['anchor_text', 'anchor', 'anchor_dur', 'comparison_text', 'comparison', 'comparison_dur', 'target', 'type']]
37 | com = df[['comparison_text', 'comparison', 'comparison_dur', 'anchor_text', 'anchor', 'anchor_dur', 'target', 'type']]
38 | self.data = self.data._append(anc.rename(columns={y: x for x, y in zip(self.data.columns, anc.columns)}), ignore_index=True)
39 | self.data = self.data._append(com.rename(columns={y: x for x, y in zip(self.data.columns, com.columns)}), ignore_index=True)
40 | self.data.to_csv(save_path, index=False)
41 | else:
42 | self.data = pd.read_csv(save_path)
43 | # print(self.data)/
44 | # self.data['dist'] = self.data.apply(lambda x: Levenshtein.ratio(re.sub(r"[^a-zA-Z0-9]+", ' ', x['Support_text']), re.sub(r"[^a-zA-Z0-9]+", ' ', x['Query_text'])), axis=1)
45 | if types == 'easy':
46 | self.data = self.data.loc[self.data['type'].isin(['diffspk_easyneg', 'diffspk_positive'])]
47 | elif types == 'hard':
48 | self.data = self.data.loc[self.data['type'].isin(['diffspk_hardneg', 'diffspk_positive'])]
49 |
50 |
51 | self.data = self.data.values.tolist()
52 | # self.data = self.data[:1000]
53 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
54 | self.test_dir = test_dir
55 |
56 | def __getitem__(
57 | self,
58 | index
59 | ):
60 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label
61 | Query_text, Query_wav, _, Support_text, Support_wav, _, label, _ = self.data[index]
62 | # print(Query_text, Query_wav, Support_text, Support_wav, label)
63 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank
64 | phoneme = self.text_embedder[Support_text]['phoneme']
65 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
66 | lm_embed = self.text_embedder[Support_text]['lm_embed']
67 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '_18.npy')
68 | fbank_feature = fbank(
69 | Query_wav,
70 | num_mel_bins=80
71 | )
72 | g2p_embed = torch.from_numpy(g2p_embed)
73 | g2p_embed = g2p_embed.type_as(fbank_feature)
74 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label)
75 |
76 |
77 | def __len__(
78 | self
79 | ):
80 | return len(self.data)
81 |
82 |
83 | def collate_fn(batch):
84 | # 将 batch 中的每个样本按照其数据类型分组
85 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch)
86 | # 对每个特征进行填充
87 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
88 | lengths = [len(seq) for seq in padded_fbank_feature]
89 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
90 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
91 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0)
92 | # 对 label 进行转换为 Tensor
93 | label_tensor = torch.tensor(label)
94 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor
95 |
96 |
97 |
98 |
99 |
100 | # test_data = LibriPhrase_Test_Dataset()
101 | # dataloader = DataLoader(test_data, batch_size=128, collate_fn=collate_fn, num_workers=16, shuffle=False)
102 | # from tqdm import tqdm
103 | # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
104 | # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data
105 | # print(dist_tensor)
106 | # break
107 | # # pass
108 | # pass
--------------------------------------------------------------------------------
/dataloaders/wenetphrase_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | import re
15 |
16 | class WenetPhrase_Test_Dataset(Dataset):
17 | def __init__(
18 | self,
19 | test_dir="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/WenetPhrase2/S",
20 | csv="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/wenetphrase_test.csv",
21 | test_text_embedding="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/zh_test_text_embeddings.pickle",
22 | types='easy'
23 | ):
24 | self.data = pd.read_csv(csv)
25 | if types == 'easy':
26 | self.data = self.data.loc[self.data['type'].isin(['easy', 'pos'])]
27 | elif types == 'hard':
28 | self.data = self.data.loc[self.data['type'].isin(['hard', 'pos'])]
29 | self.data = self.data.values.tolist()
30 | print(len(self.data))
31 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
32 | self.test_dir = test_dir
33 |
34 | def __getitem__(
35 | self,
36 | index
37 | ):
38 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label
39 | Query_text, Query_wav, Support_text, Support_wav, label, _ = self.data[index]
40 | # print(Query_text, Query_wav, Support_text, Support_wav, label)
41 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank
42 | phoneme = self.text_embedder[Support_text]['phoneme']
43 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
44 | lm_embed = self.text_embedder[Support_text]['lm_embed']
45 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '_18.npy')
46 | fbank_feature = fbank(
47 | Query_wav,
48 | num_mel_bins=80
49 | )
50 | g2p_embed = torch.from_numpy(g2p_embed)
51 | g2p_embed = g2p_embed.type_as(fbank_feature)
52 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label)
53 |
54 |
55 | def __len__(
56 | self
57 | ):
58 | return len(self.data)
59 |
60 |
61 | def collate_fn(batch):
62 | # 将 batch 中的每个样本按照其数据类型分组
63 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch)
64 | # 对每个特征进行填充
65 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
66 | lengths = [len(seq) for seq in padded_fbank_feature]
67 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
68 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
69 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0)
70 | # 对 label 进行转换为 Tensor
71 | label_tensor = torch.tensor(label)
72 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor
73 |
74 |
75 |
76 |
77 |
78 | # test_data = WenetPhrase_Test_Dataset()
79 | # print(test_data[0])
80 | # # dataloader = DataLoader(test_data, batch_size=2, collate_fn=collate_fn, num_workers=1, shuffle=False)
81 | # # from tqdm import tqdm
82 | # # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
83 | # # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data
84 | # # print(dist_tensor)
85 | # # break
86 | # # # pass
87 | # # pass
--------------------------------------------------------------------------------
/g2p/g2p_en/__init__.py:
--------------------------------------------------------------------------------
1 | from .g2p import G2p
2 |
--------------------------------------------------------------------------------
/g2p/g2p_en/checkpoint20.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/g2p/g2p_en/checkpoint20.npz
--------------------------------------------------------------------------------
/g2p/g2p_en/expand.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | Borrowed
5 | from https://github.com/keithito/tacotron/blob/master/text/numbers.py
6 | By kyubyong park. kbpark.linguist@gmail.com.
7 | https://www.github.com/kyubyong/g2p
8 | '''
9 | from __future__ import print_function
10 | import inflect
11 | import re
12 |
13 |
14 |
15 | _inflect = inflect.engine()
16 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
17 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
18 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
19 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
20 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
21 | _number_re = re.compile(r'[0-9]+')
22 |
23 |
24 | def _remove_commas(m):
25 | return m.group(1).replace(',', '')
26 |
27 |
28 | def _expand_decimal_point(m):
29 | return m.group(1).replace('.', ' point ')
30 |
31 |
32 | def _expand_dollars(m):
33 | match = m.group(1)
34 | parts = match.split('.')
35 | if len(parts) > 2:
36 | return match + ' dollars' # Unexpected format
37 | dollars = int(parts[0]) if parts[0] else 0
38 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
39 | if dollars and cents:
40 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
41 | cent_unit = 'cent' if cents == 1 else 'cents'
42 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
43 | elif dollars:
44 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
45 | return '%s %s' % (dollars, dollar_unit)
46 | elif cents:
47 | cent_unit = 'cent' if cents == 1 else 'cents'
48 | return '%s %s' % (cents, cent_unit)
49 | else:
50 | return 'zero dollars'
51 |
52 |
53 | def _expand_ordinal(m):
54 | return _inflect.number_to_words(m.group(0))
55 |
56 |
57 | def _expand_number(m):
58 | num = int(m.group(0))
59 | if num > 1000 and num < 3000:
60 | if num == 2000:
61 | return 'two thousand'
62 | elif num > 2000 and num < 2010:
63 | return 'two thousand ' + _inflect.number_to_words(num % 100)
64 | elif num % 100 == 0:
65 | return _inflect.number_to_words(num // 100) + ' hundred'
66 | else:
67 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
68 | else:
69 | return _inflect.number_to_words(num, andword='')
70 |
71 |
72 | def normalize_numbers(text):
73 | text = re.sub(_comma_number_re, _remove_commas, text)
74 | text = re.sub(_pounds_re, r'\1 pounds', text)
75 | text = re.sub(_dollars_re, _expand_dollars, text)
76 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
77 | text = re.sub(_ordinal_re, _expand_ordinal, text)
78 | text = re.sub(_number_re, _expand_number, text)
79 | return text
80 |
--------------------------------------------------------------------------------
/libriphrase_hardneg.json.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/libriphrase_hardneg.json.zip
--------------------------------------------------------------------------------
/mm-kws/README.md:
--------------------------------------------------------------------------------
1 | # MM-KWS: Multi-modal Prompts for Multilingual User-defined Keyword Spotting(Updating)
2 |
3 | ### Note
4 | 1. Code for the paper 'MM-KWS: Multi-modal Prompts for Multilingual User-defined Keyword Spotting', Interspeech 2024 accepted
5 | 2. Arxiv: https://arxiv.org/pdf/2406.07310
6 | 3. The code-version1
7 | 4. WenetPrase hardneg-data & Libriphrase hardneg-data
8 | 5. DataAug data(todo)
9 |
10 | ---
11 | ### Performance
12 | #### 1.1 Performance on LibriPhrase
13 |
14 |
15 | #### 1.2 Performance on WenetPhrase
16 |
17 |
18 | #### 2.Zero-shot performance on Speech Command
19 |
20 |
21 | #### 3. Few-shot performance on wake-up word(snips)
22 | Note: I wrote about the wake word experiment for my master thesis, which was deleted from the Interspeech manuscript due to space limit.
23 |
24 |
25 |
26 | ---
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/mm-kws/conformer/conformer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .model import Conformer
16 |
--------------------------------------------------------------------------------
/mm-kws/conformer/conformer/activation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch.nn as nn
16 | from torch import Tensor
17 |
18 |
19 | class Swish(nn.Module):
20 | """
21 | Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
22 | to a variety of challenging domains such as Image classification and Machine translation.
23 | """
24 | def __init__(self):
25 | super(Swish, self).__init__()
26 |
27 | def forward(self, inputs: Tensor) -> Tensor:
28 | return inputs * inputs.sigmoid()
29 |
30 |
31 | class GLU(nn.Module):
32 | """
33 | The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
34 | in the paper “Language Modeling with Gated Convolutional Networks”
35 | """
36 | def __init__(self, dim: int) -> None:
37 | super(GLU, self).__init__()
38 | self.dim = dim
39 |
40 | def forward(self, inputs: Tensor) -> Tensor:
41 | outputs, gate = inputs.chunk(2, dim=self.dim)
42 | return outputs * gate.sigmoid()
43 |
--------------------------------------------------------------------------------
/mm-kws/conformer/conformer/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from torch import Tensor
20 | from typing import Optional
21 |
22 | from .embedding import PositionalEncoding
23 | from .modules import Linear
24 |
25 |
26 | class RelativeMultiHeadAttention(nn.Module):
27 | """
28 | Multi-head attention with relative positional encoding.
29 | This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
30 |
31 | Args:
32 | d_model (int): The dimension of model
33 | num_heads (int): The number of attention heads.
34 | dropout_p (float): probability of dropout
35 |
36 | Inputs: query, key, value, pos_embedding, mask
37 | - **query** (batch, time, dim): Tensor containing query vector
38 | - **key** (batch, time, dim): Tensor containing key vector
39 | - **value** (batch, time, dim): Tensor containing value vector
40 | - **pos_embedding** (batch, time, dim): Positional embedding tensor
41 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
42 |
43 | Returns:
44 | - **outputs**: Tensor produces by relative multi head attention module.
45 | """
46 | def __init__(
47 | self,
48 | d_model: int = 512,
49 | num_heads: int = 16,
50 | dropout_p: float = 0.1,
51 | ):
52 | super(RelativeMultiHeadAttention, self).__init__()
53 | assert d_model % num_heads == 0, "d_model % num_heads should be zero."
54 | self.d_model = d_model
55 | self.d_head = int(d_model / num_heads)
56 | self.num_heads = num_heads
57 | self.sqrt_dim = math.sqrt(d_model)
58 |
59 | self.query_proj = Linear(d_model, d_model)
60 | self.key_proj = Linear(d_model, d_model)
61 | self.value_proj = Linear(d_model, d_model)
62 | self.pos_proj = Linear(d_model, d_model, bias=False)
63 |
64 | self.dropout = nn.Dropout(p=dropout_p)
65 | self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
66 | self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
67 | torch.nn.init.xavier_uniform_(self.u_bias)
68 | torch.nn.init.xavier_uniform_(self.v_bias)
69 |
70 | self.out_proj = Linear(d_model, d_model)
71 |
72 | def forward(
73 | self,
74 | query: Tensor,
75 | key: Tensor,
76 | value: Tensor,
77 | pos_embedding: Tensor,
78 | mask: Optional[Tensor] = None,
79 | ) -> Tensor:
80 | batch_size = value.size(0)
81 |
82 | query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
83 | key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
84 | value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
85 | pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
86 |
87 | content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
88 | pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
89 | pos_score = self._relative_shift(pos_score)
90 |
91 | score = (content_score + pos_score) / self.sqrt_dim
92 |
93 | if mask is not None:
94 | mask = mask.unsqueeze(1)
95 | score.masked_fill_(mask, -1e9)
96 |
97 | attn = F.softmax(score, -1)
98 | attn = self.dropout(attn)
99 |
100 | context = torch.matmul(attn, value).transpose(1, 2)
101 | context = context.contiguous().view(batch_size, -1, self.d_model)
102 |
103 | return self.out_proj(context)
104 |
105 | def _relative_shift(self, pos_score: Tensor) -> Tensor:
106 | batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
107 | zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
108 | padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
109 |
110 | padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
111 | pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
112 |
113 | return pos_score
114 |
115 |
116 | class MultiHeadedSelfAttentionModule(nn.Module):
117 | """
118 | Conformer employ multi-headed self-attention (MHSA) while integrating an important technique from Transformer-XL,
119 | the relative sinusoidal positional encoding scheme. The relative positional encoding allows the self-attention
120 | module to generalize better on different input length and the resulting encoder is more robust to the variance of
121 | the utterance length. Conformer use prenorm residual units with dropout which helps training
122 | and regularizing deeper models.
123 |
124 | Args:
125 | d_model (int): The dimension of model
126 | num_heads (int): The number of attention heads.
127 | dropout_p (float): probability of dropout
128 |
129 | Inputs: inputs, mask
130 | - **inputs** (batch, time, dim): Tensor containing input vector
131 | - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
132 |
133 | Returns:
134 | - **outputs** (batch, time, dim): Tensor produces by relative multi headed self attention module.
135 | """
136 | def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1):
137 | super(MultiHeadedSelfAttentionModule, self).__init__()
138 | self.positional_encoding = PositionalEncoding(d_model)
139 | self.layer_norm = nn.LayerNorm(d_model)
140 | self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p)
141 | self.dropout = nn.Dropout(p=dropout_p)
142 |
143 | def forward(self, inputs: Tensor, mask: Optional[Tensor] = None):
144 | batch_size, seq_length, _ = inputs.size()
145 | pos_embedding = self.positional_encoding(seq_length)
146 | pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
147 |
148 | inputs = self.layer_norm(inputs)
149 | outputs = self.attention(inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask)
150 |
151 | return self.dropout(outputs)
152 |
--------------------------------------------------------------------------------
/mm-kws/conformer/conformer/convolution.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch import Tensor
18 | from typing import Tuple
19 |
20 | from .activation import Swish, GLU
21 | from .modules import Transpose
22 |
23 |
24 | class DepthwiseConv1d(nn.Module):
25 | """
26 | When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
27 | this operation is termed in literature as depthwise convolution.
28 |
29 | Args:
30 | in_channels (int): Number of channels in the input
31 | out_channels (int): Number of channels produced by the convolution
32 | kernel_size (int or tuple): Size of the convolving kernel
33 | stride (int, optional): Stride of the convolution. Default: 1
34 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
35 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True
36 |
37 | Inputs: inputs
38 | - **inputs** (batch, in_channels, time): Tensor containing input vector
39 |
40 | Returns: outputs
41 | - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
42 | """
43 | def __init__(
44 | self,
45 | in_channels: int,
46 | out_channels: int,
47 | kernel_size: int,
48 | stride: int = 1,
49 | padding: int = 0,
50 | bias: bool = False,
51 | ) -> None:
52 | super(DepthwiseConv1d, self).__init__()
53 | assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
54 | self.conv = nn.Conv1d(
55 | in_channels=in_channels,
56 | out_channels=out_channels,
57 | kernel_size=kernel_size,
58 | groups=in_channels,
59 | stride=stride,
60 | padding=padding,
61 | bias=bias,
62 | )
63 |
64 | def forward(self, inputs: Tensor) -> Tensor:
65 | return self.conv(inputs)
66 |
67 |
68 | class PointwiseConv1d(nn.Module):
69 | """
70 | When kernel size == 1 conv1d, this operation is termed in literature as pointwise convolution.
71 | This operation often used to match dimensions.
72 |
73 | Args:
74 | in_channels (int): Number of channels in the input
75 | out_channels (int): Number of channels produced by the convolution
76 | stride (int, optional): Stride of the convolution. Default: 1
77 | padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
78 | bias (bool, optional): If True, adds a learnable bias to the output. Default: True
79 |
80 | Inputs: inputs
81 | - **inputs** (batch, in_channels, time): Tensor containing input vector
82 |
83 | Returns: outputs
84 | - **outputs** (batch, out_channels, time): Tensor produces by pointwise 1-D convolution.
85 | """
86 | def __init__(
87 | self,
88 | in_channels: int,
89 | out_channels: int,
90 | stride: int = 1,
91 | padding: int = 0,
92 | bias: bool = True,
93 | ) -> None:
94 | super(PointwiseConv1d, self).__init__()
95 | self.conv = nn.Conv1d(
96 | in_channels=in_channels,
97 | out_channels=out_channels,
98 | kernel_size=1,
99 | stride=stride,
100 | padding=padding,
101 | bias=bias,
102 | )
103 |
104 | def forward(self, inputs: Tensor) -> Tensor:
105 | return self.conv(inputs)
106 |
107 |
108 | class ConformerConvModule(nn.Module):
109 | """
110 | Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
111 | This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
112 | to aid training deep models.
113 |
114 | Args:
115 | in_channels (int): Number of channels in the input
116 | kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
117 | dropout_p (float, optional): probability of dropout
118 |
119 | Inputs: inputs
120 | inputs (batch, time, dim): Tensor contains input sequences
121 |
122 | Outputs: outputs
123 | outputs (batch, time, dim): Tensor produces by conformer convolution module.
124 | """
125 | def __init__(
126 | self,
127 | in_channels: int,
128 | kernel_size: int = 31,
129 | expansion_factor: int = 2,
130 | dropout_p: float = 0.1,
131 | ) -> None:
132 | super(ConformerConvModule, self).__init__()
133 | assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
134 | assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
135 |
136 | self.sequential = nn.Sequential(
137 | nn.LayerNorm(in_channels),
138 | Transpose(shape=(1, 2)),
139 | PointwiseConv1d(in_channels, in_channels * expansion_factor, stride=1, padding=0, bias=True),
140 | GLU(dim=1),
141 | DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
142 | nn.BatchNorm1d(in_channels),
143 | Swish(),
144 | PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True),
145 | nn.Dropout(p=dropout_p),
146 | )
147 |
148 | def forward(self, inputs: Tensor) -> Tensor:
149 | return self.sequential(inputs).transpose(1, 2)
150 |
151 |
152 | class Conv2dSubampling(nn.Module):
153 | """
154 | Convolutional 2D subsampling (to 1/4 length)
155 |
156 | Args:
157 | in_channels (int): Number of channels in the input image
158 | out_channels (int): Number of channels produced by the convolution
159 |
160 | Inputs: inputs
161 | - **inputs** (batch, time, dim): Tensor containing sequence of inputs
162 |
163 | Returns: outputs, output_lengths
164 | - **outputs** (batch, time, dim): Tensor produced by the convolution
165 | - **output_lengths** (batch): list of sequence output lengths
166 | """
167 | def __init__(self, in_channels: int, out_channels: int) -> None:
168 | super(Conv2dSubampling, self).__init__()
169 | self.sequential = nn.Sequential(
170 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2),
171 | nn.ReLU(),
172 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2),
173 | nn.ReLU(),
174 | )
175 |
176 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
177 | outputs = self.sequential(inputs.unsqueeze(1))
178 | batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
179 |
180 | outputs = outputs.permute(0, 2, 1, 3)
181 | outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim)
182 |
183 | output_lengths = input_lengths >> 2
184 | output_lengths -= 1
185 |
186 | return outputs, output_lengths
187 |
--------------------------------------------------------------------------------
/mm-kws/conformer/conformer/embedding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import math
16 | import torch
17 | import torch.nn as nn
18 | from torch import Tensor
19 |
20 |
21 | class PositionalEncoding(nn.Module):
22 | """
23 | Positional Encoding proposed in "Attention Is All You Need".
24 | Since transformer contains no recurrence and no convolution, in order for the model to make
25 | use of the order of the sequence, we must add some positional information.
26 |
27 | "Attention Is All You Need" use sine and cosine functions of different frequencies:
28 | PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model))
29 | PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))
30 | """
31 | def __init__(self, d_model: int = 512, max_len: int = 10000) -> None:
32 | super(PositionalEncoding, self).__init__()
33 | pe = torch.zeros(max_len, d_model, requires_grad=False)
34 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
35 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
36 | pe[:, 0::2] = torch.sin(position * div_term)
37 | pe[:, 1::2] = torch.cos(position * div_term)
38 | pe = pe.unsqueeze(0)
39 | self.register_buffer('pe', pe)
40 |
41 | def forward(self, length: int) -> Tensor:
42 | return self.pe[:, :length]
--------------------------------------------------------------------------------
/mm-kws/conformer/conformer/feed_forward.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch import Tensor
18 |
19 | from .activation import Swish
20 | from .modules import Linear
21 |
22 |
23 | class FeedForwardModule(nn.Module):
24 | """
25 | Conformer Feed Forward Module follow pre-norm residual units and apply layer normalization within the residual unit
26 | and on the input before the first linear layer. This module also apply Swish activation and dropout, which helps
27 | regularizing the network.
28 |
29 | Args:
30 | encoder_dim (int): Dimension of conformer encoder
31 | expansion_factor (int): Expansion factor of feed forward module.
32 | dropout_p (float): Ratio of dropout
33 |
34 | Inputs: inputs
35 | - **inputs** (batch, time, dim): Tensor contains input sequences
36 |
37 | Outputs: outputs
38 | - **outputs** (batch, time, dim): Tensor produces by feed forward module.
39 | """
40 | def __init__(
41 | self,
42 | encoder_dim: int = 512,
43 | expansion_factor: int = 4,
44 | dropout_p: float = 0.1,
45 | ) -> None:
46 | super(FeedForwardModule, self).__init__()
47 | self.sequential = nn.Sequential(
48 | nn.LayerNorm(encoder_dim),
49 | Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
50 | Swish(),
51 | nn.Dropout(p=dropout_p),
52 | Linear(encoder_dim * expansion_factor, encoder_dim, bias=True),
53 | nn.Dropout(p=dropout_p),
54 | )
55 |
56 | def forward(self, inputs: Tensor) -> Tensor:
57 | return self.sequential(inputs)
58 |
--------------------------------------------------------------------------------
/mm-kws/conformer/conformer/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch import Tensor
18 | from typing import Tuple
19 |
20 | from .encoder import ConformerEncoder
21 | from .modules import Linear
22 |
23 |
24 | class Conformer(nn.Module):
25 | """
26 | Conformer: Convolution-augmented Transformer for Speech Recognition
27 | The paper used a one-lstm Transducer decoder, currently still only implemented
28 | the conformer encoder shown in the paper.
29 |
30 | Args:
31 | num_classes (int): Number of classification classes
32 | input_dim (int, optional): Dimension of input vector
33 | encoder_dim (int, optional): Dimension of conformer encoder
34 | num_encoder_layers (int, optional): Number of conformer blocks
35 | num_attention_heads (int, optional): Number of attention heads
36 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
37 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
38 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
39 | attention_dropout_p (float, optional): Probability of attention module dropout
40 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout
41 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel
42 | half_step_residual (bool): Flag indication whether to use half step residual or not
43 |
44 | Inputs: inputs, input_lengths
45 | - **inputs** (batch, time, dim): Tensor containing input vector
46 | - **input_lengths** (batch): list of sequence input lengths
47 |
48 | Returns: outputs, output_lengths
49 | - **outputs** (batch, out_channels, time): Tensor produces by conformer.
50 | - **output_lengths** (batch): list of sequence output lengths
51 | """
52 | def __init__(
53 | self,
54 | num_classes: int,
55 | input_dim: int = 80,
56 | encoder_dim: int = 512,
57 | num_encoder_layers: int = 17,
58 | num_attention_heads: int = 8,
59 | feed_forward_expansion_factor: int = 4,
60 | conv_expansion_factor: int = 2,
61 | input_dropout_p: float = 0.1,
62 | feed_forward_dropout_p: float = 0.1,
63 | attention_dropout_p: float = 0.1,
64 | conv_dropout_p: float = 0.1,
65 | conv_kernel_size: int = 31,
66 | half_step_residual: bool = True,
67 | ) -> None:
68 | super(Conformer, self).__init__()
69 | self.encoder = ConformerEncoder(
70 | input_dim=input_dim,
71 | encoder_dim=encoder_dim,
72 | num_layers=num_encoder_layers,
73 | num_attention_heads=num_attention_heads,
74 | feed_forward_expansion_factor=feed_forward_expansion_factor,
75 | conv_expansion_factor=conv_expansion_factor,
76 | input_dropout_p=input_dropout_p,
77 | feed_forward_dropout_p=feed_forward_dropout_p,
78 | attention_dropout_p=attention_dropout_p,
79 | conv_dropout_p=conv_dropout_p,
80 | conv_kernel_size=conv_kernel_size,
81 | half_step_residual=half_step_residual,
82 | )
83 | self.fc = Linear(encoder_dim, num_classes, bias=False)
84 |
85 | def count_parameters(self) -> int:
86 | """ Count parameters of encoder """
87 | return self.encoder.count_parameters()
88 |
89 | def update_dropout(self, dropout_p) -> None:
90 | """ Update dropout probability of model """
91 | self.encoder.update_dropout(dropout_p)
92 |
93 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
94 | """
95 | Forward propagate a `inputs` and `targets` pair for training.
96 |
97 | Args:
98 | inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
99 | `FloatTensor` of size ``(batch, seq_length, dimension)``.
100 | input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
101 |
102 | Returns:
103 | * predictions (torch.FloatTensor): Result of model predictions.
104 | """
105 | encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths)
106 | outputs = self.fc(encoder_outputs)
107 | outputs = nn.functional.log_softmax(outputs, dim=-1)
108 | return outputs, encoder_output_lengths
109 |
--------------------------------------------------------------------------------
/mm-kws/conformer/conformer/model_def.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch import Tensor
18 | from typing import Tuple
19 |
20 | from .encoder import ConformerEncoder
21 | from .modules import Linear
22 |
23 |
24 | class Conformer(nn.Module):
25 | """
26 | Conformer: Convolution-augmented Transformer for Speech Recognition
27 | The paper used a one-lstm Transducer decoder, currently still only implemented
28 | the conformer encoder shown in the paper.
29 |
30 | Args:
31 | num_classes (int): Number of classification classes
32 | input_dim (int, optional): Dimension of input vector
33 | encoder_dim (int, optional): Dimension of conformer encoder
34 | num_encoder_layers (int, optional): Number of conformer blocks
35 | num_attention_heads (int, optional): Number of attention heads
36 | feed_forward_expansion_factor (int, optional): Expansion factor of feed forward module
37 | conv_expansion_factor (int, optional): Expansion factor of conformer convolution module
38 | feed_forward_dropout_p (float, optional): Probability of feed forward module dropout
39 | attention_dropout_p (float, optional): Probability of attention module dropout
40 | conv_dropout_p (float, optional): Probability of conformer convolution module dropout
41 | conv_kernel_size (int or tuple, optional): Size of the convolving kernel
42 | half_step_residual (bool): Flag indication whether to use half step residual or not
43 |
44 | Inputs: inputs, input_lengths
45 | - **inputs** (batch, time, dim): Tensor containing input vector
46 | - **input_lengths** (batch): list of sequence input lengths
47 |
48 | Returns: outputs, output_lengths
49 | - **outputs** (batch, out_channels, time): Tensor produces by conformer.
50 | - **output_lengths** (batch): list of sequence output lengths
51 | """
52 | def __init__(
53 | self,
54 | input_dim: int = 80,
55 | encoder_dim: int = 512,
56 | num_encoder_layers: int = 17,
57 | num_attention_heads: int = 8,
58 | feed_forward_expansion_factor: int = 4,
59 | conv_expansion_factor: int = 2,
60 | input_dropout_p: float = 0.1,
61 | feed_forward_dropout_p: float = 0.1,
62 | attention_dropout_p: float = 0.1,
63 | conv_dropout_p: float = 0.1,
64 | conv_kernel_size: int = 31,
65 | half_step_residual: bool = True,
66 | ) -> None:
67 | super(Conformer, self).__init__()
68 | self.encoder = ConformerEncoder(
69 | input_dim=input_dim,
70 | encoder_dim=encoder_dim,
71 | num_layers=num_encoder_layers,
72 | num_attention_heads=num_attention_heads,
73 | feed_forward_expansion_factor=feed_forward_expansion_factor,
74 | conv_expansion_factor=conv_expansion_factor,
75 | input_dropout_p=input_dropout_p,
76 | feed_forward_dropout_p=feed_forward_dropout_p,
77 | attention_dropout_p=attention_dropout_p,
78 | conv_dropout_p=conv_dropout_p,
79 | conv_kernel_size=conv_kernel_size,
80 | half_step_residual=half_step_residual,
81 | )
82 |
83 | def count_parameters(self) -> int:
84 | """ Count parameters of encoder """
85 | return self.encoder.count_parameters()
86 |
87 | def update_dropout(self, dropout_p) -> None:
88 | """ Update dropout probability of model """
89 | self.encoder.update_dropout(dropout_p)
90 |
91 | def forward(self, inputs: Tensor, input_lengths: Tensor) -> Tuple[Tensor, Tensor]:
92 | """
93 | Forward propagate a `inputs` and `targets` pair for training.
94 |
95 | Args:
96 | inputs (torch.FloatTensor): A input sequence passed to encoder. Typically for inputs this will be a padded
97 | `FloatTensor` of size ``(batch, seq_length, dimension)``.
98 | input_lengths (torch.LongTensor): The length of input tensor. ``(batch)``
99 |
100 | Returns:
101 | * predictions (torch.FloatTensor): Result of model predictions.
102 | """
103 | encoder_outputs, encoder_output_lengths = self.encoder(inputs, input_lengths)
104 | return encoder_outputs, encoder_output_lengths
105 |
--------------------------------------------------------------------------------
/mm-kws/conformer/conformer/modules.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.init as init
18 | from torch import Tensor
19 |
20 |
21 | class ResidualConnectionModule(nn.Module):
22 | """
23 | Residual Connection Module.
24 | outputs = (module(inputs) x module_factor + inputs x input_factor)
25 | """
26 | def __init__(self, module: nn.Module, module_factor: float = 1.0, input_factor: float = 1.0):
27 | super(ResidualConnectionModule, self).__init__()
28 | self.module = module
29 | self.module_factor = module_factor
30 | self.input_factor = input_factor
31 |
32 | def forward(self, inputs: Tensor) -> Tensor:
33 | return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
34 |
35 |
36 | class Linear(nn.Module):
37 | """
38 | Wrapper class of torch.nn.Linear
39 | Weight initialize by xavier initialization and bias initialize to zeros.
40 | """
41 | def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
42 | super(Linear, self).__init__()
43 | self.linear = nn.Linear(in_features, out_features, bias=bias)
44 | init.xavier_uniform_(self.linear.weight)
45 | if bias:
46 | init.zeros_(self.linear.bias)
47 |
48 | def forward(self, x: Tensor) -> Tensor:
49 | return self.linear(x)
50 |
51 |
52 | class View(nn.Module):
53 | """ Wrapper class of torch.view() for Sequential module. """
54 | def __init__(self, shape: tuple, contiguous: bool = False):
55 | super(View, self).__init__()
56 | self.shape = shape
57 | self.contiguous = contiguous
58 |
59 | def forward(self, x: Tensor) -> Tensor:
60 | if self.contiguous:
61 | x = x.contiguous()
62 |
63 | return x.view(*self.shape)
64 |
65 |
66 | class Transpose(nn.Module):
67 | """ Wrapper class of torch.transpose() for Sequential module. """
68 | def __init__(self, shape: tuple):
69 | super(Transpose, self).__init__()
70 | self.shape = shape
71 |
72 | def forward(self, x: Tensor) -> Tensor:
73 | return x.transpose(*self.shape)
74 |
--------------------------------------------------------------------------------
/mm-kws/dataloaders/SPC_N0_ALL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | import Levenshtein
15 | import re
16 | import random
17 |
18 | import os
19 |
20 | def get_files(path, endswith):
21 | _files = []
22 | for root, dirs, files in os.walk(path):
23 | for file in files:
24 | if file.endswith(endswith):
25 | _files.append(os.path.join(root, file))
26 | return _files
27 |
28 |
29 | class SPC_NO_Dataset(Dataset):
30 | def __init__(
31 | self,
32 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data",
33 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text",
34 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N0_ALL.csv",
35 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle",
36 | preprocess=True,
37 | ):
38 | super().__init__()
39 | if preprocess:
40 | target_dict = {}
41 | idx = 0
42 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Query_label', 'Support_label', 'label'])
43 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list)))
44 | classes = os.listdir(test_dir)
45 | random.shuffle(classes)
46 | Targets = classes[:10]
47 | for wav_idx in range(len(wav_id)):
48 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav'
49 | query_text = wav_id[wav_idx].split('_')[0]
50 | if query_text in Targets:
51 | for comparison_text in Targets:
52 | _label = 1 if comparison_text == query_text else 0
53 | target_dict[idx] = {
54 | 'id': wav_id[wav_idx],
55 | 'Query_text': query_text,
56 | 'Query_wav': wav,
57 | 'Support_text': comparison_text,
58 | 'Query_label': Targets.index(query_text),
59 | 'Support_label': Targets.index(comparison_text),
60 | 'label': _label
61 | }
62 | idx += 1
63 | else:
64 | for comparison_text in Targets:
65 | _label = 0
66 | target_dict[idx] = {
67 | 'id': wav_id[wav_idx],
68 | 'Query_text': query_text,
69 | 'Query_wav': wav,
70 | 'Support_text': comparison_text,
71 | 'Query_label': 10,
72 | 'Support_label': Targets.index(comparison_text),
73 | 'label': 10,
74 | }
75 | idx += 1
76 |
77 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True)
78 | self.data.to_csv(save_path, index=False)
79 | else:
80 | self.data = pd.read_csv(save_path)
81 |
82 | self.data = self.data.values.tolist()
83 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
84 |
85 |
86 | def __getitem__(
87 | self,
88 | index
89 | ):
90 | ids, Query_text, Query_wav, Support_text, Query_label, Support_label, label = self.data[index]
91 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank
92 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
93 | lm_embed = self.text_embedder[Support_text]['lm_embed']
94 | fbank_feature = fbank(
95 | Query_wav,
96 | num_mel_bins=80
97 | )
98 | g2p_embed = torch.from_numpy(g2p_embed)
99 | g2p_embed = g2p_embed.type_as(fbank_feature)
100 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label)
101 |
102 | def __len__(
103 | self
104 | ):
105 | return len(self.data)
106 |
107 |
108 |
109 |
110 |
111 | def collate_fn(batch):
112 | ids, fbank_feature, g2p_embed, lm_embed, Query_label, Support_label, label = zip(*batch)
113 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
114 | lengths = [len(seq) for seq in padded_fbank_feature]
115 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
116 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
117 | label_tensor = torch.tensor(label)
118 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, Query_label, Support_label, label_tensor
119 |
120 |
121 |
122 | spc = SPC_NO_Dataset()
123 | spc[0]
--------------------------------------------------------------------------------
/mm-kws/dataloaders/SPC_N0_TARGET.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 |
13 | import torch.nn.functional as F
14 | from torch.nn.utils.rnn import pad_sequence
15 | import Levenshtein
16 | import re
17 | import random
18 |
19 | import os
20 |
21 | def get_files(path, endswith):
22 | _files = []
23 | for root, dirs, files in os.walk(path):
24 | for file in files:
25 | if file.endswith(endswith):
26 | _files.append(os.path.join(root, file))
27 | return _files
28 |
29 |
30 | class SPC_NO_Dataset(Dataset):
31 | def __init__(
32 | self,
33 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data",
34 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text",
35 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N0_TARGET.csv",
36 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle",
37 | preprocess=True,
38 | ):
39 | super().__init__()
40 | if preprocess:
41 | target_dict = {}
42 | idx = 0
43 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Query_label', 'Support_label', 'label'])
44 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list)))
45 | classes = os.listdir(test_dir)
46 | random.shuffle(classes)
47 | Targets = classes[:10]
48 | for wav_idx in range(len(wav_id)):
49 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav'
50 | query_text = wav_id[wav_idx].split('_')[0]
51 | if query_text in Targets:
52 | for comparison_text in Targets:
53 | _label = 1 if comparison_text == query_text else 0
54 | target_dict[idx] = {
55 | 'id': wav_id[wav_idx],
56 | 'Query_text': query_text,
57 | 'Query_wav': wav,
58 | 'Support_text': comparison_text,
59 | 'Query_label': Targets.index(query_text),
60 | 'Support_label': Targets.index(comparison_text),
61 | 'label': _label
62 | }
63 | idx += 1
64 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True)
65 | self.data.to_csv(save_path, index=False)
66 | else:
67 | self.data = pd.read_csv(save_path)
68 |
69 | self.data = self.data.values.tolist()
70 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
71 |
72 |
73 | def __getitem__(
74 | self,
75 | index
76 | ):
77 | ids, Query_text, Query_wav, Support_text, Query_label, Support_label, label = self.data[index]
78 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank
79 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
80 | lm_embed = self.text_embedder[Support_text]['lm_embed']
81 | fbank_feature = fbank(
82 | Query_wav,
83 | num_mel_bins=80
84 | )
85 | g2p_embed = torch.from_numpy(g2p_embed)
86 | g2p_embed = g2p_embed.type_as(fbank_feature)
87 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label)
88 |
89 | def __len__(
90 | self
91 | ):
92 | return len(self.data)
93 |
94 |
95 |
96 |
97 |
98 | def collate_fn(batch):
99 | ids, fbank_feature, g2p_embed, lm_embed, Query_label, Support_label, label = zip(*batch)
100 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
101 | lengths = [len(seq) for seq in padded_fbank_feature]
102 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
103 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
104 | label_tensor = torch.tensor(label)
105 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, Query_label, Support_label, label_tensor
106 |
107 |
108 |
109 | # spc = SPC_NO_Dataset()
110 |
--------------------------------------------------------------------------------
/mm-kws/dataloaders/SPC_N1_ALL.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | # import dataaug
13 | import torch.nn.functional as F
14 | from torch.nn.utils.rnn import pad_sequence
15 | import Levenshtein
16 | import re
17 | import random
18 |
19 | import os
20 |
21 | def get_files(path, endswith):
22 | _files = []
23 | for root, dirs, files in os.walk(path):
24 | for file in files:
25 | if file.endswith(endswith):
26 | _files.append(os.path.join(root, file))
27 | return _files
28 | from tqdm import tqdm
29 |
30 | class SPC_N1_Dataset(Dataset):
31 | def __init__(
32 | self,
33 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data",
34 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text",
35 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N1_ALL.csv",
36 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle",
37 | preprocess=True,
38 | ):
39 | super().__init__()
40 | if preprocess:
41 | target_dict = {}
42 | idx = 0
43 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Support_wav', 'Query_label', 'Support_label', 'label'])
44 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list)))
45 | classes = os.listdir(test_dir)
46 | random.shuffle(classes)
47 | Targets = classes[:10]
48 | supports_wavs = {}
49 | for comparison_text in Targets:
50 | supports_wav = get_files(os.path.join(test_dir, comparison_text), '_18.npy')
51 | random.shuffle(supports_wav)
52 | supports_wavs[comparison_text] = supports_wav
53 | for wav_idx in range(len(wav_id)):
54 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav'
55 | query_text = wav_id[wav_idx].split('_')[0]
56 | if query_text in Targets:
57 | for comparison_text in Targets:
58 | support_wav = random.choices(supports_wavs[comparison_text], k=2)
59 | _label = 1 if comparison_text == query_text else 0
60 | target_dict[idx] = {
61 | 'id': wav_id[wav_idx],
62 | 'Query_text': query_text,
63 | 'Query_wav': wav,
64 | 'Support_text': comparison_text,
65 | 'Support_wav': support_wav[0],
66 | 'Query_label': Targets.index(query_text),
67 | 'Support_label': Targets.index(comparison_text),
68 | 'label': _label
69 | }
70 | idx += 1
71 | else:
72 | for comparison_text in Targets:
73 | support_wav = random.choices(supports_wavs[comparison_text], k=5)
74 | _label = 0
75 | target_dict[idx] = {
76 | 'id': wav_id[wav_idx],
77 | 'Query_text': query_text,
78 | 'Query_wav': wav,
79 | 'Support_text': comparison_text,
80 | 'Support_wav': support_wav[0],
81 | 'Query_label': 10,
82 | 'Support_label': Targets.index(comparison_text),
83 | 'label': 10,
84 | }
85 | idx += 1
86 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True)
87 | self.data.to_csv(save_path, index=False)
88 | else:
89 | self.data = pd.read_csv(save_path)
90 |
91 | self.data = self.data.values.tolist()
92 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
93 |
94 |
95 | def __getitem__(
96 | self,
97 | index
98 | ):
99 | ids, Query_text, Query_wav, Support_text, Support_wav, Query_label, Support_label, label = self.data[index]
100 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank
101 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
102 | lm_embed = self.text_embedder[Support_text]['lm_embed']
103 | audiolm_embed = np.load(Support_wav)
104 | fbank_feature = fbank(
105 | Query_wav,
106 | num_mel_bins=80
107 | )
108 | g2p_embed = torch.from_numpy(g2p_embed)
109 | g2p_embed = g2p_embed.type_as(fbank_feature)
110 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label)
111 |
112 | def __len__(
113 | self
114 | ):
115 | return len(self.data)
116 |
117 |
118 |
119 |
120 |
121 | def collate_fn(batch):
122 | ids, fbank_feature, g2p_embed, lm_embed, audiolm_embed, Query_label, Support_label, label = zip(*batch)
123 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
124 | lengths = [len(seq) for seq in padded_fbank_feature]
125 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
126 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
127 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0)
128 | label_tensor = torch.tensor(label)
129 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, Query_label, Support_label, label_tensor
130 |
131 |
132 |
133 | spc = SPC_N1_Dataset()
134 | # spc[0]
--------------------------------------------------------------------------------
/mm-kws/dataloaders/SPC_N1_TARGET.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | import Levenshtein
15 | import re
16 | import random
17 |
18 | import os
19 |
20 | def get_files(path, endswith):
21 | _files = []
22 | for root, dirs, files in os.walk(path):
23 | for file in files:
24 | if file.endswith(endswith):
25 | _files.append(os.path.join(root, file))
26 | return _files
27 | from tqdm import tqdm
28 |
29 | class SPC_N1_Dataset(Dataset):
30 | def __init__(
31 | self,
32 | test_dir="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/data",
33 | test_list="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/test/text",
34 | save_path="/nvme01/aizq/mmkws/mmkws_submits/spc/SPC1/SPC_N1_TARGET.csv",
35 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/spc_text_embeddings.pickle",
36 | preprocess=True,
37 | ):
38 | super().__init__()
39 | if preprocess:
40 | target_dict = {}
41 | idx = 0
42 | self.data = pd.DataFrame(columns=['id', 'Query_text', 'Query_wav', 'Support_text', 'Support_wav', 'Query_label', 'Support_label', 'label'])
43 | wav_id, _ = zip(*(line.strip().split() for line in open(test_list)))
44 | classes = os.listdir(test_dir)
45 | random.shuffle(classes)
46 | Targets = classes[:10]
47 | # Targets = ['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go']
48 | supports_wavs = {}
49 | for comparison_text in Targets:
50 | supports_wav = get_files(os.path.join(test_dir, comparison_text), '_18.npy')
51 | random.shuffle(supports_wav)
52 | supports_wavs[comparison_text] = supports_wav
53 | for wav_idx in range(len(wav_id)):
54 | wav = os.path.join(test_dir, *wav_id[wav_idx].split('_', 1)) + '.wav'
55 | query_text = wav_id[wav_idx].split('_')[0]
56 | if query_text in Targets:
57 | for comparison_text in Targets:
58 | _label = 1 if comparison_text == query_text else 0
59 | target_dict[idx] = {
60 | 'id': wav_id[wav_idx],
61 | 'Query_text': query_text,
62 | 'Query_wav': wav,
63 | 'Support_text': comparison_text,
64 | 'Support_wav': supports_wavs[comparison_text][0],
65 | 'Query_label': Targets.index(query_text),
66 | 'Support_label': Targets.index(comparison_text),
67 | 'label': _label
68 | }
69 | idx += 1
70 | self.data = self.data._append(pd.DataFrame.from_dict(target_dict, 'index'), ignore_index=True)
71 | self.data.to_csv(save_path, index=False)
72 | else:
73 | self.data = pd.read_csv(save_path)
74 |
75 | self.data = self.data.values.tolist()
76 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
77 |
78 |
79 | def __getitem__(
80 | self,
81 | index
82 | ):
83 | ids, Query_text, Query_wav, Support_text, Support_wav, Query_label, Support_label, label = self.data[index]
84 | Query_wav, _ = torchaudio.load(Query_wav) # waveform -> fbank
85 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
86 | lm_embed = self.text_embedder[Support_text]['lm_embed']
87 | audiolm_embed = np.load(Support_wav)
88 | fbank_feature = fbank(
89 | Query_wav,
90 | num_mel_bins=80
91 | )
92 | g2p_embed = torch.from_numpy(g2p_embed)
93 | g2p_embed = g2p_embed.type_as(fbank_feature)
94 | return ids, fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), Query_label, Support_label, torch.tensor(label)
95 |
96 | def __len__(
97 | self
98 | ):
99 | return len(self.data)
100 |
101 |
102 |
103 |
104 |
105 | def collate_fn(batch):
106 | ids, fbank_feature, g2p_embed, lm_embed, audiolm_embed, Query_label, Support_label, label = zip(*batch)
107 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
108 | lengths = [len(seq) for seq in padded_fbank_feature]
109 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
110 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
111 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0)
112 | label_tensor = torch.tensor(label)
113 | return ids, padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, Query_label, Support_label, label_tensor
114 |
115 |
116 |
117 | spc = SPC_N1_Dataset()
118 | # spc[0]
--------------------------------------------------------------------------------
/mm-kws/dataloaders/__pycache__/SPC_N0_ALL.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/SPC_N0_ALL.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/__pycache__/SPC_N0_TARGET.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/SPC_N0_TARGET.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/__pycache__/SPC_N1_ALL.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/SPC_N1_ALL.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/__pycache__/SPC_N1_TARGET.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/SPC_N1_TARGET.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/__pycache__/libriphrase_test.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/libriphrase_test.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/__pycache__/libriphrase_test_18.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/libriphrase_test_18.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/__pycache__/libriphrase_train.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/libriphrase_train.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/__pycache__/libriphrase_trainY.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/libriphrase_trainY.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/__pycache__/libriphrase_train_18.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/__pycache__/libriphrase_train_18.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/g2p/g2p_en/__init__.py:
--------------------------------------------------------------------------------
1 | from .g2p import G2p
2 |
--------------------------------------------------------------------------------
/mm-kws/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/expand.cpython-39.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-310.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/__pycache__/g2p.cpython-39.pyc
--------------------------------------------------------------------------------
/mm-kws/dataloaders/g2p/g2p_en/checkpoint20.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/dataloaders/g2p/g2p_en/checkpoint20.npz
--------------------------------------------------------------------------------
/mm-kws/dataloaders/g2p/g2p_en/expand.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | Borrowed
5 | from https://github.com/keithito/tacotron/blob/master/text/numbers.py
6 | By kyubyong park. kbpark.linguist@gmail.com.
7 | https://www.github.com/kyubyong/g2p
8 | '''
9 | from __future__ import print_function
10 | import inflect
11 | import re
12 |
13 |
14 |
15 | _inflect = inflect.engine()
16 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
17 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
18 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
19 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
20 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
21 | _number_re = re.compile(r'[0-9]+')
22 |
23 |
24 | def _remove_commas(m):
25 | return m.group(1).replace(',', '')
26 |
27 |
28 | def _expand_decimal_point(m):
29 | return m.group(1).replace('.', ' point ')
30 |
31 |
32 | def _expand_dollars(m):
33 | match = m.group(1)
34 | parts = match.split('.')
35 | if len(parts) > 2:
36 | return match + ' dollars' # Unexpected format
37 | dollars = int(parts[0]) if parts[0] else 0
38 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
39 | if dollars and cents:
40 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
41 | cent_unit = 'cent' if cents == 1 else 'cents'
42 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
43 | elif dollars:
44 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
45 | return '%s %s' % (dollars, dollar_unit)
46 | elif cents:
47 | cent_unit = 'cent' if cents == 1 else 'cents'
48 | return '%s %s' % (cents, cent_unit)
49 | else:
50 | return 'zero dollars'
51 |
52 |
53 | def _expand_ordinal(m):
54 | return _inflect.number_to_words(m.group(0))
55 |
56 |
57 | def _expand_number(m):
58 | num = int(m.group(0))
59 | if num > 1000 and num < 3000:
60 | if num == 2000:
61 | return 'two thousand'
62 | elif num > 2000 and num < 2010:
63 | return 'two thousand ' + _inflect.number_to_words(num % 100)
64 | elif num % 100 == 0:
65 | return _inflect.number_to_words(num // 100) + ' hundred'
66 | else:
67 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
68 | else:
69 | return _inflect.number_to_words(num, andword='')
70 |
71 |
72 | def normalize_numbers(text):
73 | text = re.sub(_comma_number_re, _remove_commas, text)
74 | text = re.sub(_pounds_re, r'\1 pounds', text)
75 | text = re.sub(_dollars_re, _expand_dollars, text)
76 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
77 | text = re.sub(_ordinal_re, _expand_ordinal, text)
78 | text = re.sub(_number_re, _expand_number, text)
79 | return text
80 |
--------------------------------------------------------------------------------
/mm-kws/dataloaders/libriphrase_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | import re
15 |
16 | class LibriPhrase_Test_Dataset(Dataset):
17 | def __init__(
18 | self,
19 | test_dir="/nvme01/aizq/mmkws/datasets/LibriPhrase_Test",
20 | csv=[
21 | "libriphrase_diffspk_all_1word.csv",
22 | "libriphrase_diffspk_all_2word.csv",
23 | "libriphrase_diffspk_all_3word.csv",
24 | "libriphrase_diffspk_all_4word.csv"
25 | ],
26 | save_path='/nvme01/aizq/mmkws/datasets/LibriPhrase_Test/test_phrase.csv',
27 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/test_text_embeddings.pickle",
28 | preprocess=False,
29 | types='easy'
30 | ):
31 | if preprocess:
32 | self.data = pd.DataFrame(columns=['Query_text', 'Query_wav', 'Query_dur', 'Support_text', 'Support_wav', 'Support_dur', 'label', 'type'])
33 | for path in csv:
34 | n_word = os.path.join(test_dir, path)
35 | df = pd.read_csv(n_word)
36 | anc = df[['anchor_text', 'anchor', 'anchor_dur', 'comparison_text', 'comparison', 'comparison_dur', 'target', 'type']]
37 | com = df[['comparison_text', 'comparison', 'comparison_dur', 'anchor_text', 'anchor', 'anchor_dur', 'target', 'type']]
38 | self.data = self.data._append(anc.rename(columns={y: x for x, y in zip(self.data.columns, anc.columns)}), ignore_index=True)
39 | self.data = self.data._append(com.rename(columns={y: x for x, y in zip(self.data.columns, com.columns)}), ignore_index=True)
40 | self.data.to_csv(save_path, index=False)
41 | else:
42 | self.data = pd.read_csv(save_path)
43 | # print(self.data)/
44 | # self.data['dist'] = self.data.apply(lambda x: Levenshtein.ratio(re.sub(r"[^a-zA-Z0-9]+", ' ', x['Support_text']), re.sub(r"[^a-zA-Z0-9]+", ' ', x['Query_text'])), axis=1)
45 | if types == 'easy':
46 | self.data = self.data.loc[self.data['type'].isin(['diffspk_easyneg', 'diffspk_positive'])]
47 | elif types == 'hard':
48 | self.data = self.data.loc[self.data['type'].isin(['diffspk_hardneg', 'diffspk_positive'])]
49 |
50 |
51 | self.data = self.data.values.tolist()
52 | # self.data = self.data[:1000]
53 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
54 | self.test_dir = test_dir
55 |
56 | def __getitem__(
57 | self,
58 | index
59 | ):
60 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label
61 | Query_text, Query_wav, _, Support_text, Support_wav, _, label, _ = self.data[index]
62 | # print(Query_text, Query_wav, Support_text, Support_wav, label)
63 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank
64 | phoneme = self.text_embedder[Support_text]['phoneme']
65 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
66 | lm_embed = self.text_embedder[Support_text]['lm_embed']
67 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '.npy')
68 | fbank_feature = fbank(
69 | Query_wav,
70 | num_mel_bins=80
71 | )
72 | g2p_embed = torch.from_numpy(g2p_embed)
73 | g2p_embed = g2p_embed.type_as(fbank_feature)
74 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label)
75 |
76 |
77 | def __len__(
78 | self
79 | ):
80 | return len(self.data)
81 |
82 |
83 | def collate_fn(batch):
84 | # 将 batch 中的每个样本按照其数据类型分组
85 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch)
86 | # 对每个特征进行填充
87 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
88 | lengths = [len(seq) for seq in padded_fbank_feature]
89 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
90 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
91 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0)
92 | # 对 label 进行转换为 Tensor
93 | label_tensor = torch.tensor(label)
94 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor
95 |
96 |
97 |
98 |
99 |
100 | # test_data = LibriPhrase_Test_Dataset()
101 | # dataloader = DataLoader(test_data, batch_size=128, collate_fn=collate_fn, num_workers=16, shuffle=False)
102 | # from tqdm import tqdm
103 | # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
104 | # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data
105 | # print(dist_tensor)
106 | # break
107 | # # pass
108 | # pass
--------------------------------------------------------------------------------
/mm-kws/dataloaders/libriphrase_test_18.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | import re
15 |
16 | class LibriPhrase_Test_Dataset(Dataset):
17 | def __init__(
18 | self,
19 | test_dir="/nvme01/aizq/mmkws/datasets/LibriPhrase_Test",
20 | csv=[
21 | "libriphrase_diffspk_all_1word.csv",
22 | "libriphrase_diffspk_all_2word.csv",
23 | "libriphrase_diffspk_all_3word.csv",
24 | "libriphrase_diffspk_all_4word.csv"
25 | ],
26 | save_path='/nvme01/aizq/mmkws/datasets/LibriPhrase_Test/test_phrase.csv',
27 | test_text_embedding="/nvme01/aizq/mmkws/datasets/LibriPhrase_Train_MIN_20/test_text_embeddings.pickle",
28 | preprocess=False,
29 | types='easy'
30 | ):
31 | if preprocess:
32 | self.data = pd.DataFrame(columns=['Query_text', 'Query_wav', 'Query_dur', 'Support_text', 'Support_wav', 'Support_dur', 'label', 'type'])
33 | for path in csv:
34 | n_word = os.path.join(test_dir, path)
35 | df = pd.read_csv(n_word)
36 | anc = df[['anchor_text', 'anchor', 'anchor_dur', 'comparison_text', 'comparison', 'comparison_dur', 'target', 'type']]
37 | com = df[['comparison_text', 'comparison', 'comparison_dur', 'anchor_text', 'anchor', 'anchor_dur', 'target', 'type']]
38 | self.data = self.data._append(anc.rename(columns={y: x for x, y in zip(self.data.columns, anc.columns)}), ignore_index=True)
39 | self.data = self.data._append(com.rename(columns={y: x for x, y in zip(self.data.columns, com.columns)}), ignore_index=True)
40 | self.data.to_csv(save_path, index=False)
41 | else:
42 | self.data = pd.read_csv(save_path)
43 | # print(self.data)/
44 | # self.data['dist'] = self.data.apply(lambda x: Levenshtein.ratio(re.sub(r"[^a-zA-Z0-9]+", ' ', x['Support_text']), re.sub(r"[^a-zA-Z0-9]+", ' ', x['Query_text'])), axis=1)
45 | if types == 'easy':
46 | self.data = self.data.loc[self.data['type'].isin(['diffspk_easyneg', 'diffspk_positive'])]
47 | elif types == 'hard':
48 | self.data = self.data.loc[self.data['type'].isin(['diffspk_hardneg', 'diffspk_positive'])]
49 |
50 |
51 | self.data = self.data.values.tolist()
52 | # self.data = self.data[:1000]
53 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
54 | self.test_dir = test_dir
55 |
56 | def __getitem__(
57 | self,
58 | index
59 | ):
60 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label
61 | Query_text, Query_wav, _, Support_text, Support_wav, _, label, _ = self.data[index]
62 | # print(Query_text, Query_wav, Support_text, Support_wav, label)
63 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank
64 | phoneme = self.text_embedder[Support_text]['phoneme']
65 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
66 | lm_embed = self.text_embedder[Support_text]['lm_embed']
67 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '_18.npy')
68 | fbank_feature = fbank(
69 | Query_wav,
70 | num_mel_bins=80
71 | )
72 | g2p_embed = torch.from_numpy(g2p_embed)
73 | g2p_embed = g2p_embed.type_as(fbank_feature)
74 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label)
75 |
76 |
77 | def __len__(
78 | self
79 | ):
80 | return len(self.data)
81 |
82 |
83 | def collate_fn(batch):
84 | # 将 batch 中的每个样本按照其数据类型分组
85 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch)
86 | # 对每个特征进行填充
87 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
88 | lengths = [len(seq) for seq in padded_fbank_feature]
89 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
90 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
91 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0)
92 | # 对 label 进行转换为 Tensor
93 | label_tensor = torch.tensor(label)
94 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor
95 |
96 |
97 |
98 |
99 |
100 | # test_data = LibriPhrase_Test_Dataset()
101 | # dataloader = DataLoader(test_data, batch_size=128, collate_fn=collate_fn, num_workers=16, shuffle=False)
102 | # from tqdm import tqdm
103 | # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
104 | # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data
105 | # print(dist_tensor)
106 | # break
107 | # # pass
108 | # pass
--------------------------------------------------------------------------------
/mm-kws/dataloaders/wenetphrase_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 |
4 | import os
5 | import pandas as pd
6 | import warnings
7 | warnings.filterwarnings('ignore')
8 | import torchaudio
9 | import pickle
10 | import numpy as np
11 | from torchaudio.compliance.kaldi import fbank
12 | import torch.nn.functional as F
13 | from torch.nn.utils.rnn import pad_sequence
14 | import re
15 |
16 | class WenetPhrase_Test_Dataset(Dataset):
17 | def __init__(
18 | self,
19 | test_dir="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/WenetPhrase2/S",
20 | csv="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/wenetphrase_test.csv",
21 | test_text_embedding="/nvme01/aizq/mmkws/datasets/WenetPhrase_Clips/zh_test_text_embeddings.pickle",
22 | types='easy'
23 | ):
24 | self.data = pd.read_csv(csv)
25 | if types == 'easy':
26 | self.data = self.data.loc[self.data['type'].isin(['easy', 'pos'])]
27 | elif types == 'hard':
28 | self.data = self.data.loc[self.data['type'].isin(['hard', 'pos'])]
29 | self.data = self.data.values.tolist()
30 | print(len(self.data))
31 | with open(test_text_embedding, 'rb') as pickle_file: self.text_embedder = pickle.load(pickle_file)
32 | self.test_dir = test_dir
33 |
34 | def __getitem__(
35 | self,
36 | index
37 | ):
38 | # Query_wav_fbank, phoneme, g2p_embed. lm_embed, audiolm_embed, label
39 | Query_text, Query_wav, Support_text, Support_wav, label, _ = self.data[index]
40 | # print(Query_text, Query_wav, Support_text, Support_wav, label)
41 | Query_wav, _ = torchaudio.load(os.path.join(self.test_dir, Query_wav)) # waveform -> fbank
42 | phoneme = self.text_embedder[Support_text]['phoneme']
43 | g2p_embed = self.text_embedder[Support_text]['g2p_embed']
44 | lm_embed = self.text_embedder[Support_text]['lm_embed']
45 | audiolm_embed = np.load(os.path.join(self.test_dir, Support_wav)[:-4] + '_18.npy')
46 | fbank_feature = fbank(
47 | Query_wav,
48 | num_mel_bins=80
49 | )
50 | g2p_embed = torch.from_numpy(g2p_embed)
51 | g2p_embed = g2p_embed.type_as(fbank_feature)
52 | return fbank_feature, g2p_embed, torch.from_numpy(lm_embed).squeeze(0), torch.from_numpy(audiolm_embed).squeeze(0), torch.tensor(label)
53 |
54 |
55 | def __len__(
56 | self
57 | ):
58 | return len(self.data)
59 |
60 |
61 | def collate_fn(batch):
62 | # 将 batch 中的每个样本按照其数据类型分组
63 | fbank_feature, g2p_embed, lm_embed, audiolm_embed, label = zip(*batch)
64 | # 对每个特征进行填充
65 | padded_fbank_feature = pad_sequence(fbank_feature, batch_first=True)
66 | lengths = [len(seq) for seq in padded_fbank_feature]
67 | padded_g2p_embed = pad_sequence(g2p_embed, batch_first=True)
68 | padded_lm_embed = pad_sequence(lm_embed, batch_first=True).squeeze(0)
69 | padded_audiolm_embed = pad_sequence(audiolm_embed, batch_first=True).squeeze(0)
70 | # 对 label 进行转换为 Tensor
71 | label_tensor = torch.tensor(label)
72 | return padded_fbank_feature, torch.tensor(lengths), padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor
73 |
74 |
75 |
76 |
77 |
78 | # test_data = WenetPhrase_Test_Dataset()
79 | # print(test_data[0])
80 | # # dataloader = DataLoader(test_data, batch_size=2, collate_fn=collate_fn, num_workers=1, shuffle=False)
81 | # # from tqdm import tqdm
82 | # # for i, data in tqdm(enumerate(dataloader), total=len(dataloader)):
83 | # # padded_fbank_feature, lengths, padded_g2p_embed, padded_lm_embed, padded_audiolm_embed, label_tensor, dist_tensor = data
84 | # # print(dist_tensor)
85 | # # break
86 | # # # pass
87 | # # pass
--------------------------------------------------------------------------------
/mm-kws/g2p/g2p_en/__init__.py:
--------------------------------------------------------------------------------
1 | from .g2p import G2p
2 |
--------------------------------------------------------------------------------
/mm-kws/g2p/g2p_en/checkpoint20.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/g2p/g2p_en/checkpoint20.npz
--------------------------------------------------------------------------------
/mm-kws/g2p/g2p_en/expand.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python2
3 | '''
4 | Borrowed
5 | from https://github.com/keithito/tacotron/blob/master/text/numbers.py
6 | By kyubyong park. kbpark.linguist@gmail.com.
7 | https://www.github.com/kyubyong/g2p
8 | '''
9 | from __future__ import print_function
10 | import inflect
11 | import re
12 |
13 |
14 |
15 | _inflect = inflect.engine()
16 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
17 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
18 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
19 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
20 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
21 | _number_re = re.compile(r'[0-9]+')
22 |
23 |
24 | def _remove_commas(m):
25 | return m.group(1).replace(',', '')
26 |
27 |
28 | def _expand_decimal_point(m):
29 | return m.group(1).replace('.', ' point ')
30 |
31 |
32 | def _expand_dollars(m):
33 | match = m.group(1)
34 | parts = match.split('.')
35 | if len(parts) > 2:
36 | return match + ' dollars' # Unexpected format
37 | dollars = int(parts[0]) if parts[0] else 0
38 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
39 | if dollars and cents:
40 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
41 | cent_unit = 'cent' if cents == 1 else 'cents'
42 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
43 | elif dollars:
44 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
45 | return '%s %s' % (dollars, dollar_unit)
46 | elif cents:
47 | cent_unit = 'cent' if cents == 1 else 'cents'
48 | return '%s %s' % (cents, cent_unit)
49 | else:
50 | return 'zero dollars'
51 |
52 |
53 | def _expand_ordinal(m):
54 | return _inflect.number_to_words(m.group(0))
55 |
56 |
57 | def _expand_number(m):
58 | num = int(m.group(0))
59 | if num > 1000 and num < 3000:
60 | if num == 2000:
61 | return 'two thousand'
62 | elif num > 2000 and num < 2010:
63 | return 'two thousand ' + _inflect.number_to_words(num % 100)
64 | elif num % 100 == 0:
65 | return _inflect.number_to_words(num // 100) + ' hundred'
66 | else:
67 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
68 | else:
69 | return _inflect.number_to_words(num, andword='')
70 |
71 |
72 | def normalize_numbers(text):
73 | text = re.sub(_comma_number_re, _remove_commas, text)
74 | text = re.sub(_pounds_re, r'\1 pounds', text)
75 | text = re.sub(_dollars_re, _expand_dollars, text)
76 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
77 | text = re.sub(_ordinal_re, _expand_ordinal, text)
78 | text = re.sub(_number_re, _expand_number, text)
79 | return text
80 |
--------------------------------------------------------------------------------
/mm-kws/libriphrase_hardneg.json.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/libriphrase_hardneg.json.zip
--------------------------------------------------------------------------------
/mm-kws/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
4 | import math
5 |
6 | from conformer.conformer.model_def import Conformer
7 |
8 | class PositionalEmbedding(nn.Module):
9 | def __init__(self, d_model=512, max_len=512):
10 | super().__init__()
11 | # Compute the positional encodings once in log space.
12 | pe = torch.zeros(max_len, d_model).float()
13 | pe.require_grad = False
14 | position = torch.arange(0, max_len).float().unsqueeze(1)
15 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
16 | pe[:, 0::2] = torch.sin(position * div_term)
17 | pe[:, 1::2] = torch.cos(position * div_term)
18 | pe = pe.unsqueeze(0)
19 | self.register_buffer('pe', pe)
20 |
21 | def forward(self, x):
22 | return self.pe[:, :x.size(1)]
23 |
24 |
25 | class TEXT_Fusion_transformer_encoder(nn.Module):
26 | def __init__(
27 | self,
28 | d_model,
29 | nlayers,
30 | nhead,
31 | dim_feedforward,
32 | dropout=0.1
33 | ):
34 | super().__init__()
35 | self.position_audio = PositionalEmbedding(d_model=128)
36 | self.position_text_g2p = PositionalEmbedding(d_model=128)
37 | self.position_text_lm = PositionalEmbedding(d_model=128)
38 | self.modality = nn.Embedding(4, 128, padding_idx=0) # 1 for audio, 2 for g2p, 3 for text lm
39 | self.dropout = nn.Dropout(p=dropout)
40 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
41 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
42 |
43 | def forward(self,audio_embedding, g2p_embedding, lm_embedding):
44 | position_audio_encoding = self.position_audio(audio_embedding)
45 | position_g2p_encoding = self.position_text_g2p(g2p_embedding)
46 | position_lm_encoding = self.position_text_lm(lm_embedding)
47 |
48 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device))
49 | modality_g2p = self.modality(2 * torch.ones((position_g2p_encoding.size(0), g2p_embedding.shape[1]), dtype=torch.int).to(g2p_embedding.device))
50 | modality_lm = self.modality(3 * torch.ones((position_lm_encoding.size(0), lm_embedding.shape[1]), dtype=torch.int).to(lm_embedding.device))
51 |
52 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio
53 | g2p_tokens = g2p_embedding + position_g2p_encoding + modality_g2p
54 | lm_tokens = lm_embedding + position_lm_encoding + modality_lm
55 |
56 | #(3) concat tokens
57 | input_tokens = torch.cat((audio_tokens, g2p_tokens, lm_tokens), dim=1)
58 | input_tokens = self.dropout(input_tokens)
59 |
60 | output = self.transformer_encoder(input_tokens)
61 | return output
62 |
63 |
64 | class Audio_Fusion_transformer_encoder(nn.Module):
65 | def __init__(
66 | self,
67 | d_model,
68 | nlayers,
69 | nhead,
70 | dim_feedforward,
71 | dropout=0.1
72 | ):
73 | super().__init__()
74 | self.position_audio = PositionalEmbedding(d_model=128)
75 | self.position_audio_lm = PositionalEmbedding(d_model=128)
76 | self.modality = nn.Embedding(3, 128, padding_idx=0) # 1 for audio, 2 for audiolm
77 | self.dropout = nn.Dropout(p=dropout)
78 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
79 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
80 |
81 | def forward(self,audio_embedding, audiolm_embedding):
82 | position_audio_encoding = self.position_audio(audio_embedding)
83 | position_audiolm_encoding = self.position_audio_lm(audiolm_embedding)
84 |
85 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device))
86 | modality_audiolm = self.modality(2 * torch.ones((position_audiolm_encoding.size(0), audiolm_embedding.shape[1]), dtype=torch.int).to(audiolm_embedding.device))
87 |
88 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio
89 | audiolm_tokens = audiolm_embedding + position_audiolm_encoding + modality_audiolm
90 |
91 | #(3) concat tokens
92 | input_tokens = torch.cat((audio_tokens, audiolm_tokens), dim=1)
93 | input_tokens = self.dropout(input_tokens)
94 |
95 | output = self.transformer_encoder(input_tokens)
96 | return output
97 |
98 |
99 | class GRUFCModel(nn.Module):
100 | def __init__(self, input_dim, hidden_dim, output_dim):
101 | super(GRUFCModel, self).__init__()
102 | self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
103 | self.fc = nn.Linear(hidden_dim, output_dim)
104 |
105 | def forward(self, x):
106 | gru_out, _ = self.gru(x)
107 | gru_last_output = gru_out[:, -1, :]
108 | fc_out = self.fc(gru_last_output)
109 | return fc_out
110 |
111 |
112 | import torch.nn as nn
113 | class Projection(nn.Module):
114 | def __init__(self, input_dim, output_dim):
115 | super(Projection, self).__init__()
116 | layers = []
117 | layers.append(nn.LayerNorm(input_dim))
118 | layers.append(nn.Linear(input_dim, output_dim))
119 | layers.append(nn.SiLU())
120 | self.projection_block = nn.Sequential(*layers)
121 |
122 | def forward(self, x):
123 | return self.projection_block(x)
124 |
125 |
126 | class MMKWS(nn.Module):
127 | def __init__(self):
128 | super().__init__()
129 | self.audioencoder = Conformer(
130 | input_dim= 80,
131 | encoder_dim= 128,
132 | num_encoder_layers= 6,
133 | num_attention_heads= 4,
134 | )
135 | self.g2p_projection = Projection(input_dim=256, output_dim=128)
136 | self.lm_projection = Projection(input_dim=768, output_dim=128)
137 | self.audiolm_projection = Projection(input_dim=1024, output_dim=128)
138 | self.text_fusion_transformer = TEXT_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1)
139 | self.audio_fusion_transformer = Audio_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1)
140 | self.gru1 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64)
141 | self.gru2 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64)
142 | self.fc = nn.Linear(64, 1)
143 | self.phoneme_fc = nn.Linear(128, 1)
144 | self.text_fc = nn.Linear(128, 1)
145 |
146 |
147 | def forward(self, fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed):
148 | audio_embedding = self.audioencoder(fbank_feature, lengths)[0]
149 | g2p_embedding = self.g2p_projection(g2p_embed)
150 | lm_embedding = self.lm_projection(lm_embed)
151 | audiolm_embedding = self.audiolm_projection(audiolm_embed)
152 |
153 | fusion_text = self.text_fusion_transformer(audio_embedding, g2p_embedding, lm_embedding)
154 | fusion_audio = self.audio_fusion_transformer(audio_embedding, audiolm_embedding)
155 | fusion = self.gru1(fusion_text)+self.gru2(fusion_audio)
156 | fusion_pred = self.fc(fusion)
157 |
158 | fusion_phoneme_pred = self.phoneme_fc(fusion_text[:, audio_embedding.shape[1]:(audio_embedding.shape[1]+g2p_embedding.shape[1]), :])
159 | fusion_text_pred = self.text_fc(fusion_text[:, (audio_embedding.shape[1]+g2p_embedding.shape[1]):, :])
160 | return fusion_pred, fusion_phoneme_pred, fusion_text_pred
161 |
--------------------------------------------------------------------------------
/mm-kws/test.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from pytorch_lightning.utilities.types import STEP_OUTPUT
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from pytorch_lightning import LightningModule, Trainer
6 | import pytorch_lightning as pl
7 | from dataloaders.libriphrase_test import LibriPhrase_Test_Dataset, collate_fn
8 | import torch.nn as nn
9 | from models import MMKWS
10 | from sklearn.metrics import roc_auc_score
11 | from sklearn.metrics import roc_curve
12 | import sklearn
13 | import numpy as np
14 | import torch.nn as nn
15 | import torch
16 |
17 | def compute_eer(label, pred):
18 | fpr, tpr, threshold = sklearn.metrics.roc_curve(label, pred)
19 | fnr = 1 - tpr
20 |
21 | eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
22 | eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
23 | eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
24 |
25 | eer = (eer_1 + eer_2) / 2
26 | return eer
27 |
28 | class EER(nn.Module):
29 | def __init__(self):
30 | super(EER, self).__init__()
31 | self.score = 0.0
32 | self.count = 0.0
33 |
34 | def forward(self, y_true, y_pred):
35 | label_np = y_true.flatten() # Convert to numpy array
36 | pred_np = y_pred.flatten() # Convert to numpy array
37 |
38 | eer_value = compute_eer(label_np, pred_np)
39 |
40 | self.score += eer_value
41 | self.count += 1
42 |
43 | return torch.tensor(self.score / self.count)
44 |
45 | # 1. 定义 LightningModuley
46 | class MMKWS_Wrapper(LightningModule):
47 | def __init__(self):
48 | super().__init__()
49 | self.model = MMKWS()
50 | self.criterion = nn.BCEWithLogitsLoss()
51 | self.test_preds, self.test_labels = [], []
52 |
53 |
54 | def test_step(self, batch, batch_size):
55 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label = batch
56 | preds, _, _ = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed)
57 | preds = torch.sigmoid(preds)
58 | preds = preds.squeeze(dim=1)
59 | self.test_preds.append(preds)
60 | self.test_labels.append(label)
61 |
62 |
63 | def on_test_epoch_end(self):
64 | eer_loss = EER()
65 | all_preds = torch.cat(self.test_preds)
66 | all_labels = torch.cat(self.test_labels)
67 | y_true = all_labels.cpu().detach().numpy()
68 | y_scores = all_preds.cpu().detach().numpy()
69 | auc = roc_auc_score(y_true, y_scores)
70 | eer = eer_loss(y_true, y_scores)
71 | self.log('test_auc', auc)
72 | self.log('test_eer', eer)
73 |
74 | def configure_optimizers(self):
75 | optim = torch.optim.Adam(self.model.parameters(), lr=5e-4)
76 | return optim
77 |
78 |
79 | # 3. 设置 Trainer 和训练
80 | if __name__ == "__main__":
81 | import os
82 | os.environ['CUDA_VISIBLE_DEVICES'] = '5'
83 | test_dataset = LibriPhrase_Test_Dataset(types='easy')
84 | test_dataloader = DataLoader(test_dataset, batch_size=1024, collate_fn=collate_fn, shuffle=False, num_workers=24, drop_last=False)
85 | model = MMKWS_Wrapper.load_from_checkpoint("/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts/epochepoch=19.ckpt")
86 | model.eval()
87 | trainer = Trainer(devices=1, accelerator='gpu') # 设置训练器参数
88 | trainer.test(model, test_dataloader)
89 | pl.seed_everything(1234)
90 | test_dataset = LibriPhrase_Test_Dataset(types='hard')
91 | test_dataloader = DataLoader(test_dataset, batch_size=1024, collate_fn=collate_fn, shuffle=False, num_workers=24, drop_last=False)
92 | model = MMKWS_Wrapper.load_from_checkpoint("/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts/epochepoch=19.ckpt")
93 | trainer = Trainer(devices=1, accelerator='gpu', ) # 设置训练器参数
94 | model.eval()
95 | trainer.test(model, test_dataloader)
96 |
--------------------------------------------------------------------------------
/mm-kws/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.utils.data import DataLoader
5 | from torch.optim.lr_scheduler import ExponentialLR
6 |
7 | import numpy as np
8 | from sklearn.metrics import roc_auc_score
9 | from sklearn.metrics import roc_curve
10 |
11 | import pytorch_lightning as pl
12 | from pytorch_lightning import LightningModule, Trainer
13 | from pytorch_lightning.callbacks import ModelCheckpoint
14 | from dataloaders.libriphrase_train import LibriPhrase_Train_Dataset, train_collate_fn
15 | from dataloaders.libriphrase_test import LibriPhrase_Test_Dataset, collate_fn
16 | from models import MMKWS
17 |
18 | def compute_eer(label, pred):
19 | fpr, tpr, threshold = roc_curve(label, pred)
20 | fnr = 1 - tpr
21 |
22 | eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
23 | eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
24 | eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
25 |
26 | eer = (eer_1 + eer_2) / 2
27 | return eer
28 |
29 | class EER(nn.Module):
30 | def __init__(self):
31 | super(EER, self).__init__()
32 | self.score = 0.0
33 | self.count = 0.0
34 |
35 | def forward(self, y_true, y_pred):
36 | label_np = y_true.flatten() # Convert to numpy array
37 | pred_np = y_pred.flatten() # Convert to numpy array
38 |
39 | eer_value = compute_eer(label_np, pred_np)
40 |
41 | self.score += eer_value
42 | self.count += 1
43 |
44 | return torch.tensor(self.score / self.count)
45 |
46 |
47 | class MMKWS_Wrapper(LightningModule):
48 | def __init__(self):
49 | super().__init__()
50 | self.model = MMKWS()
51 | self.criterion = nn.BCEWithLogitsLoss()
52 | self.test_preds = []
53 | self.test_labels = []
54 |
55 |
56 | def training_step(self, batch, batch_idx):
57 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label, pl, mask_pl, tl, mask_tl = batch
58 | preds, phoneme_preds, text_preds = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed)
59 | phoneme_preds = phoneme_preds.squeeze(dim=2)
60 | text_preds = text_preds.squeeze(dim=2)
61 | preds = preds.squeeze(dim=1) # Output is [Batch size, 1], but we want [Batch size]
62 | phoneme_loss = self.sequence_bce_loss(phoneme_preds, pl.float(), mask_pl)
63 | text_loss = self.sequence_bce_loss(text_preds, tl.float(), mask_tl)
64 | utt_loss = self.criterion(preds, label.float())
65 | all_loss = utt_loss + phoneme_loss + text_loss
66 | self.log('train/all_loss', all_loss, on_step=True, prog_bar=True)
67 | self.log('train/utt_loss', utt_loss, on_step=True, prog_bar=True)
68 | self.log('train/text_loss', text_loss, on_step=True, prog_bar=True)
69 | self.log('train/phoneme_loss', phoneme_loss, on_step=True, prog_bar=True)
70 | return all_loss
71 |
72 | def validation_step(self, batch, batch_idx):
73 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label = batch
74 | preds, _, _ = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed)
75 | preds = torch.sigmoid(preds)
76 | preds = preds.squeeze(dim=1)
77 | self.test_preds.append(preds)
78 | self.test_labels.append(label)
79 |
80 |
81 | def on_validation_epoch_end(self):
82 | if self.current_epoch > 0:
83 | eer_loss = EER()
84 | all_preds = torch.cat(self.test_preds)
85 | all_labels = torch.cat(self.test_labels)
86 | y_true = all_labels.cpu().detach().numpy()
87 | y_scores = all_preds.cpu().detach().numpy()
88 | # 计算 AUC
89 | auc = roc_auc_score(y_true, y_scores)
90 | eer = eer_loss(y_true, y_scores)
91 | self.log('test/test_auc', auc)
92 | self.log('test/test_eer', eer)
93 | self.test_preds.clear()
94 | self.test_labels.clear()
95 |
96 |
97 | def configure_optimizers(self):
98 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
99 | return {
100 | "optimizer": optimizer,
101 | "lr_scheduler": {
102 | "scheduler": ExponentialLR(optimizer, gamma=0.95),
103 | "frequency": 1,
104 | "interval": 'epoch',
105 | },
106 | }
107 |
108 | def sequence_bce_loss(self, preds, labels, mask):
109 | # 将预测值、标签和掩码都展平为一维向量
110 | preds_flat = preds.view(-1)
111 | labels_flat = labels.view(-1)
112 | mask_flat = mask.view(-1)
113 | # 仅考虑掩码为1的位置计算二元交叉熵损失
114 | valid_indices = torch.where(mask_flat == 1)[0]
115 | valid_preds = preds_flat[valid_indices]
116 | valid_labels = labels_flat[valid_indices]
117 | # 使用PyTorch内置的二元交叉熵损失函数
118 | loss = F.binary_cross_entropy_with_logits(valid_preds, valid_labels)
119 | return loss
120 |
121 |
122 | # 3. 设置 Trainer 和训练
123 | if __name__ == "__main__":
124 | pl.seed_everything(2024)
125 | import os
126 | os.environ['CUDA_VISIBLE_DEVICES']='0, 1, 2, 3'
127 | train_dataset = LibriPhrase_Train_Dataset()
128 | train_dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=train_collate_fn, shuffle=True, num_workers=16, drop_last=True)
129 | test_dataset = LibriPhrase_Test_Dataset(types='hard')
130 | test_dataloader = DataLoader(test_dataset, batch_size=256, collate_fn=collate_fn, shuffle=False, num_workers=8, drop_last=True)
131 | model = MMKWS_Wrapper()
132 | model_checkpoint = ModelCheckpoint(
133 | dirpath="/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts",
134 | filename='epoch{epoch:02d}',
135 | save_top_k=-1,
136 | )
137 | logger = pl.loggers.TensorBoardLogger('/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/', name='MMKWS+')
138 | trainer = Trainer(devices=4, accelerator='gpu', # strategy='ddp_find_unused_parameters_true',
139 | logger=logger, max_epochs=100, callbacks=[model_checkpoint], accumulate_grad_batches=4, precision='16-mixed') # 设置训练器参数
140 | trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
141 |
--------------------------------------------------------------------------------
/mm-kws/wenetphrase_hardneg.json.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/mm-kws/wenetphrase_hardneg.json.zip
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
4 | import math
5 |
6 | from conformer.conformer.model_def import Conformer
7 |
8 | class PositionalEmbedding(nn.Module):
9 | def __init__(self, d_model=512, max_len=512):
10 | super().__init__()
11 | # Compute the positional encodings once in log space.
12 | pe = torch.zeros(max_len, d_model).float()
13 | pe.require_grad = False
14 | position = torch.arange(0, max_len).float().unsqueeze(1)
15 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
16 | pe[:, 0::2] = torch.sin(position * div_term)
17 | pe[:, 1::2] = torch.cos(position * div_term)
18 | pe = pe.unsqueeze(0)
19 | self.register_buffer('pe', pe)
20 |
21 | def forward(self, x):
22 | return self.pe[:, :x.size(1)]
23 |
24 |
25 | class TEXT_Fusion_transformer_encoder(nn.Module):
26 | def __init__(
27 | self,
28 | d_model,
29 | nlayers,
30 | nhead,
31 | dim_feedforward,
32 | dropout=0.1
33 | ):
34 | super().__init__()
35 | self.position_audio = PositionalEmbedding(d_model=128)
36 | self.position_text_g2p = PositionalEmbedding(d_model=128)
37 | self.position_text_lm = PositionalEmbedding(d_model=128)
38 | self.modality = nn.Embedding(4, 128, padding_idx=0) # 1 for audio, 2 for g2p, 3 for text lm
39 | self.dropout = nn.Dropout(p=dropout)
40 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
41 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
42 |
43 | def forward(self,audio_embedding, g2p_embedding, lm_embedding):
44 | position_audio_encoding = self.position_audio(audio_embedding)
45 | position_g2p_encoding = self.position_text_g2p(g2p_embedding)
46 | position_lm_encoding = self.position_text_lm(lm_embedding)
47 |
48 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device))
49 | modality_g2p = self.modality(2 * torch.ones((position_g2p_encoding.size(0), g2p_embedding.shape[1]), dtype=torch.int).to(g2p_embedding.device))
50 | modality_lm = self.modality(3 * torch.ones((position_lm_encoding.size(0), lm_embedding.shape[1]), dtype=torch.int).to(lm_embedding.device))
51 |
52 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio
53 | g2p_tokens = g2p_embedding + position_g2p_encoding + modality_g2p
54 | lm_tokens = lm_embedding + position_lm_encoding + modality_lm
55 |
56 | #(3) concat tokens
57 | input_tokens = torch.cat((audio_tokens, g2p_tokens, lm_tokens), dim=1)
58 | input_tokens = self.dropout(input_tokens)
59 |
60 | output = self.transformer_encoder(input_tokens)
61 | return output
62 |
63 |
64 | class Audio_Fusion_transformer_encoder(nn.Module):
65 | def __init__(
66 | self,
67 | d_model,
68 | nlayers,
69 | nhead,
70 | dim_feedforward,
71 | dropout=0.1
72 | ):
73 | super().__init__()
74 | self.position_audio = PositionalEmbedding(d_model=128)
75 | self.position_audio_lm = PositionalEmbedding(d_model=128)
76 | self.modality = nn.Embedding(3, 128, padding_idx=0) # 1 for audio, 2 for audiolm
77 | self.dropout = nn.Dropout(p=dropout)
78 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
79 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
80 |
81 | def forward(self,audio_embedding, audiolm_embedding):
82 | position_audio_encoding = self.position_audio(audio_embedding)
83 | position_audiolm_encoding = self.position_audio_lm(audiolm_embedding)
84 |
85 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device))
86 | modality_audiolm = self.modality(2 * torch.ones((position_audiolm_encoding.size(0), audiolm_embedding.shape[1]), dtype=torch.int).to(audiolm_embedding.device))
87 |
88 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio
89 | audiolm_tokens = audiolm_embedding + position_audiolm_encoding + modality_audiolm
90 |
91 | #(3) concat tokens
92 | input_tokens = torch.cat((audio_tokens, audiolm_tokens), dim=1)
93 | input_tokens = self.dropout(input_tokens)
94 |
95 | output = self.transformer_encoder(input_tokens)
96 | return output
97 |
98 |
99 | class GRUFCModel(nn.Module):
100 | def __init__(self, input_dim, hidden_dim, output_dim):
101 | super(GRUFCModel, self).__init__()
102 | self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
103 | self.fc = nn.Linear(hidden_dim, output_dim)
104 |
105 | def forward(self, x):
106 | gru_out, _ = self.gru(x)
107 | gru_last_output = gru_out[:, -1, :]
108 | fc_out = self.fc(gru_last_output)
109 | return fc_out
110 |
111 |
112 | import torch.nn as nn
113 | class Projection(nn.Module):
114 | def __init__(self, input_dim, output_dim):
115 | super(Projection, self).__init__()
116 | layers = []
117 | layers.append(nn.LayerNorm(input_dim))
118 | layers.append(nn.Linear(input_dim, output_dim))
119 | layers.append(nn.SiLU())
120 | self.projection_block = nn.Sequential(*layers)
121 |
122 | def forward(self, x):
123 | return self.projection_block(x)
124 |
125 |
126 | class MMKWS(nn.Module):
127 | def __init__(self):
128 | super().__init__()
129 | self.audioencoder = Conformer(
130 | input_dim= 80,
131 | encoder_dim= 128,
132 | num_encoder_layers= 6,
133 | num_attention_heads= 4,
134 | )
135 | self.g2p_projection = Projection(input_dim=256, output_dim=128)
136 | self.lm_projection = Projection(input_dim=768, output_dim=128)
137 | self.audiolm_projection = Projection(input_dim=1024, output_dim=128)
138 | self.text_fusion_transformer = TEXT_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1)
139 | self.audio_fusion_transformer = Audio_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1)
140 | self.gru1 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64)
141 | self.gru2 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64)
142 | self.fc = nn.Linear(64, 1)
143 | self.phoneme_fc = nn.Linear(128, 1)
144 | self.text_fc = nn.Linear(128, 1)
145 |
146 |
147 | def forward(self, fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed):
148 | audio_embedding = self.audioencoder(fbank_feature, lengths)[0]
149 | g2p_embedding = self.g2p_projection(g2p_embed)
150 | lm_embedding = self.lm_projection(lm_embed)
151 | audiolm_embedding = self.audiolm_projection(audiolm_embed)
152 |
153 | fusion_text = self.text_fusion_transformer(audio_embedding, g2p_embedding, lm_embedding)
154 | fusion_audio = self.audio_fusion_transformer(audio_embedding, audiolm_embedding)
155 | fusion = self.gru1(fusion_text)+self.gru2(fusion_audio)
156 | fusion_pred = self.fc(fusion)
157 |
158 | fusion_phoneme_pred = self.phoneme_fc(fusion_text[:, audio_embedding.shape[1]:(audio_embedding.shape[1]+g2p_embedding.shape[1]), :])
159 | fusion_text_pred = self.text_fc(fusion_text[:, (audio_embedding.shape[1]+g2p_embedding.shape[1]):, :])
160 | return fusion_pred, fusion_phoneme_pred, fusion_text_pred
161 |
--------------------------------------------------------------------------------
/models_tiny.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
4 | import math
5 |
6 | from conformer.conformer.model_def import Conformer
7 | from mdtc import MDTC_KWS
8 |
9 | class PositionalEmbedding(nn.Module):
10 | def __init__(self, d_model=512, max_len=512):
11 | super().__init__()
12 | # Compute the positional encodings once in log space.
13 | pe = torch.zeros(max_len, d_model).float()
14 | pe.require_grad = False
15 | position = torch.arange(0, max_len).float().unsqueeze(1)
16 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
17 | pe[:, 0::2] = torch.sin(position * div_term)
18 | pe[:, 1::2] = torch.cos(position * div_term)
19 | pe = pe.unsqueeze(0)
20 | self.register_buffer('pe', pe)
21 |
22 | def forward(self, x):
23 | return self.pe[:, :x.size(1)]
24 |
25 |
26 | class TEXT_Fusion_transformer_encoder(nn.Module):
27 | def __init__(
28 | self,
29 | d_model,
30 | nlayers,
31 | nhead,
32 | dim_feedforward,
33 | dropout=0.1
34 | ):
35 | super().__init__()
36 | self.position_audio = PositionalEmbedding(d_model=128)
37 | self.position_text_g2p = PositionalEmbedding(d_model=128)
38 | self.position_text_lm = PositionalEmbedding(d_model=128)
39 | self.modality = nn.Embedding(4, 128, padding_idx=0) # 1 for audio, 2 for g2p, 3 for text lm
40 | self.dropout = nn.Dropout(p=dropout)
41 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
42 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
43 |
44 | def forward(self,audio_embedding, g2p_embedding, lm_embedding):
45 | position_audio_encoding = self.position_audio(audio_embedding)
46 | position_g2p_encoding = self.position_text_g2p(g2p_embedding)
47 | position_lm_encoding = self.position_text_lm(lm_embedding)
48 |
49 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device))
50 | modality_g2p = self.modality(2 * torch.ones((position_g2p_encoding.size(0), g2p_embedding.shape[1]), dtype=torch.int).to(g2p_embedding.device))
51 | modality_lm = self.modality(3 * torch.ones((position_lm_encoding.size(0), lm_embedding.shape[1]), dtype=torch.int).to(lm_embedding.device))
52 |
53 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio
54 | g2p_tokens = g2p_embedding + position_g2p_encoding + modality_g2p
55 | lm_tokens = lm_embedding + position_lm_encoding + modality_lm
56 |
57 | #(3) concat tokens
58 | input_tokens = torch.cat((audio_tokens, g2p_tokens, lm_tokens), dim=1)
59 | input_tokens = self.dropout(input_tokens)
60 |
61 | output = self.transformer_encoder(input_tokens)
62 | return output
63 |
64 |
65 | class Audio_Fusion_transformer_encoder(nn.Module):
66 | def __init__(
67 | self,
68 | d_model,
69 | nlayers,
70 | nhead,
71 | dim_feedforward,
72 | dropout=0.1
73 | ):
74 | super().__init__()
75 | self.position_audio = PositionalEmbedding(d_model=128)
76 | self.position_audio_lm = PositionalEmbedding(d_model=128)
77 | self.modality = nn.Embedding(3, 128, padding_idx=0) # 1 for audio, 2 for audiolm
78 | self.dropout = nn.Dropout(p=dropout)
79 | encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
80 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
81 |
82 | def forward(self,audio_embedding, audiolm_embedding):
83 | position_audio_encoding = self.position_audio(audio_embedding)
84 | position_audiolm_encoding = self.position_audio_lm(audiolm_embedding)
85 |
86 | modality_audio = self.modality(1 * torch.ones((position_audio_encoding.size(0), audio_embedding.shape[1]), dtype=torch.int).to(audio_embedding.device))
87 | modality_audiolm = self.modality(2 * torch.ones((position_audiolm_encoding.size(0), audiolm_embedding.shape[1]), dtype=torch.int).to(audiolm_embedding.device))
88 |
89 | audio_tokens = audio_embedding + position_audio_encoding + modality_audio
90 | audiolm_tokens = audiolm_embedding + position_audiolm_encoding + modality_audiolm
91 |
92 | #(3) concat tokens
93 | input_tokens = torch.cat((audio_tokens, audiolm_tokens), dim=1)
94 | input_tokens = self.dropout(input_tokens)
95 |
96 | output = self.transformer_encoder(input_tokens)
97 | return output
98 |
99 |
100 | class GRUFCModel(nn.Module):
101 | def __init__(self, input_dim, hidden_dim, output_dim):
102 | super(GRUFCModel, self).__init__()
103 | self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
104 | self.fc = nn.Linear(hidden_dim, output_dim)
105 |
106 | def forward(self, x):
107 | gru_out, _ = self.gru(x)
108 | gru_last_output = gru_out[:, -1, :]
109 | fc_out = self.fc(gru_last_output)
110 | return fc_out
111 |
112 |
113 | import torch.nn as nn
114 | class Projection(nn.Module):
115 | def __init__(self, input_dim, output_dim):
116 | super(Projection, self).__init__()
117 | layers = []
118 | layers.append(nn.LayerNorm(input_dim))
119 | layers.append(nn.Linear(input_dim, output_dim))
120 | layers.append(nn.SiLU())
121 | self.projection_block = nn.Sequential(*layers)
122 |
123 | def forward(self, x):
124 | return self.projection_block(x)
125 |
126 |
127 | class MMKWS(nn.Module):
128 | def __init__(self):
129 | super().__init__()
130 | self.audioencoder = MDTC_KWS()
131 | # self.audioencoder = Conformer(
132 | # input_dim= 80,
133 | # encoder_dim= 128,
134 | # num_encoder_layers= 6,
135 | # num_attention_heads= 4,
136 | # )
137 | self.g2p_projection = Projection(input_dim=256, output_dim=128)
138 | self.lm_projection = Projection(input_dim=768, output_dim=128)
139 | self.audiolm_projection = Projection(input_dim=1024, output_dim=128)
140 | self.text_fusion_transformer = TEXT_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1)
141 | self.audio_fusion_transformer = Audio_Fusion_transformer_encoder(d_model=128,nlayers=2,nhead=4,dim_feedforward=512,dropout=0.1)
142 | self.gru1 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64)
143 | self.gru2 = GRUFCModel(input_dim=128, hidden_dim=128, output_dim=64)
144 | self.fc = nn.Linear(64, 1)
145 | self.phoneme_fc = nn.Linear(128, 1)
146 | self.text_fc = nn.Linear(128, 1)
147 |
148 |
149 | def forward(self, fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed):
150 | audio_embedding = self.audioencoder(fbank_feature, lengths)[0]
151 | g2p_embedding = self.g2p_projection(g2p_embed)
152 | lm_embedding = self.lm_projection(lm_embed)
153 | audiolm_embedding = self.audiolm_projection(audiolm_embed)
154 |
155 | fusion_text = self.text_fusion_transformer(audio_embedding, g2p_embedding, lm_embedding)
156 | fusion_audio = self.audio_fusion_transformer(audio_embedding, audiolm_embedding)
157 | fusion = self.gru1(fusion_text)+self.gru2(fusion_audio)
158 | fusion_pred = self.fc(fusion)
159 |
160 | fusion_phoneme_pred = self.phoneme_fc(fusion_text[:, audio_embedding.shape[1]:(audio_embedding.shape[1]+g2p_embedding.shape[1]), :])
161 | fusion_text_pred = self.text_fc(fusion_text[:, (audio_embedding.shape[1]+g2p_embedding.shape[1]):, :])
162 | return fusion_pred, fusion_phoneme_pred, fusion_text_pred
163 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from pytorch_lightning.utilities.types import STEP_OUTPUT
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from pytorch_lightning import LightningModule, Trainer
6 | import pytorch_lightning as pl
7 | from dataloaders.libriphrase_test import LibriPhrase_Test_Dataset, collate_fn
8 | import torch.nn as nn
9 | from models import MMKWS
10 | from sklearn.metrics import roc_auc_score
11 | from sklearn.metrics import roc_curve
12 | import sklearn
13 | import numpy as np
14 | import torch.nn as nn
15 | import torch
16 |
17 | def compute_eer(label, pred):
18 | fpr, tpr, threshold = sklearn.metrics.roc_curve(label, pred)
19 | fnr = 1 - tpr
20 |
21 | eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
22 | eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
23 | eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
24 |
25 | eer = (eer_1 + eer_2) / 2
26 | return eer
27 |
28 | class EER(nn.Module):
29 | def __init__(self):
30 | super(EER, self).__init__()
31 | self.score = 0.0
32 | self.count = 0.0
33 |
34 | def forward(self, y_true, y_pred):
35 | label_np = y_true.flatten() # Convert to numpy array
36 | pred_np = y_pred.flatten() # Convert to numpy array
37 |
38 | eer_value = compute_eer(label_np, pred_np)
39 |
40 | self.score += eer_value
41 | self.count += 1
42 |
43 | return torch.tensor(self.score / self.count)
44 |
45 | # 1. 定义 LightningModuley
46 | class MMKWS_Wrapper(LightningModule):
47 | def __init__(self):
48 | super().__init__()
49 | self.model = MMKWS()
50 | self.criterion = nn.BCEWithLogitsLoss()
51 | self.test_preds, self.test_labels = [], []
52 |
53 |
54 | def test_step(self, batch, batch_size):
55 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label = batch
56 | preds, _, _ = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed)
57 | preds = torch.sigmoid(preds)
58 | preds = preds.squeeze(dim=1)
59 | self.test_preds.append(preds)
60 | self.test_labels.append(label)
61 |
62 |
63 | def on_test_epoch_end(self):
64 | eer_loss = EER()
65 | all_preds = torch.cat(self.test_preds)
66 | all_labels = torch.cat(self.test_labels)
67 | y_true = all_labels.cpu().detach().numpy()
68 | y_scores = all_preds.cpu().detach().numpy()
69 | auc = roc_auc_score(y_true, y_scores)
70 | eer = eer_loss(y_true, y_scores)
71 | self.log('test_auc', auc)
72 | self.log('test_eer', eer)
73 |
74 | def configure_optimizers(self):
75 | optim = torch.optim.Adam(self.model.parameters(), lr=5e-4)
76 | return optim
77 |
78 |
79 | # 3. 设置 Trainer 和训练
80 | if __name__ == "__main__":
81 | import os
82 | os.environ['CUDA_VISIBLE_DEVICES'] = '5'
83 | test_dataset = LibriPhrase_Test_Dataset(types='easy')
84 | test_dataloader = DataLoader(test_dataset, batch_size=1024, collate_fn=collate_fn, shuffle=False, num_workers=24, drop_last=False)
85 | model = MMKWS_Wrapper.load_from_checkpoint("/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts/epochepoch=19.ckpt")
86 | model.eval()
87 | trainer = Trainer(devices=1, accelerator='gpu') # 设置训练器参数
88 | trainer.test(model, test_dataloader)
89 | pl.seed_everything(1234)
90 | test_dataset = LibriPhrase_Test_Dataset(types='hard')
91 | test_dataloader = DataLoader(test_dataset, batch_size=1024, collate_fn=collate_fn, shuffle=False, num_workers=24, drop_last=False)
92 | model = MMKWS_Wrapper.load_from_checkpoint("/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts/epochepoch=19.ckpt")
93 | trainer = Trainer(devices=1, accelerator='gpu', ) # 设置训练器参数
94 | model.eval()
95 | trainer.test(model, test_dataloader)
96 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.utils.data import DataLoader
5 | from torch.optim.lr_scheduler import ExponentialLR
6 |
7 | import numpy as np
8 | from sklearn.metrics import roc_auc_score
9 | from sklearn.metrics import roc_curve
10 |
11 | import pytorch_lightning as pl
12 | from pytorch_lightning import LightningModule, Trainer
13 | from pytorch_lightning.callbacks import ModelCheckpoint
14 | from dataloaders.libriphrase_train import LibriPhrase_Train_Dataset, train_collate_fn
15 | from dataloaders.libriphrase_test import LibriPhrase_Test_Dataset, collate_fn
16 | from models import MMKWS
17 |
18 | def compute_eer(label, pred):
19 | fpr, tpr, threshold = roc_curve(label, pred)
20 | fnr = 1 - tpr
21 |
22 | eer_threshold = threshold[np.nanargmin(np.absolute((fnr - fpr)))]
23 | eer_1 = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
24 | eer_2 = fnr[np.nanargmin(np.absolute((fnr - fpr)))]
25 |
26 | eer = (eer_1 + eer_2) / 2
27 | return eer
28 |
29 | class EER(nn.Module):
30 | def __init__(self):
31 | super(EER, self).__init__()
32 | self.score = 0.0
33 | self.count = 0.0
34 |
35 | def forward(self, y_true, y_pred):
36 | label_np = y_true.flatten() # Convert to numpy array
37 | pred_np = y_pred.flatten() # Convert to numpy array
38 |
39 | eer_value = compute_eer(label_np, pred_np)
40 |
41 | self.score += eer_value
42 | self.count += 1
43 |
44 | return torch.tensor(self.score / self.count)
45 |
46 |
47 | class MMKWS_Wrapper(LightningModule):
48 | def __init__(self):
49 | super().__init__()
50 | self.model = MMKWS()
51 | self.criterion = nn.BCEWithLogitsLoss()
52 | self.test_preds = []
53 | self.test_labels = []
54 |
55 |
56 | def training_step(self, batch, batch_idx):
57 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label, pl, mask_pl, tl, mask_tl = batch
58 | preds, phoneme_preds, text_preds = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed)
59 | phoneme_preds = phoneme_preds.squeeze(dim=2)
60 | text_preds = text_preds.squeeze(dim=2)
61 | preds = preds.squeeze(dim=1) # Output is [Batch size, 1], but we want [Batch size]
62 | phoneme_loss = self.sequence_bce_loss(phoneme_preds, pl.float(), mask_pl)
63 | text_loss = self.sequence_bce_loss(text_preds, tl.float(), mask_tl)
64 | utt_loss = self.criterion(preds, label.float())
65 | all_loss = utt_loss + phoneme_loss + text_loss
66 | self.log('train/all_loss', all_loss, on_step=True, prog_bar=True)
67 | self.log('train/utt_loss', utt_loss, on_step=True, prog_bar=True)
68 | self.log('train/text_loss', text_loss, on_step=True, prog_bar=True)
69 | self.log('train/phoneme_loss', phoneme_loss, on_step=True, prog_bar=True)
70 | return all_loss
71 |
72 | def validation_step(self, batch, batch_idx):
73 | fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed, label = batch
74 | preds, _, _ = self.model(fbank_feature, lengths, g2p_embed, lm_embed, audiolm_embed)
75 | preds = torch.sigmoid(preds)
76 | preds = preds.squeeze(dim=1)
77 | self.test_preds.append(preds)
78 | self.test_labels.append(label)
79 |
80 |
81 | def on_validation_epoch_end(self):
82 | if self.current_epoch > 0:
83 | eer_loss = EER()
84 | all_preds = torch.cat(self.test_preds)
85 | all_labels = torch.cat(self.test_labels)
86 | y_true = all_labels.cpu().detach().numpy()
87 | y_scores = all_preds.cpu().detach().numpy()
88 | # 计算 AUC
89 | auc = roc_auc_score(y_true, y_scores)
90 | eer = eer_loss(y_true, y_scores)
91 | self.log('test/test_auc', auc)
92 | self.log('test/test_eer', eer)
93 | self.test_preds.clear()
94 | self.test_labels.clear()
95 |
96 |
97 | def configure_optimizers(self):
98 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
99 | return {
100 | "optimizer": optimizer,
101 | "lr_scheduler": {
102 | "scheduler": ExponentialLR(optimizer, gamma=0.95),
103 | "frequency": 1,
104 | "interval": 'epoch',
105 | },
106 | }
107 |
108 | def sequence_bce_loss(self, preds, labels, mask):
109 | # 将预测值、标签和掩码都展平为一维向量
110 | preds_flat = preds.view(-1)
111 | labels_flat = labels.view(-1)
112 | mask_flat = mask.view(-1)
113 | # 仅考虑掩码为1的位置计算二元交叉熵损失
114 | valid_indices = torch.where(mask_flat == 1)[0]
115 | valid_preds = preds_flat[valid_indices]
116 | valid_labels = labels_flat[valid_indices]
117 | # 使用PyTorch内置的二元交叉熵损失函数
118 | loss = F.binary_cross_entropy_with_logits(valid_preds, valid_labels)
119 | return loss
120 |
121 |
122 | # 3. 设置 Trainer 和训练
123 | if __name__ == "__main__":
124 | pl.seed_everything(2024)
125 | import os
126 | os.environ['CUDA_VISIBLE_DEVICES']='0, 1, 2, 3'
127 | train_dataset = LibriPhrase_Train_Dataset()
128 | train_dataloader = DataLoader(train_dataset, batch_size=4, collate_fn=train_collate_fn, shuffle=True, num_workers=16, drop_last=True)
129 | test_dataset = LibriPhrase_Test_Dataset(types='hard')
130 | test_dataloader = DataLoader(test_dataset, batch_size=256, collate_fn=collate_fn, shuffle=False, num_workers=8, drop_last=True)
131 | model = MMKWS_Wrapper()
132 | model_checkpoint = ModelCheckpoint(
133 | dirpath="/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/MMKWS+/ckpts",
134 | filename='epoch{epoch:02d}',
135 | save_top_k=-1,
136 | )
137 | logger = pl.loggers.TensorBoardLogger('/nvme01/aizq/mmkws/mmkws_submits/MMKWS_EN_Base+/logs/', name='MMKWS+')
138 | trainer = Trainer(devices=4, accelerator='gpu', # strategy='ddp_find_unused_parameters_true',
139 | logger=logger, max_epochs=100, callbacks=[model_checkpoint], accumulate_grad_batches=4, precision='16-mixed') # 设置训练器参数
140 | trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
141 |
--------------------------------------------------------------------------------
/wenetphrase_hardneg.json.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aizhiqi-work/MM-KWS/a2bbb15ea4dff82851a69a67e82283826c74dd35/wenetphrase_hardneg.json.zip
--------------------------------------------------------------------------------