├── .idea
├── .gitignore
├── fspen_m.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── config.json
├── configs
└── train_configs.py
├── models
└── fspen.py
├── modules
├── en_decoder.py
└── sequence_modules.py
└── run_train.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/fspen_m.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
12 |
13 |
14 |
46 |
47 |
48 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # description
2 | un official implement of [FSPEN: AN ULTRA-LIGHTWEIGHT NETWORK FOR REAL TIME SPEECH
3 | ENAHNCMENT](https://ieeexplore.ieee.org/document/10446016)
4 |
5 | the model is also can use for stream inference.
6 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "sample_rate": 16000,
3 | "n_fft": 512,
4 | "hop_length": 256,
5 | "train_frames": 62,
6 | "train_points": 15616,
7 | "full_band_encoder": {
8 | "encoder1": {
9 | "in_channels": 2,
10 | "out_channels": 4,
11 | "kernel_size": 6,
12 | "stride": 2,
13 | "padding": 2
14 | },
15 | "encoder2": {
16 | "in_channels": 4,
17 | "out_channels": 16,
18 | "kernel_size": 8,
19 | "stride": 2,
20 | "padding": 3
21 | },
22 | "encoder3": {
23 | "in_channels": 16,
24 | "out_channels": 32,
25 | "kernel_size": 6,
26 | "stride": 2,
27 | "padding": 2
28 | }
29 | },
30 | "full_band_decoder": {
31 | "decoder1": {
32 | "in_channels": 64,
33 | "out_channels": 16,
34 | "kernel_size": 6,
35 | "stride": 2,
36 | "padding": 2
37 | },
38 | "decoder2": {
39 | "in_channels": 32,
40 | "out_channels": 4,
41 | "kernel_size": 8,
42 | "stride": 2,
43 | "padding": 3
44 | },
45 | "decoder3": {
46 | "in_channels": 8,
47 | "out_channels": 2,
48 | "kernel_size": 6,
49 | "stride": 2,
50 | "padding": 2
51 | }
52 | },
53 | "sub_band_encoder": {
54 | "encoder1": {
55 | "group_width": 16,
56 | "conv": {
57 | "start_frequency": 0,
58 | "end_frequency": 16,
59 | "in_channels": 1,
60 | "out_channels": 32,
61 | "kernel_size": 4,
62 | "stride": 2,
63 | "padding": 1
64 | }
65 | },
66 | "encoder2": {
67 | "group_width": 18,
68 | "conv": {
69 | "start_frequency": 16,
70 | "end_frequency": 34,
71 | "in_channels": 1,
72 | "out_channels": 32,
73 | "kernel_size": 7,
74 | "stride": 3,
75 | "padding": 2
76 | }
77 | },
78 | "encoder3": {
79 | "group_width": 36,
80 | "conv": {
81 | "start_frequency": 34,
82 | "end_frequency": 70,
83 | "in_channels": 1,
84 | "out_channels": 32,
85 | "kernel_size": 11,
86 | "stride": 5,
87 | "padding": 2
88 | }
89 | },
90 | "encoder4": {
91 | "group_width": 66,
92 | "conv": {
93 | "start_frequency": 70,
94 | "end_frequency": 136,
95 | "in_channels": 1,
96 | "out_channels": 32,
97 | "kernel_size": 20,
98 | "stride": 10,
99 | "padding": 4
100 | }
101 | },
102 | "encoder5": {
103 | "group_width": 121,
104 | "conv": {
105 | "start_frequency": 136,
106 | "end_frequency": 257,
107 | "in_channels": 1,
108 | "out_channels": 32,
109 | "kernel_size": 30,
110 | "stride": 20,
111 | "padding": 5
112 | }
113 | }
114 | },
115 | "merge_split": {
116 | "channels": 64,
117 | "bands": 32,
118 | "compress_rate": 2
119 | },
120 | "bands_num_in_groups": [
121 | 8,
122 | 6,
123 | 6,
124 | 6,
125 | 6
126 | ],
127 | "band_width_in_groups": [
128 | 2,
129 | 3,
130 | 6,
131 | 11,
132 | 20
133 | ],
134 | "sub_band_decoder": {
135 | "decoder0": {
136 | "in_features": 64,
137 | "out_features": 2
138 | },
139 | "decoder1": {
140 | "in_features": 64,
141 | "out_features": 3
142 | },
143 | "decoder2": {
144 | "in_features": 64,
145 | "out_features": 6
146 | },
147 | "decoder3": {
148 | "in_features": 64,
149 | "out_features": 11
150 | },
151 | "decoder4": {
152 | "in_features": 64,
153 | "out_features": 20
154 | }
155 | },
156 | "dual_path_extension": {
157 | "num_modules": 3,
158 | "parameters": {
159 | "input_size": 16,
160 | "intra_hidden_size": 16,
161 | "inter_hidden_size": 16,
162 | "groups": 8,
163 | "rnn_type": "GRU"
164 | }
165 | }
166 | }
--------------------------------------------------------------------------------
/configs/train_configs.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 |
3 | from pydantic import BaseModel, field_validator
4 |
5 |
6 | def get_sub_bands(band_parameters: dict):
7 | group_bands = list()
8 | group_band_width = list()
9 | for key, value in band_parameters.items():
10 | num_band = (value["group_width"] - value["conv"]["kernel_size"] +
11 | 2 * value["conv"]["padding"]) // value["conv"]["stride"] + 1
12 | sub_band_width = value["group_width"] // num_band
13 | group_bands.append(num_band)
14 | group_band_width.append(sub_band_width)
15 |
16 | return tuple(group_bands), tuple(group_band_width)
17 |
18 |
19 | class TrainConfig(BaseModel):
20 | sample_rate: int = 16000
21 | n_fft: int = 512
22 | hop_length: int = 256
23 | train_frames: int = 62
24 | train_points: int = (train_frames - 1) * hop_length
25 |
26 | full_band_encoder: Dict[str, dict] = {
27 | "encoder1": {"in_channels": 2, "out_channels": 4, "kernel_size": 6, "stride": 2, "padding": 2},
28 | "encoder2": {"in_channels": 4, "out_channels": 16, "kernel_size": 8, "stride": 2, "padding": 3},
29 | "encoder3": {"in_channels": 16, "out_channels": 32, "kernel_size": 6, "stride": 2, "padding": 2}
30 | }
31 | full_band_decoder: Dict[str, dict] = {
32 | "decoder1": {"in_channels": 64, "out_channels": 16, "kernel_size": 6, "stride": 2, "padding": 2},
33 | "decoder2": {"in_channels": 32, "out_channels": 4, "kernel_size": 8, "stride": 2, "padding": 3},
34 | "decoder3": {"in_channels": 8, "out_channels": 2, "kernel_size": 6, "stride": 2, "padding": 2}
35 | }
36 |
37 | sub_band_encoder: Dict[str, dict] = {
38 | "encoder1": {"group_width": 16, "conv": {"start_frequency": 0, "end_frequency": 16, "in_channels": 1,
39 | "out_channels": 32, "kernel_size": 4, "stride": 2, "padding": 1}},
40 | "encoder2": {"group_width": 18, "conv": {"start_frequency": 16, "end_frequency": 34, "in_channels": 1,
41 | "out_channels": 32, "kernel_size": 7, "stride": 3, "padding": 2}},
42 | "encoder3": {"group_width": 36, "conv": {"start_frequency": 34, "end_frequency": 70, "in_channels": 1,
43 | "out_channels": 32, "kernel_size": 11, "stride": 5, "padding": 2}},
44 | "encoder4": {"group_width": 66, "conv": {"start_frequency": 70, "end_frequency": 136, "in_channels": 1,
45 | "out_channels": 32, "kernel_size": 20, "stride": 10, "padding": 4}},
46 | "encoder5": {"group_width": 121, "conv": {"start_frequency": 136, "end_frequency": 257, "in_channels": 1,
47 | "out_channels": 32, "kernel_size": 30, "stride": 20, "padding": 5}}
48 | }
49 | merge_split: dict = {"channels": 64, "bands": 32, "compress_rate": 2}
50 | bands_num_in_groups: Tuple[int] = get_sub_bands(sub_band_encoder)[0]
51 | band_width_in_groups: Tuple[int] = get_sub_bands(sub_band_encoder)[1]
52 |
53 | sub_band_decoder: Dict[str, dict] = {f"decoder{idx}": {"in_features": 64, "out_features": width}
54 | for idx, width in enumerate(band_width_in_groups)}
55 |
56 | dual_path_extension: dict = {
57 | "num_modules": 3,
58 | "parameters": {"input_size": 16, "intra_hidden_size": 16, "inter_hidden_size": 16,
59 | "groups": 8, "rnn_type": "GRU"}
60 | }
61 |
62 | @field_validator("sub_band_decoder")
63 | def sub_band_decoder_validate(cls, decoders):
64 | for decoder in decoders:
65 | if decoder["out_feature"] < 2:
66 | raise ValueError(f"values should > 2, but got {decoder['out_feature']}")
67 |
68 |
69 | if __name__ == "__main__":
70 | test_configs = TrainConfig()
71 |
72 | for (decoder_name, parameters), _, in zip(test_configs.sub_band_encoder.items(), test_configs.bands_num_in_groups):
73 | print(parameters)
74 |
--------------------------------------------------------------------------------
/models/fspen.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from torch import nn, Tensor
4 | from modules.en_decoder import FullBandEncoderBlock, FullBandDecoderBlock
5 | from modules.en_decoder import SubBandEncoderBlock, SubBandDecoderBlock
6 | from modules.sequence_modules import DualPathExtensionRNN
7 | from configs.train_configs import TrainConfig
8 |
9 |
10 | class FullBandEncoder(nn.Module):
11 | def __init__(self, configs: TrainConfig):
12 | super().__init__()
13 |
14 | last_channels = 0
15 | self.full_band_encoder = nn.ModuleList()
16 | for encoder_name, conv_parameter in configs.full_band_encoder.items():
17 | self.full_band_encoder.append(FullBandEncoderBlock(**conv_parameter))
18 | last_channels = conv_parameter["out_channels"]
19 |
20 | self.global_features = nn.Conv1d(in_channels=last_channels, out_channels=last_channels, kernel_size=1, stride=1)
21 |
22 | def forward(self, complex_spectrum: Tensor):
23 | """
24 | :param complex_spectrum: (batch*frame, channels, frequency)
25 | :return:
26 | """
27 | full_band_encodes = []
28 | for encoder in self.full_band_encoder:
29 | complex_spectrum = encoder(complex_spectrum)
30 | full_band_encodes.append(complex_spectrum)
31 |
32 | global_feature = self.global_features(complex_spectrum)
33 |
34 | return full_band_encodes[::-1], global_feature
35 |
36 |
37 | class SubBandEncoder(nn.Module):
38 | def __init__(self, configs: TrainConfig):
39 | super().__init__()
40 |
41 | self.sub_band_encoders = nn.ModuleList()
42 | for encoder_name, conv_parameters in configs.sub_band_encoder.items():
43 | self.sub_band_encoders.append(SubBandEncoderBlock(**conv_parameters["conv"]))
44 |
45 | def forward(self, amplitude_spectrum: Tensor):
46 | """
47 | :param amplitude_spectrum: (batch * frames, channels, frequency)
48 | :return:
49 | """
50 | sub_band_encodes = list()
51 | for encoder in self.sub_band_encoders:
52 | encode_out = encoder(amplitude_spectrum)
53 | sub_band_encodes.append(encode_out)
54 |
55 | local_feature = torch.cat(sub_band_encodes, dim=2) # feature cat
56 |
57 | return sub_band_encodes, local_feature
58 |
59 |
60 | class FullBandDecoder(nn.Module):
61 | def __init__(self, configs: TrainConfig):
62 | super().__init__()
63 | self.full_band_decoders = nn.ModuleList()
64 | for decoder_name, parameters in configs.full_band_decoder.items():
65 | self.full_band_decoders.append(
66 | FullBandDecoderBlock(**parameters))
67 |
68 | def forward(self, feature: Tensor, encode_outs: list):
69 | for decoder, encode_out in zip(self.full_band_decoders, encode_outs):
70 | feature = decoder(feature, encode_out)
71 |
72 | return feature
73 |
74 |
75 | class SubBandDecoder(nn.Module):
76 | def __init__(self, configs: TrainConfig):
77 | super().__init__()
78 | start_idx = 0
79 | self.sub_band_decoders = nn.ModuleList()
80 | for (decoder_name, parameters), bands in zip(configs.sub_band_decoder.items(), configs.bands_num_in_groups):
81 | end_idx = start_idx + bands
82 | self.sub_band_decoders.append(SubBandDecoderBlock(start_idx=start_idx, end_idx=end_idx, **parameters))
83 |
84 | def forward(self, feature: Tensor, sub_encodes: list):
85 | """
86 | :param feature: (batch*frames, channels, bands)
87 | :param sub_encodes: [sub_encode_0, sub_encode_1, ...], each element is (batch*frames, channels, sub_bands)
88 | :return: (batch*frames, full-frequency)
89 | """
90 | sub_decoder_outs = []
91 | for decoder, sub_encode in zip(self.sub_band_decoders, sub_encodes):
92 | sub_decoder_out = decoder(feature, sub_encode)
93 | sub_decoder_outs.append(sub_decoder_out)
94 |
95 | sub_decoder_outs = torch.cat(tensors=sub_decoder_outs, dim=1) # feature cat
96 |
97 | return sub_decoder_outs
98 |
99 |
100 | class FullSubPathExtension(nn.Module):
101 | def __init__(self, configs: TrainConfig):
102 | super().__init__()
103 | self.full_band_encoder = FullBandEncoder(configs)
104 | self.sub_band_encoder = SubBandEncoder(configs)
105 |
106 | merge_split = configs.merge_split
107 | merge_channels = merge_split["channels"]
108 | merge_bands = merge_split["bands"]
109 | compress_rate = merge_split["compress_rate"]
110 |
111 | self.feature_merge_layer = nn.Sequential(
112 | nn.Linear(in_features=merge_channels, out_features=merge_channels//compress_rate),
113 | nn.ELU(),
114 | nn.Conv1d(in_channels=merge_bands, out_channels=merge_bands//compress_rate, kernel_size=1, stride=1)
115 | )
116 |
117 | self.dual_path_extension_rnn_list = nn.ModuleList()
118 | for _ in range(configs.dual_path_extension["num_modules"]):
119 | self.dual_path_extension_rnn_list.append(DualPathExtensionRNN(**configs.dual_path_extension["parameters"]))
120 |
121 | self.feature_split_layer = nn.Sequential(
122 | nn.Conv1d(in_channels=merge_bands//compress_rate, out_channels=merge_bands, kernel_size=1, stride=1),
123 | nn.Linear(in_features=merge_channels//compress_rate, out_features=merge_channels),
124 | nn.ELU()
125 | )
126 |
127 | self.full_band_decoder = FullBandDecoder(configs)
128 | self.sub_band_decoder = SubBandDecoder(configs)
129 |
130 | self.mask_padding = nn.ConstantPad2d(padding=(1, 0, 0, 0), value=0.0)
131 |
132 | def forward(self, in_complex_spectrum: Tensor, in_amplitude_spectrum: Tensor, hidden_state: list):
133 | """
134 | :param in_amplitude_spectrum: (batch, frames, 1, frequency)
135 | :param hidden_state:
136 | :param in_complex_spectrum: (batch, frames, channels, frequency)
137 | :return:
138 | """
139 | batch, frames, channels, frequency = in_complex_spectrum.shape
140 | complex_spectrum = torch.reshape(in_complex_spectrum, shape=(batch * frames, channels, frequency))
141 | amplitude_spectrum = torch.reshape(in_amplitude_spectrum, shape=(batch*frames, 1, frequency))
142 |
143 | full_band_encode_outs, global_feature = self.full_band_encoder(complex_spectrum)
144 | sub_band_encode_outs, local_feature = self.sub_band_encoder(amplitude_spectrum)
145 |
146 | merge_feature = torch.cat(tensors=[global_feature, local_feature], dim=2) # feature cat
147 | merge_feature = self.feature_merge_layer(merge_feature)
148 | # (batch*frames, channels, frequency) -> (batch*frames, channels//2, frequency//2)
149 | _, channels, frequency = merge_feature.shape
150 | merge_feature = torch.reshape(merge_feature, shape=(batch, frames, channels, frequency))
151 | merge_feature = torch.permute(merge_feature, dims=(0, 3, 1, 2)).contiguous()
152 | # (batch, frequency, frames, channels)
153 | out_hidden_state = list()
154 | for idx, rnn_layer in enumerate(self.dual_path_extension_rnn_list):
155 | merge_feature, state = rnn_layer(merge_feature, hidden_state[idx])
156 | out_hidden_state.append(state)
157 |
158 | merge_feature = torch.permute(merge_feature, dims=(0, 2, 3, 1)).contiguous()
159 | merge_feature = torch.reshape(merge_feature, shape=(batch * frames, channels, frequency))
160 |
161 | split_feature = self.feature_split_layer(merge_feature)
162 | first_dim, channels, frequency = split_feature.shape
163 | split_feature = torch.reshape(split_feature, shape=(first_dim, channels, -1, 2))
164 |
165 | full_band_mask = self.full_band_decoder(split_feature[..., 0], full_band_encode_outs)
166 | sub_band_mask = self.sub_band_decoder(split_feature[..., 1], sub_band_encode_outs)
167 |
168 | full_band_mask = torch.reshape(full_band_mask, shape=(batch, frames, 2, -1))
169 | sub_band_mask = torch.reshape(sub_band_mask, shape=(batch, frames, 1, -1))
170 |
171 | # Zero padding in the DC signal part removes the DC component
172 | full_band_mask = self.mask_padding(full_band_mask)
173 | sub_band_mask = self.mask_padding(sub_band_mask)
174 |
175 | full_band_out = in_complex_spectrum * full_band_mask
176 | sub_band_out = in_amplitude_spectrum * sub_band_mask
177 | # outputs is (batch, frames, 2, frequency), complex style.
178 |
179 | return full_band_out + sub_band_out, out_hidden_state
180 |
--------------------------------------------------------------------------------
/modules/en_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 |
4 |
5 | class FullBandEncoderBlock(nn.Module):
6 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int):
7 | super().__init__()
8 | self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels,
9 | kernel_size=kernel_size, stride=stride, padding=padding)
10 |
11 | self.norm = nn.BatchNorm1d(num_features=out_channels)
12 |
13 | self.activate = nn.ELU()
14 |
15 | def forward(self, complex_spectrum: Tensor):
16 | """
17 | :param complex_spectrum: (batch * frames, channels, frequency)
18 | :return:
19 | """
20 | complex_spectrum = self.conv(complex_spectrum)
21 | complex_spectrum = self.norm(complex_spectrum)
22 | complex_spectrum = self.activate(complex_spectrum)
23 |
24 | return complex_spectrum
25 |
26 |
27 | class FullBandDecoderBlock(nn.Module):
28 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int):
29 | super().__init__()
30 | self.conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels // 2,
31 | kernel_size=1, stride=1, padding=0)
32 | self.convT = nn.ConvTranspose1d(in_channels // 2, out_channels, kernel_size=kernel_size, stride=stride,
33 | padding=padding)
34 |
35 | self.norm = nn.BatchNorm1d(num_features=out_channels)
36 | self.activate = nn.ELU()
37 |
38 | def forward(self, encode_complex_spectrum: Tensor, decode_complex_spectrum):
39 | """
40 | :param decode_complex_spectrum: (batch * frames, channels1, frequency)
41 | :param encode_complex_spectrum: (batch * frames, channels2, frequency)
42 | :return:
43 | """
44 | complex_spectrum = torch.cat([encode_complex_spectrum, decode_complex_spectrum], dim=1)
45 | complex_spectrum = self.conv(complex_spectrum)
46 | complex_spectrum = self.convT(complex_spectrum)
47 | complex_spectrum = self.norm(complex_spectrum)
48 | complex_spectrum = self.activate(complex_spectrum)
49 |
50 | return complex_spectrum
51 |
52 |
53 | class SubBandEncoderBlock(nn.Module):
54 | def __init__(self, start_frequency: int,
55 | end_frequency: int,
56 | in_channels: int,
57 | out_channels: int,
58 | kernel_size: int,
59 | stride: int,
60 | padding: int):
61 | super().__init__()
62 | self.start_frequency = start_frequency
63 | self.end_frequency = end_frequency
64 |
65 | self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
66 | stride=stride, padding=padding)
67 | self.activate = nn.ReLU()
68 |
69 | def forward(self, amplitude_spectrum: Tensor):
70 | """
71 | :param amplitude_spectrum: (batch*frames, channels, frequency)
72 | :return:
73 | """
74 | sub_spectrum = amplitude_spectrum[:, :, self.start_frequency:self.end_frequency]
75 |
76 | sub_spectrum = self.conv(sub_spectrum) # (batch*frames, out_channels, sub_bands)
77 | sub_spectrum = self.activate(sub_spectrum)
78 |
79 | return sub_spectrum
80 |
81 |
82 | class SubBandDecoderBlock(nn.Module):
83 | def __init__(self, in_features: int, out_features: int, start_idx: int, end_idx: int):
84 | super().__init__()
85 | self.start_idx = start_idx
86 | self.end_idx = end_idx
87 | self.fc = nn.Linear(in_features=in_features, out_features=out_features)
88 | self.activate = nn.ReLU()
89 |
90 | def forward(self, encode_amplitude_spectrum: Tensor, decode_amplitude_spectrum: Tensor):
91 | """
92 |
93 | :param encode_amplitude_spectrum: (batch * frames, channels, sub_bands)
94 | :param decode_amplitude_spectrum: (batch * frames, channels, sub_bands)
95 | :return:
96 | """
97 | encode_amplitude_spectrum = encode_amplitude_spectrum[:, :, self.start_idx: self.end_idx]
98 | spectrum = torch.cat([encode_amplitude_spectrum, decode_amplitude_spectrum], dim=1) # channels cat
99 | spectrum = torch.transpose(spectrum, dim0=1, dim1=2).contiguous() # (*, bands, channels)
100 |
101 | spectrum = self.fc(spectrum) # (*, bands, band-width)
102 | spectrum = self.activate(spectrum)
103 | first_dim, bands, band_width = spectrum.shape
104 | spectrum = torch.reshape(spectrum, shape=(first_dim, bands*band_width))
105 |
106 | return spectrum
107 |
--------------------------------------------------------------------------------
/modules/sequence_modules.py:
--------------------------------------------------------------------------------
1 | # !/user/bin/env python
2 | # -*-coding:utf-8 -*-
3 |
4 | """
5 | # File : sequence_modules.py
6 | # Time : 2024/4/10 上午9:35
7 | # Author : wukeyi
8 | # version : python3.9
9 | """
10 | from typing import List
11 |
12 | import torch
13 | from torch import nn, Tensor
14 |
15 |
16 | class GroupRNN(nn.Module):
17 | def __init__(self, input_size: int,
18 | hidden_size: int,
19 | groups: int,
20 | rnn_type: str,
21 | num_layers: int = 1,
22 | bidirectional: bool = False,
23 | batch_first: bool = True):
24 | super().__init__()
25 | assert input_size % groups == 0, \
26 | f"input_size % groups must be equal to 0, but got {input_size} % {groups} = {input_size % groups}"
27 |
28 | self.groups = groups
29 | self.rnn_list = nn.ModuleList()
30 | for _ in range(groups):
31 | self.rnn_list.append(
32 | getattr(nn, rnn_type)(input_size=input_size // groups, hidden_size=hidden_size//groups,
33 | num_layers=num_layers,
34 | bidirectional=bidirectional, batch_first=batch_first)
35 | )
36 |
37 | def forward(self, inputs: Tensor, hidden_state: List[Tensor]):
38 | """
39 | :param hidden_state: List[state1, state2, ...], len(hidden_state) = groups
40 | state shape = (num_layers*bidirectional, batch*[], hidden_size) if rnn_type is GRU or RNN, otherwise,
41 | state = (h0, c0), h0/c0 shape = (num_layers*bidirectional, batch*[], hidden_size).
42 | :param inputs: (batch, steps, input_size)
43 | :return:
44 | """
45 | outputs = []
46 | out_states = []
47 | batch, steps, _ = inputs.shape
48 |
49 | inputs = torch.reshape(inputs, shape=(batch, steps, self.groups, -1)) # (batch, steps, groups, width)
50 | for idx, rnn in enumerate(self.rnn_list):
51 | out, state = rnn(inputs[:, :, idx, :], hidden_state[idx])
52 | outputs.append(out) # (batch, steps, hidden_size)
53 | out_states.append(state) # (num_layers*bidirectional, batch*[], hidden_size)
54 |
55 | outputs = torch.cat(outputs, dim=2) # (batch, steps, hidden_size * groups)
56 |
57 | return outputs, out_states
58 |
59 |
60 | class DualPathExtensionRNN(nn.Module):
61 | def __init__(self, input_size: int,
62 | intra_hidden_size: int,
63 | inter_hidden_size: int,
64 | groups: int,
65 | rnn_type: str):
66 | super().__init__()
67 | assert rnn_type in ["RNN", "GRU", "LSTM"], f"rnn_type should be RNN/GRU/LSTM, but got {rnn_type}!"
68 |
69 | self.intra_chunk_rnn = getattr(nn, rnn_type)(input_size=input_size, hidden_size=intra_hidden_size,
70 | num_layers=1, bidirectional=True, batch_first=True)
71 | self.intra_chunk_fc = nn.Linear(in_features=intra_hidden_size*2, out_features=input_size)
72 | self.intra_chunk_norm = nn.LayerNorm(normalized_shape=input_size, elementwise_affine=True)
73 |
74 | self.inter_chunk_rnn = GroupRNN(input_size=input_size, hidden_size=inter_hidden_size, groups=groups,
75 | rnn_type=rnn_type)
76 | self.inter_chunk_fc = nn.Linear(in_features=inter_hidden_size, out_features=input_size)
77 |
78 | def forward(self, inputs: Tensor, hidden_state: List[Tensor]):
79 | """
80 | :param hidden_state: List[state1, state2, ...], len(hidden_state) = groups
81 | state shape = (num_layers*bidirectional, batch*[], hidden_size) if rnn_type is GRU or RNN, otherwise,
82 | state = (h0, c0), h0/c0 shape = (num_layers*bidirectional, batch*[], hidden_size).
83 | :param inputs: (B, F, T, N)
84 | :return:
85 | """
86 | B, F, T, N = inputs.shape
87 | intra_out = torch.transpose(inputs, dim0=1, dim1=2).contiguous() # (B, T, F, N)
88 | intra_out = torch.reshape(intra_out, shape=(B * T, F, N))
89 | intra_out, _ = self.intra_chunk_rnn(intra_out)
90 | intra_out = self.intra_chunk_fc(intra_out) # (B, T, F, N)
91 | intra_out = torch.reshape(intra_out, shape=(B, T, F, N))
92 | intra_out = torch.transpose(intra_out, dim0=1, dim1=2).contiguous() # (B, F, T, N)
93 | intra_out = self.intra_chunk_norm(intra_out) # (B, F, T, N)
94 |
95 | intra_out = inputs + intra_out # residual add
96 |
97 | inter_out = torch.reshape(intra_out, shape=(B * F, T, N)) # (B*F, T, N)
98 | inter_out, hidden_state = self.inter_chunk_rnn(inter_out, hidden_state)
99 | inter_out = torch.reshape(inter_out, shape=(B, F, T, -1)) # (B, F, T, groups * N)
100 | inter_out = self.inter_chunk_fc(inter_out) # (B, F, T, N)
101 |
102 | inter_out = inter_out + intra_out # residual add
103 |
104 | return inter_out, hidden_state
105 |
106 |
107 | if __name__ == "__main__":
108 | test_model = DualPathExtensionRNN(input_size=32, intra_hidden_size=16, inter_hidden_size=16,
109 | groups=8, rnn_type="LSTM")
110 | test_data = torch.randn(5, 32, 10, 32)
111 | test_state = [(torch.randn(1, 5*32, 16), torch.randn(1, 5*32, 16)) for _ in range(8)]
112 | test_out = test_model(test_data, test_state)
113 |
--------------------------------------------------------------------------------
/run_train.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import torch
4 | from thop import profile, clever_format
5 |
6 | from configs.train_configs import TrainConfig
7 | from models.fspen import FullSubPathExtension
8 |
9 |
10 | if __name__ == "__main__":
11 | configs = TrainConfig()
12 | with open("config.json", mode="w", encoding="utf-8") as file:
13 | json.dump(configs.__dict__, file, indent=4)
14 |
15 | model = FullSubPathExtension(configs)
16 |
17 | batch = 1
18 | groups = configs.dual_path_extension["parameters"]["groups"]
19 | inter_hidden_size = configs.dual_path_extension["parameters"]["inter_hidden_size"]
20 | num_modules = configs.dual_path_extension["num_modules"]
21 | num_bands = sum(configs.bands_num_in_groups)
22 |
23 | in_wav = torch.randn(1, configs.train_points)
24 | complex_spectrum = torch.stft(in_wav, n_fft=configs.n_fft, hop_length=configs.hop_length,
25 | window=torch.hamming_window(configs.n_fft), return_complex=True) # (B, F, T)
26 | amplitude_spectrum = torch.abs(complex_spectrum)
27 |
28 | complex_spectrum = torch.view_as_real(complex_spectrum) # (B, F, T, 2)
29 | complex_spectrum = torch.permute(complex_spectrum, dims=(0, 2, 3, 1))
30 | _, frames, channels, frequency = complex_spectrum.shape
31 | complex_spectrum = torch.reshape(complex_spectrum, shape=(batch, frames, channels, frequency))
32 | amplitude_spectrum = torch.permute(amplitude_spectrum, dims=(0, 2, 1))
33 | amplitude_spectrum = torch.reshape(amplitude_spectrum, shape=(batch, frames, 1, frequency))
34 | #
35 | in_hidden_state = [[torch.zeros(1, batch * num_bands, inter_hidden_size//groups) for _ in range(groups)]
36 | for _ in range(num_modules)]
37 |
38 | flops, params = profile(model, inputs=(complex_spectrum, amplitude_spectrum, in_hidden_state))
39 | flops, params = clever_format(nums=[flops, params], format="%0.4f")
40 | print(f"flops: {flops} \nparams: {params}")
41 |
--------------------------------------------------------------------------------