├── .idea ├── .gitignore ├── fspen_m.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── config.json ├── configs └── train_configs.py ├── models └── fspen.py ├── modules ├── en_decoder.py └── sequence_modules.py └── run_train.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/fspen_m.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 66 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # description 2 | un official implement of [FSPEN: AN ULTRA-LIGHTWEIGHT NETWORK FOR REAL TIME SPEECH 3 | ENAHNCMENT](https://ieeexplore.ieee.org/document/10446016)
4 | 5 | the model is also can use for stream inference. 6 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "sample_rate": 16000, 3 | "n_fft": 512, 4 | "hop_length": 256, 5 | "train_frames": 62, 6 | "train_points": 15616, 7 | "full_band_encoder": { 8 | "encoder1": { 9 | "in_channels": 2, 10 | "out_channels": 4, 11 | "kernel_size": 6, 12 | "stride": 2, 13 | "padding": 2 14 | }, 15 | "encoder2": { 16 | "in_channels": 4, 17 | "out_channels": 16, 18 | "kernel_size": 8, 19 | "stride": 2, 20 | "padding": 3 21 | }, 22 | "encoder3": { 23 | "in_channels": 16, 24 | "out_channels": 32, 25 | "kernel_size": 6, 26 | "stride": 2, 27 | "padding": 2 28 | } 29 | }, 30 | "full_band_decoder": { 31 | "decoder1": { 32 | "in_channels": 64, 33 | "out_channels": 16, 34 | "kernel_size": 6, 35 | "stride": 2, 36 | "padding": 2 37 | }, 38 | "decoder2": { 39 | "in_channels": 32, 40 | "out_channels": 4, 41 | "kernel_size": 8, 42 | "stride": 2, 43 | "padding": 3 44 | }, 45 | "decoder3": { 46 | "in_channels": 8, 47 | "out_channels": 2, 48 | "kernel_size": 6, 49 | "stride": 2, 50 | "padding": 2 51 | } 52 | }, 53 | "sub_band_encoder": { 54 | "encoder1": { 55 | "group_width": 16, 56 | "conv": { 57 | "start_frequency": 0, 58 | "end_frequency": 16, 59 | "in_channels": 1, 60 | "out_channels": 32, 61 | "kernel_size": 4, 62 | "stride": 2, 63 | "padding": 1 64 | } 65 | }, 66 | "encoder2": { 67 | "group_width": 18, 68 | "conv": { 69 | "start_frequency": 16, 70 | "end_frequency": 34, 71 | "in_channels": 1, 72 | "out_channels": 32, 73 | "kernel_size": 7, 74 | "stride": 3, 75 | "padding": 2 76 | } 77 | }, 78 | "encoder3": { 79 | "group_width": 36, 80 | "conv": { 81 | "start_frequency": 34, 82 | "end_frequency": 70, 83 | "in_channels": 1, 84 | "out_channels": 32, 85 | "kernel_size": 11, 86 | "stride": 5, 87 | "padding": 2 88 | } 89 | }, 90 | "encoder4": { 91 | "group_width": 66, 92 | "conv": { 93 | "start_frequency": 70, 94 | "end_frequency": 136, 95 | "in_channels": 1, 96 | "out_channels": 32, 97 | "kernel_size": 20, 98 | "stride": 10, 99 | "padding": 4 100 | } 101 | }, 102 | "encoder5": { 103 | "group_width": 121, 104 | "conv": { 105 | "start_frequency": 136, 106 | "end_frequency": 257, 107 | "in_channels": 1, 108 | "out_channels": 32, 109 | "kernel_size": 30, 110 | "stride": 20, 111 | "padding": 5 112 | } 113 | } 114 | }, 115 | "merge_split": { 116 | "channels": 64, 117 | "bands": 32, 118 | "compress_rate": 2 119 | }, 120 | "bands_num_in_groups": [ 121 | 8, 122 | 6, 123 | 6, 124 | 6, 125 | 6 126 | ], 127 | "band_width_in_groups": [ 128 | 2, 129 | 3, 130 | 6, 131 | 11, 132 | 20 133 | ], 134 | "sub_band_decoder": { 135 | "decoder0": { 136 | "in_features": 64, 137 | "out_features": 2 138 | }, 139 | "decoder1": { 140 | "in_features": 64, 141 | "out_features": 3 142 | }, 143 | "decoder2": { 144 | "in_features": 64, 145 | "out_features": 6 146 | }, 147 | "decoder3": { 148 | "in_features": 64, 149 | "out_features": 11 150 | }, 151 | "decoder4": { 152 | "in_features": 64, 153 | "out_features": 20 154 | } 155 | }, 156 | "dual_path_extension": { 157 | "num_modules": 3, 158 | "parameters": { 159 | "input_size": 16, 160 | "intra_hidden_size": 16, 161 | "inter_hidden_size": 16, 162 | "groups": 8, 163 | "rnn_type": "GRU" 164 | } 165 | } 166 | } -------------------------------------------------------------------------------- /configs/train_configs.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | 3 | from pydantic import BaseModel, field_validator 4 | 5 | 6 | def get_sub_bands(band_parameters: dict): 7 | group_bands = list() 8 | group_band_width = list() 9 | for key, value in band_parameters.items(): 10 | num_band = (value["group_width"] - value["conv"]["kernel_size"] + 11 | 2 * value["conv"]["padding"]) // value["conv"]["stride"] + 1 12 | sub_band_width = value["group_width"] // num_band 13 | group_bands.append(num_band) 14 | group_band_width.append(sub_band_width) 15 | 16 | return tuple(group_bands), tuple(group_band_width) 17 | 18 | 19 | class TrainConfig(BaseModel): 20 | sample_rate: int = 16000 21 | n_fft: int = 512 22 | hop_length: int = 256 23 | train_frames: int = 62 24 | train_points: int = (train_frames - 1) * hop_length 25 | 26 | full_band_encoder: Dict[str, dict] = { 27 | "encoder1": {"in_channels": 2, "out_channels": 4, "kernel_size": 6, "stride": 2, "padding": 2}, 28 | "encoder2": {"in_channels": 4, "out_channels": 16, "kernel_size": 8, "stride": 2, "padding": 3}, 29 | "encoder3": {"in_channels": 16, "out_channels": 32, "kernel_size": 6, "stride": 2, "padding": 2} 30 | } 31 | full_band_decoder: Dict[str, dict] = { 32 | "decoder1": {"in_channels": 64, "out_channels": 16, "kernel_size": 6, "stride": 2, "padding": 2}, 33 | "decoder2": {"in_channels": 32, "out_channels": 4, "kernel_size": 8, "stride": 2, "padding": 3}, 34 | "decoder3": {"in_channels": 8, "out_channels": 2, "kernel_size": 6, "stride": 2, "padding": 2} 35 | } 36 | 37 | sub_band_encoder: Dict[str, dict] = { 38 | "encoder1": {"group_width": 16, "conv": {"start_frequency": 0, "end_frequency": 16, "in_channels": 1, 39 | "out_channels": 32, "kernel_size": 4, "stride": 2, "padding": 1}}, 40 | "encoder2": {"group_width": 18, "conv": {"start_frequency": 16, "end_frequency": 34, "in_channels": 1, 41 | "out_channels": 32, "kernel_size": 7, "stride": 3, "padding": 2}}, 42 | "encoder3": {"group_width": 36, "conv": {"start_frequency": 34, "end_frequency": 70, "in_channels": 1, 43 | "out_channels": 32, "kernel_size": 11, "stride": 5, "padding": 2}}, 44 | "encoder4": {"group_width": 66, "conv": {"start_frequency": 70, "end_frequency": 136, "in_channels": 1, 45 | "out_channels": 32, "kernel_size": 20, "stride": 10, "padding": 4}}, 46 | "encoder5": {"group_width": 121, "conv": {"start_frequency": 136, "end_frequency": 257, "in_channels": 1, 47 | "out_channels": 32, "kernel_size": 30, "stride": 20, "padding": 5}} 48 | } 49 | merge_split: dict = {"channels": 64, "bands": 32, "compress_rate": 2} 50 | bands_num_in_groups: Tuple[int] = get_sub_bands(sub_band_encoder)[0] 51 | band_width_in_groups: Tuple[int] = get_sub_bands(sub_band_encoder)[1] 52 | 53 | sub_band_decoder: Dict[str, dict] = {f"decoder{idx}": {"in_features": 64, "out_features": width} 54 | for idx, width in enumerate(band_width_in_groups)} 55 | 56 | dual_path_extension: dict = { 57 | "num_modules": 3, 58 | "parameters": {"input_size": 16, "intra_hidden_size": 16, "inter_hidden_size": 16, 59 | "groups": 8, "rnn_type": "GRU"} 60 | } 61 | 62 | @field_validator("sub_band_decoder") 63 | def sub_band_decoder_validate(cls, decoders): 64 | for decoder in decoders: 65 | if decoder["out_feature"] < 2: 66 | raise ValueError(f"values should > 2, but got {decoder['out_feature']}") 67 | 68 | 69 | if __name__ == "__main__": 70 | test_configs = TrainConfig() 71 | 72 | for (decoder_name, parameters), _, in zip(test_configs.sub_band_encoder.items(), test_configs.bands_num_in_groups): 73 | print(parameters) 74 | -------------------------------------------------------------------------------- /models/fspen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn, Tensor 4 | from modules.en_decoder import FullBandEncoderBlock, FullBandDecoderBlock 5 | from modules.en_decoder import SubBandEncoderBlock, SubBandDecoderBlock 6 | from modules.sequence_modules import DualPathExtensionRNN 7 | from configs.train_configs import TrainConfig 8 | 9 | 10 | class FullBandEncoder(nn.Module): 11 | def __init__(self, configs: TrainConfig): 12 | super().__init__() 13 | 14 | last_channels = 0 15 | self.full_band_encoder = nn.ModuleList() 16 | for encoder_name, conv_parameter in configs.full_band_encoder.items(): 17 | self.full_band_encoder.append(FullBandEncoderBlock(**conv_parameter)) 18 | last_channels = conv_parameter["out_channels"] 19 | 20 | self.global_features = nn.Conv1d(in_channels=last_channels, out_channels=last_channels, kernel_size=1, stride=1) 21 | 22 | def forward(self, complex_spectrum: Tensor): 23 | """ 24 | :param complex_spectrum: (batch*frame, channels, frequency) 25 | :return: 26 | """ 27 | full_band_encodes = [] 28 | for encoder in self.full_band_encoder: 29 | complex_spectrum = encoder(complex_spectrum) 30 | full_band_encodes.append(complex_spectrum) 31 | 32 | global_feature = self.global_features(complex_spectrum) 33 | 34 | return full_band_encodes[::-1], global_feature 35 | 36 | 37 | class SubBandEncoder(nn.Module): 38 | def __init__(self, configs: TrainConfig): 39 | super().__init__() 40 | 41 | self.sub_band_encoders = nn.ModuleList() 42 | for encoder_name, conv_parameters in configs.sub_band_encoder.items(): 43 | self.sub_band_encoders.append(SubBandEncoderBlock(**conv_parameters["conv"])) 44 | 45 | def forward(self, amplitude_spectrum: Tensor): 46 | """ 47 | :param amplitude_spectrum: (batch * frames, channels, frequency) 48 | :return: 49 | """ 50 | sub_band_encodes = list() 51 | for encoder in self.sub_band_encoders: 52 | encode_out = encoder(amplitude_spectrum) 53 | sub_band_encodes.append(encode_out) 54 | 55 | local_feature = torch.cat(sub_band_encodes, dim=2) # feature cat 56 | 57 | return sub_band_encodes, local_feature 58 | 59 | 60 | class FullBandDecoder(nn.Module): 61 | def __init__(self, configs: TrainConfig): 62 | super().__init__() 63 | self.full_band_decoders = nn.ModuleList() 64 | for decoder_name, parameters in configs.full_band_decoder.items(): 65 | self.full_band_decoders.append( 66 | FullBandDecoderBlock(**parameters)) 67 | 68 | def forward(self, feature: Tensor, encode_outs: list): 69 | for decoder, encode_out in zip(self.full_band_decoders, encode_outs): 70 | feature = decoder(feature, encode_out) 71 | 72 | return feature 73 | 74 | 75 | class SubBandDecoder(nn.Module): 76 | def __init__(self, configs: TrainConfig): 77 | super().__init__() 78 | start_idx = 0 79 | self.sub_band_decoders = nn.ModuleList() 80 | for (decoder_name, parameters), bands in zip(configs.sub_band_decoder.items(), configs.bands_num_in_groups): 81 | end_idx = start_idx + bands 82 | self.sub_band_decoders.append(SubBandDecoderBlock(start_idx=start_idx, end_idx=end_idx, **parameters)) 83 | 84 | def forward(self, feature: Tensor, sub_encodes: list): 85 | """ 86 | :param feature: (batch*frames, channels, bands) 87 | :param sub_encodes: [sub_encode_0, sub_encode_1, ...], each element is (batch*frames, channels, sub_bands) 88 | :return: (batch*frames, full-frequency) 89 | """ 90 | sub_decoder_outs = [] 91 | for decoder, sub_encode in zip(self.sub_band_decoders, sub_encodes): 92 | sub_decoder_out = decoder(feature, sub_encode) 93 | sub_decoder_outs.append(sub_decoder_out) 94 | 95 | sub_decoder_outs = torch.cat(tensors=sub_decoder_outs, dim=1) # feature cat 96 | 97 | return sub_decoder_outs 98 | 99 | 100 | class FullSubPathExtension(nn.Module): 101 | def __init__(self, configs: TrainConfig): 102 | super().__init__() 103 | self.full_band_encoder = FullBandEncoder(configs) 104 | self.sub_band_encoder = SubBandEncoder(configs) 105 | 106 | merge_split = configs.merge_split 107 | merge_channels = merge_split["channels"] 108 | merge_bands = merge_split["bands"] 109 | compress_rate = merge_split["compress_rate"] 110 | 111 | self.feature_merge_layer = nn.Sequential( 112 | nn.Linear(in_features=merge_channels, out_features=merge_channels//compress_rate), 113 | nn.ELU(), 114 | nn.Conv1d(in_channels=merge_bands, out_channels=merge_bands//compress_rate, kernel_size=1, stride=1) 115 | ) 116 | 117 | self.dual_path_extension_rnn_list = nn.ModuleList() 118 | for _ in range(configs.dual_path_extension["num_modules"]): 119 | self.dual_path_extension_rnn_list.append(DualPathExtensionRNN(**configs.dual_path_extension["parameters"])) 120 | 121 | self.feature_split_layer = nn.Sequential( 122 | nn.Conv1d(in_channels=merge_bands//compress_rate, out_channels=merge_bands, kernel_size=1, stride=1), 123 | nn.Linear(in_features=merge_channels//compress_rate, out_features=merge_channels), 124 | nn.ELU() 125 | ) 126 | 127 | self.full_band_decoder = FullBandDecoder(configs) 128 | self.sub_band_decoder = SubBandDecoder(configs) 129 | 130 | self.mask_padding = nn.ConstantPad2d(padding=(1, 0, 0, 0), value=0.0) 131 | 132 | def forward(self, in_complex_spectrum: Tensor, in_amplitude_spectrum: Tensor, hidden_state: list): 133 | """ 134 | :param in_amplitude_spectrum: (batch, frames, 1, frequency) 135 | :param hidden_state: 136 | :param in_complex_spectrum: (batch, frames, channels, frequency) 137 | :return: 138 | """ 139 | batch, frames, channels, frequency = in_complex_spectrum.shape 140 | complex_spectrum = torch.reshape(in_complex_spectrum, shape=(batch * frames, channels, frequency)) 141 | amplitude_spectrum = torch.reshape(in_amplitude_spectrum, shape=(batch*frames, 1, frequency)) 142 | 143 | full_band_encode_outs, global_feature = self.full_band_encoder(complex_spectrum) 144 | sub_band_encode_outs, local_feature = self.sub_band_encoder(amplitude_spectrum) 145 | 146 | merge_feature = torch.cat(tensors=[global_feature, local_feature], dim=2) # feature cat 147 | merge_feature = self.feature_merge_layer(merge_feature) 148 | # (batch*frames, channels, frequency) -> (batch*frames, channels//2, frequency//2) 149 | _, channels, frequency = merge_feature.shape 150 | merge_feature = torch.reshape(merge_feature, shape=(batch, frames, channels, frequency)) 151 | merge_feature = torch.permute(merge_feature, dims=(0, 3, 1, 2)).contiguous() 152 | # (batch, frequency, frames, channels) 153 | out_hidden_state = list() 154 | for idx, rnn_layer in enumerate(self.dual_path_extension_rnn_list): 155 | merge_feature, state = rnn_layer(merge_feature, hidden_state[idx]) 156 | out_hidden_state.append(state) 157 | 158 | merge_feature = torch.permute(merge_feature, dims=(0, 2, 3, 1)).contiguous() 159 | merge_feature = torch.reshape(merge_feature, shape=(batch * frames, channels, frequency)) 160 | 161 | split_feature = self.feature_split_layer(merge_feature) 162 | first_dim, channels, frequency = split_feature.shape 163 | split_feature = torch.reshape(split_feature, shape=(first_dim, channels, -1, 2)) 164 | 165 | full_band_mask = self.full_band_decoder(split_feature[..., 0], full_band_encode_outs) 166 | sub_band_mask = self.sub_band_decoder(split_feature[..., 1], sub_band_encode_outs) 167 | 168 | full_band_mask = torch.reshape(full_band_mask, shape=(batch, frames, 2, -1)) 169 | sub_band_mask = torch.reshape(sub_band_mask, shape=(batch, frames, 1, -1)) 170 | 171 | # Zero padding in the DC signal part removes the DC component 172 | full_band_mask = self.mask_padding(full_band_mask) 173 | sub_band_mask = self.mask_padding(sub_band_mask) 174 | 175 | full_band_out = in_complex_spectrum * full_band_mask 176 | sub_band_out = in_amplitude_spectrum * sub_band_mask 177 | # outputs is (batch, frames, 2, frequency), complex style. 178 | 179 | return full_band_out + sub_band_out, out_hidden_state 180 | -------------------------------------------------------------------------------- /modules/en_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | 5 | class FullBandEncoderBlock(nn.Module): 6 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int): 7 | super().__init__() 8 | self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, 9 | kernel_size=kernel_size, stride=stride, padding=padding) 10 | 11 | self.norm = nn.BatchNorm1d(num_features=out_channels) 12 | 13 | self.activate = nn.ELU() 14 | 15 | def forward(self, complex_spectrum: Tensor): 16 | """ 17 | :param complex_spectrum: (batch * frames, channels, frequency) 18 | :return: 19 | """ 20 | complex_spectrum = self.conv(complex_spectrum) 21 | complex_spectrum = self.norm(complex_spectrum) 22 | complex_spectrum = self.activate(complex_spectrum) 23 | 24 | return complex_spectrum 25 | 26 | 27 | class FullBandDecoderBlock(nn.Module): 28 | def __init__(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int): 29 | super().__init__() 30 | self.conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels // 2, 31 | kernel_size=1, stride=1, padding=0) 32 | self.convT = nn.ConvTranspose1d(in_channels // 2, out_channels, kernel_size=kernel_size, stride=stride, 33 | padding=padding) 34 | 35 | self.norm = nn.BatchNorm1d(num_features=out_channels) 36 | self.activate = nn.ELU() 37 | 38 | def forward(self, encode_complex_spectrum: Tensor, decode_complex_spectrum): 39 | """ 40 | :param decode_complex_spectrum: (batch * frames, channels1, frequency) 41 | :param encode_complex_spectrum: (batch * frames, channels2, frequency) 42 | :return: 43 | """ 44 | complex_spectrum = torch.cat([encode_complex_spectrum, decode_complex_spectrum], dim=1) 45 | complex_spectrum = self.conv(complex_spectrum) 46 | complex_spectrum = self.convT(complex_spectrum) 47 | complex_spectrum = self.norm(complex_spectrum) 48 | complex_spectrum = self.activate(complex_spectrum) 49 | 50 | return complex_spectrum 51 | 52 | 53 | class SubBandEncoderBlock(nn.Module): 54 | def __init__(self, start_frequency: int, 55 | end_frequency: int, 56 | in_channels: int, 57 | out_channels: int, 58 | kernel_size: int, 59 | stride: int, 60 | padding: int): 61 | super().__init__() 62 | self.start_frequency = start_frequency 63 | self.end_frequency = end_frequency 64 | 65 | self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 66 | stride=stride, padding=padding) 67 | self.activate = nn.ReLU() 68 | 69 | def forward(self, amplitude_spectrum: Tensor): 70 | """ 71 | :param amplitude_spectrum: (batch*frames, channels, frequency) 72 | :return: 73 | """ 74 | sub_spectrum = amplitude_spectrum[:, :, self.start_frequency:self.end_frequency] 75 | 76 | sub_spectrum = self.conv(sub_spectrum) # (batch*frames, out_channels, sub_bands) 77 | sub_spectrum = self.activate(sub_spectrum) 78 | 79 | return sub_spectrum 80 | 81 | 82 | class SubBandDecoderBlock(nn.Module): 83 | def __init__(self, in_features: int, out_features: int, start_idx: int, end_idx: int): 84 | super().__init__() 85 | self.start_idx = start_idx 86 | self.end_idx = end_idx 87 | self.fc = nn.Linear(in_features=in_features, out_features=out_features) 88 | self.activate = nn.ReLU() 89 | 90 | def forward(self, encode_amplitude_spectrum: Tensor, decode_amplitude_spectrum: Tensor): 91 | """ 92 | 93 | :param encode_amplitude_spectrum: (batch * frames, channels, sub_bands) 94 | :param decode_amplitude_spectrum: (batch * frames, channels, sub_bands) 95 | :return: 96 | """ 97 | encode_amplitude_spectrum = encode_amplitude_spectrum[:, :, self.start_idx: self.end_idx] 98 | spectrum = torch.cat([encode_amplitude_spectrum, decode_amplitude_spectrum], dim=1) # channels cat 99 | spectrum = torch.transpose(spectrum, dim0=1, dim1=2).contiguous() # (*, bands, channels) 100 | 101 | spectrum = self.fc(spectrum) # (*, bands, band-width) 102 | spectrum = self.activate(spectrum) 103 | first_dim, bands, band_width = spectrum.shape 104 | spectrum = torch.reshape(spectrum, shape=(first_dim, bands*band_width)) 105 | 106 | return spectrum 107 | -------------------------------------------------------------------------------- /modules/sequence_modules.py: -------------------------------------------------------------------------------- 1 | # !/user/bin/env python 2 | # -*-coding:utf-8 -*- 3 | 4 | """ 5 | # File : sequence_modules.py 6 | # Time : 2024/4/10 上午9:35 7 | # Author : wukeyi 8 | # version : python3.9 9 | """ 10 | from typing import List 11 | 12 | import torch 13 | from torch import nn, Tensor 14 | 15 | 16 | class GroupRNN(nn.Module): 17 | def __init__(self, input_size: int, 18 | hidden_size: int, 19 | groups: int, 20 | rnn_type: str, 21 | num_layers: int = 1, 22 | bidirectional: bool = False, 23 | batch_first: bool = True): 24 | super().__init__() 25 | assert input_size % groups == 0, \ 26 | f"input_size % groups must be equal to 0, but got {input_size} % {groups} = {input_size % groups}" 27 | 28 | self.groups = groups 29 | self.rnn_list = nn.ModuleList() 30 | for _ in range(groups): 31 | self.rnn_list.append( 32 | getattr(nn, rnn_type)(input_size=input_size // groups, hidden_size=hidden_size//groups, 33 | num_layers=num_layers, 34 | bidirectional=bidirectional, batch_first=batch_first) 35 | ) 36 | 37 | def forward(self, inputs: Tensor, hidden_state: List[Tensor]): 38 | """ 39 | :param hidden_state: List[state1, state2, ...], len(hidden_state) = groups 40 | state shape = (num_layers*bidirectional, batch*[], hidden_size) if rnn_type is GRU or RNN, otherwise, 41 | state = (h0, c0), h0/c0 shape = (num_layers*bidirectional, batch*[], hidden_size). 42 | :param inputs: (batch, steps, input_size) 43 | :return: 44 | """ 45 | outputs = [] 46 | out_states = [] 47 | batch, steps, _ = inputs.shape 48 | 49 | inputs = torch.reshape(inputs, shape=(batch, steps, self.groups, -1)) # (batch, steps, groups, width) 50 | for idx, rnn in enumerate(self.rnn_list): 51 | out, state = rnn(inputs[:, :, idx, :], hidden_state[idx]) 52 | outputs.append(out) # (batch, steps, hidden_size) 53 | out_states.append(state) # (num_layers*bidirectional, batch*[], hidden_size) 54 | 55 | outputs = torch.cat(outputs, dim=2) # (batch, steps, hidden_size * groups) 56 | 57 | return outputs, out_states 58 | 59 | 60 | class DualPathExtensionRNN(nn.Module): 61 | def __init__(self, input_size: int, 62 | intra_hidden_size: int, 63 | inter_hidden_size: int, 64 | groups: int, 65 | rnn_type: str): 66 | super().__init__() 67 | assert rnn_type in ["RNN", "GRU", "LSTM"], f"rnn_type should be RNN/GRU/LSTM, but got {rnn_type}!" 68 | 69 | self.intra_chunk_rnn = getattr(nn, rnn_type)(input_size=input_size, hidden_size=intra_hidden_size, 70 | num_layers=1, bidirectional=True, batch_first=True) 71 | self.intra_chunk_fc = nn.Linear(in_features=intra_hidden_size*2, out_features=input_size) 72 | self.intra_chunk_norm = nn.LayerNorm(normalized_shape=input_size, elementwise_affine=True) 73 | 74 | self.inter_chunk_rnn = GroupRNN(input_size=input_size, hidden_size=inter_hidden_size, groups=groups, 75 | rnn_type=rnn_type) 76 | self.inter_chunk_fc = nn.Linear(in_features=inter_hidden_size, out_features=input_size) 77 | 78 | def forward(self, inputs: Tensor, hidden_state: List[Tensor]): 79 | """ 80 | :param hidden_state: List[state1, state2, ...], len(hidden_state) = groups 81 | state shape = (num_layers*bidirectional, batch*[], hidden_size) if rnn_type is GRU or RNN, otherwise, 82 | state = (h0, c0), h0/c0 shape = (num_layers*bidirectional, batch*[], hidden_size). 83 | :param inputs: (B, F, T, N) 84 | :return: 85 | """ 86 | B, F, T, N = inputs.shape 87 | intra_out = torch.transpose(inputs, dim0=1, dim1=2).contiguous() # (B, T, F, N) 88 | intra_out = torch.reshape(intra_out, shape=(B * T, F, N)) 89 | intra_out, _ = self.intra_chunk_rnn(intra_out) 90 | intra_out = self.intra_chunk_fc(intra_out) # (B, T, F, N) 91 | intra_out = torch.reshape(intra_out, shape=(B, T, F, N)) 92 | intra_out = torch.transpose(intra_out, dim0=1, dim1=2).contiguous() # (B, F, T, N) 93 | intra_out = self.intra_chunk_norm(intra_out) # (B, F, T, N) 94 | 95 | intra_out = inputs + intra_out # residual add 96 | 97 | inter_out = torch.reshape(intra_out, shape=(B * F, T, N)) # (B*F, T, N) 98 | inter_out, hidden_state = self.inter_chunk_rnn(inter_out, hidden_state) 99 | inter_out = torch.reshape(inter_out, shape=(B, F, T, -1)) # (B, F, T, groups * N) 100 | inter_out = self.inter_chunk_fc(inter_out) # (B, F, T, N) 101 | 102 | inter_out = inter_out + intra_out # residual add 103 | 104 | return inter_out, hidden_state 105 | 106 | 107 | if __name__ == "__main__": 108 | test_model = DualPathExtensionRNN(input_size=32, intra_hidden_size=16, inter_hidden_size=16, 109 | groups=8, rnn_type="LSTM") 110 | test_data = torch.randn(5, 32, 10, 32) 111 | test_state = [(torch.randn(1, 5*32, 16), torch.randn(1, 5*32, 16)) for _ in range(8)] 112 | test_out = test_model(test_data, test_state) 113 | -------------------------------------------------------------------------------- /run_train.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | from thop import profile, clever_format 5 | 6 | from configs.train_configs import TrainConfig 7 | from models.fspen import FullSubPathExtension 8 | 9 | 10 | if __name__ == "__main__": 11 | configs = TrainConfig() 12 | with open("config.json", mode="w", encoding="utf-8") as file: 13 | json.dump(configs.__dict__, file, indent=4) 14 | 15 | model = FullSubPathExtension(configs) 16 | 17 | batch = 1 18 | groups = configs.dual_path_extension["parameters"]["groups"] 19 | inter_hidden_size = configs.dual_path_extension["parameters"]["inter_hidden_size"] 20 | num_modules = configs.dual_path_extension["num_modules"] 21 | num_bands = sum(configs.bands_num_in_groups) 22 | 23 | in_wav = torch.randn(1, configs.train_points) 24 | complex_spectrum = torch.stft(in_wav, n_fft=configs.n_fft, hop_length=configs.hop_length, 25 | window=torch.hamming_window(configs.n_fft), return_complex=True) # (B, F, T) 26 | amplitude_spectrum = torch.abs(complex_spectrum) 27 | 28 | complex_spectrum = torch.view_as_real(complex_spectrum) # (B, F, T, 2) 29 | complex_spectrum = torch.permute(complex_spectrum, dims=(0, 2, 3, 1)) 30 | _, frames, channels, frequency = complex_spectrum.shape 31 | complex_spectrum = torch.reshape(complex_spectrum, shape=(batch, frames, channels, frequency)) 32 | amplitude_spectrum = torch.permute(amplitude_spectrum, dims=(0, 2, 1)) 33 | amplitude_spectrum = torch.reshape(amplitude_spectrum, shape=(batch, frames, 1, frequency)) 34 | # 35 | in_hidden_state = [[torch.zeros(1, batch * num_bands, inter_hidden_size//groups) for _ in range(groups)] 36 | for _ in range(num_modules)] 37 | 38 | flops, params = profile(model, inputs=(complex_spectrum, amplitude_spectrum, in_hidden_state)) 39 | flops, params = clever_format(nums=[flops, params], format="%0.4f") 40 | print(f"flops: {flops} \nparams: {params}") 41 | --------------------------------------------------------------------------------