├── README.md ├── TasNet.py ├── complexnn.py ├── conv_stft.py ├── dataload.py ├── dataload_vad.py ├── dc_crn.py ├── dc_crn_test_attention.py ├── dc_crn_test_avg.py ├── eval.py ├── eval_avg.py ├── main.py ├── main_test.py ├── main_test_avg.py ├── model.py ├── network_selfattention ├── resnet.py ├── result ├── run.sh ├── test.py ├── test.sh ├── test_attention_visual.py ├── test_avg.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # CSENet 2 | Csenet: Complex Squeeze-and-Excitation Network for Speech Depression Level Prediction (ICASSP 2022) 3 | 4 | ## Introduction 5 | "Csenet: Complex Squeeze-and-Excitation Network for Speech Depression Level Prediction" 6 | https://ieeexplore.ieee.org/document/9746011 7 | 8 | ## Citation 9 | If you use this code for your research, please consider citing: 10 | \\C. Fan, Z. Lv, S. Pei and M. Niu, "Csenet: Complex Squeeze-and-Excitation Network for Speech Depression Level Prediction," 11 | ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 2022, pp. 546-550, 12 | doi: 10.1109/ICASSP43922.2022.9746011.\\ 13 | 14 | ## run 15 | 1. data 16 | 17 | Firstly, the data should be about 3s for each wave. 18 | 19 | 2. train 20 | 21 | run.sh 22 | 23 | 3. test 24 | 25 | test.sh 26 | 27 | ![image](https://user-images.githubusercontent.com/26382648/175199425-f03bbec7-b5e0-4c2b-988f-e79bda8bdb3e.png) 28 | ![image](https://user-images.githubusercontent.com/26382648/175199462-c446887d-3833-4f47-a1ed-93dc4a560efd.png) 29 | ![image](https://user-images.githubusercontent.com/26382648/175199490-cce70ed2-6217-4521-a986-3c85a84f2adc.png) 30 | 31 | 32 | 33 | The authors would like to thank the Cai Cong for kindly providing the programs of data processing. 34 | -------------------------------------------------------------------------------- /TasNet.py: -------------------------------------------------------------------------------- 1 | # wujian@2018 2 | import torch 3 | import torch as th 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | # import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import pickle 10 | from torchsummary import summary 11 | from dc_crn import DCCRN 12 | 13 | def param(nnet, Mb=True): 14 | """ 15 | Return number parameters(not bytes) in nnet 16 | """ 17 | neles = sum([param.nelement() for param in nnet.parameters()]) 18 | return neles / 10**6 if Mb else neles 19 | def foo_conv1d_block(): 20 | nnet = Conv1DBlock(256, 512, 3, 20) 21 | print(param(nnet)) 22 | def foo_layernorm(): 23 | C, T = 256, 20 24 | nnet1 = nn.LayerNorm([C, T], elementwise_affine=True) 25 | print(param(nnet1, Mb=False)) 26 | nnet2 = nn.LayerNorm([C, T], elementwise_affine=False) 27 | print(param(nnet2, Mb=False)) 28 | def foo_conv_tas_net(): 29 | x = th.rand(4, 1000) 30 | nnet = ConvTasNet(norm="cLN", causal=False) 31 | # print(nnet) 32 | print("ConvTasNet #param: {:.2f}".format(param(nnet))) 33 | x = nnet(x) 34 | s1 = x[0] 35 | print(s1.shape) 36 | 37 | class ChannelWiseLayerNorm(nn.LayerNorm): 38 | """ 39 | Channel wise layer normalization 40 | """ 41 | 42 | def __init__(self, *args, **kwargs): 43 | super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs) 44 | 45 | def forward(self, x): 46 | """ 47 | x: N x C x T 48 | """ 49 | if x.dim() != 3: 50 | raise RuntimeError("{} accept 3D tensor as input".format( 51 | self.__name__)) 52 | # N x C x T => N x T x C 53 | x = th.transpose(x, 1, 2) 54 | # LN 55 | x = super().forward(x) 56 | # N x C x T => N x T x C 57 | x = th.transpose(x, 1, 2) 58 | return x 59 | class GlobalChannelLayerNorm(nn.Module): 60 | """ 61 | Global channel layer normalization 62 | """ 63 | 64 | def __init__(self, dim, eps=1e-05, elementwise_affine=True): 65 | super(GlobalChannelLayerNorm, self).__init__() 66 | self.eps = eps 67 | self.normalized_dim = dim 68 | self.elementwise_affine = elementwise_affine 69 | if elementwise_affine: 70 | self.beta = nn.Parameter(th.zeros(dim, 1)) 71 | self.gamma = nn.Parameter(th.ones(dim, 1)) 72 | else: 73 | self.register_parameter("weight", None) 74 | self.register_parameter("bias", None) 75 | 76 | def forward(self, x): 77 | """ 78 | x: N x C x T 79 | """ 80 | if x.dim() != 3: 81 | raise RuntimeError("{} accept 3D tensor as input".format( 82 | self.__name__)) 83 | # N x 1 x 1 84 | mean = th.mean(x, (1, 2), keepdim=True) 85 | var = th.mean((x - mean)**2, (1, 2), keepdim=True) 86 | # N x T x C 87 | if self.elementwise_affine: 88 | x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta 89 | else: 90 | x = (x - mean) / th.sqrt(var + self.eps) 91 | return x 92 | 93 | def extra_repr(self): 94 | return "{normalized_dim}, eps={eps}, " \ 95 | "elementwise_affine={elementwise_affine}".format(**self.__dict__) 96 | class Conv_regression_up(nn.Module): 97 | def __init__(self,channel,num,len): 98 | super(Conv_regression_up, self).__init__() 99 | self.conv1 = nn.Conv2d(channel,32,kernel_size=3,stride=1,padding=1) 100 | self.bn1 = nn.BatchNorm2d(32) 101 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 102 | self.bn2 = nn.BatchNorm2d(64) 103 | self.relu= nn.ReLU() 104 | self.avgpool1=nn.AdaptiveAvgPool2d((int(num/2),int(len/2))) 105 | self.avgpool2 = nn.AdaptiveAvgPool2d((int(num / 4), int(len / 4))) 106 | self.flatten = nn.Flatten(2) 107 | self.linear1 = nn.Linear(int(num/4)*int(len/4),1) 108 | self.linear2 = nn.Linear(64, 1) 109 | self.dropout = nn.Dropout(p=0.5) 110 | 111 | def forward(self, x): 112 | x = x.float() 113 | y = self.relu(self.bn1(self.conv1(x))) 114 | y = self.avgpool1(y) 115 | y = self.relu(self.bn2(self.conv2(y))) 116 | y = self.avgpool2(y) 117 | y = self.flatten(y) 118 | y = self.linear1(y) 119 | y = torch.squeeze(y) 120 | y = self.linear2(y) 121 | return y 122 | class Conv_regression(nn.Module): 123 | def __init__(self,channel, length): #(64,7999) 124 | super(Conv_regression, self).__init__() 125 | self.conv1 = nn.Conv2d(channel,128,kernel_size=3,stride=1,padding=1) 126 | self.bn1 = nn.BatchNorm2d(128) 127 | self.conv2 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 128 | self.bn2 = nn.BatchNorm2d(256) 129 | self.relu= nn.ReLU() 130 | self.avgpool1=nn.AdaptiveAvgPool2d(int(length/4)) 131 | self.avgpool2 = nn.AdaptiveAvgPool2d( int(length / 16)) 132 | self.linear1 = nn.Linear(int(length / 16),1) 133 | self.linear2 = nn.Linear(256, 1) 134 | 135 | 136 | def forward(self, x): 137 | y = self.relu(self.bn1(self.conv1(x))) 138 | y = self.avgpool1(y) 139 | y = self.relu(self.bn2(self.conv2(y))) 140 | y = self.avgpool2(y) 141 | y = self.linear1(self.relu(y)) 142 | y = torch.squeeze(y) 143 | y = self.linear2(self.relu(y)) 144 | return y 145 | 146 | def build_norm(norm, dim): 147 | """ 148 | Build normalize layer 149 | LN cost more memory than BN 150 | """ 151 | if norm not in ["cLN", "gLN", "BN"]: 152 | raise RuntimeError("Unsupported normalize layer: {}".format(norm)) 153 | if norm == "cLN": 154 | return ChannelWiseLayerNorm(dim, elementwise_affine=True) 155 | elif norm == "BN": 156 | return nn.BatchNorm1d(dim) 157 | else: 158 | return GlobalChannelLayerNorm(dim, elementwise_affine=True) 159 | class Conv1D(nn.Conv1d): 160 | 161 | def __init__(self, *args, **kwargs): 162 | super(Conv1D, self).__init__(*args, **kwargs) 163 | 164 | def forward(self, x, squeeze=False): 165 | if x.dim() not in [2, 3]: 166 | raise RuntimeError("{} accept 2/3D tensor as input".format( 167 | self.__name__)) 168 | x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1)) 169 | if squeeze: 170 | x = th.squeeze(x) 171 | return x 172 | 173 | class Conv1DBlock(nn.Module): 174 | 175 | def __init__(self,in_channels=256,conv_channels=512,kernel_size=3,dilation=1,norm="cLN",causal=False): 176 | super(Conv1DBlock, self).__init__() 177 | self.conv1x1 = Conv1D(in_channels, conv_channels, 1) 178 | self.prelu1 = nn.ReLU() 179 | self.lnorm1 = build_norm(norm, conv_channels) 180 | dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (dilation * (kernel_size - 1)) 181 | self.dconv = nn.Conv1d(conv_channels,conv_channels,kernel_size,padding=dconv_pad, dilation=dilation,bias=True) 182 | self.prelu2 = nn.ReLU() 183 | self.lnorm2 = build_norm(norm, conv_channels) 184 | self.lnorm3 = build_norm(norm, in_channels) 185 | self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True) 186 | self.causal = causal 187 | self.dconv_pad = dconv_pad 188 | 189 | def forward(self, x): 190 | y = self.conv1x1(x) 191 | y = self.prelu1(self.lnorm1(y)) 192 | y = self.dconv(y) 193 | y = self.prelu2(self.lnorm2(y)) 194 | y = self.sconv(y) 195 | #y = self.prelu2(self.lnorm3(y)) 196 | x = x + y 197 | return x 198 | 199 | 200 | class eca_layer(nn.Module): 201 | """Constructs a ECA module. 202 | Args: 203 | channel: Number of channels of the input feature map 204 | k_size: Adaptive selection of kernel size 205 | """ 206 | def __init__(self, k_size=3): 207 | super(eca_layer, self).__init__() 208 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 209 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 210 | self.sigmoid = nn.Sigmoid() 211 | 212 | def forward(self, x): 213 | # feature descriptor on the global spatial information 214 | y = self.avg_pool(x) 215 | # Two different branches of ECA module 216 | y = self.conv(y.transpose(-1,-2)).transpose(-1,-2) 217 | # Multi-scale information fusion 218 | y = self.sigmoid(y) 219 | return x * y.expand_as(x) 220 | class ConvTasNet(nn.Module): 221 | def __init__(self, 222 | L=10, #初始卷积核 223 | N=64, #初始卷积通道数 224 | X=8, 225 | R=4, 226 | B=64, #1*1卷积通道数 227 | H=128, #block卷积通道数 228 | P=3, #block 卷积核 229 | norm="BN", 230 | causal=False): 231 | super(ConvTasNet, self).__init__() 232 | self.encoder_1d = DCCRN(rnn_units=256,use_clstm=True,kernel_num=[32, 64, 128, 256, 256,256]) 233 | #self.decoder_1d = ConvTrans1D(N, 1, kernel_size=L, stride=L // 2, bias=True) 234 | # self.conv_regression_up = Conv_regression_up(1,64, 7999) 235 | self.conv_regression = Conv_regression(64, 7999) 236 | self.conv_regression3s = Conv_regression(64, 4799) 237 | self.avgpooling = nn.AdaptiveAvgPool1d(1) 238 | self.flatten = nn.Flatten() 239 | self.eca = eca_layer() 240 | self.linear = nn.Linear(128,1) 241 | 242 | def _build_blocks(self, num_blocks, **block_kwargs): 243 | 244 | blocks = [Conv1DBlock(**block_kwargs, dilation=(2**b)) 245 | for b in range(num_blocks)] 246 | return nn.Sequential(*blocks) 247 | 248 | def _build_repeats(self, num_repeats, num_blocks, **block_kwargs): 249 | repeats = [self._build_blocks(num_blocks, **block_kwargs) 250 | for r in range(num_repeats)] 251 | return nn.Sequential(*repeats) 252 | 253 | def forward(self, x): 254 | print('input: ', x.size()) 255 | w = F.relu(self.encoder_1d(x)) 256 | print('encoder dccrn: ', w.size()) 257 | y = self.proj(self.ln(w)) 258 | y = self.repeats(y) 259 | y = self.eca(y) 260 | y=self.conv_regression3s(y) 261 | return y 262 | 263 | # if __name__ == "__main__": 264 | # 265 | # 266 | # 267 | # a = torch.tensor([[[1, 2, 3]]]) 268 | # # 新建data3.pkl文件准备写入 269 | # data_output = open('data3.pkl', 'wb') 270 | # # 把a写入data.pkl文件里 271 | # pickle.dump(a, data_output) 272 | # #关闭写入 273 | # data_output.close() 274 | # pathname = "data3.pkl" 275 | # fp = open(pathname, "rb") 276 | # x = pickle.load(fp) # x表示当前的矩阵 277 | # sns.set() 278 | # ax = sns.heatmap(x, cmap="rainbow") # cmap是热力图颜色的参数 279 | # plt.show() 280 | 281 | -------------------------------------------------------------------------------- /complexnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | def get_casual_padding1d(): 7 | pass 8 | 9 | def get_casual_padding2d(): 10 | pass 11 | 12 | class cPReLU(nn.Module): 13 | 14 | def __init__(self, complex_axis=1): 15 | super(cPReLU,self).__init__() 16 | self.r_prelu = nn.PReLU() 17 | self.i_prelu = nn.PReLU() 18 | self.complex_axis = complex_axis 19 | 20 | 21 | def forward(self, inputs): 22 | real, imag = torch.chunk(inputs, 2,self.complex_axis) 23 | real = self.r_prelu(real) 24 | imag = self.i_prelu(imag) 25 | return torch.cat([real,imag],self.complex_axis) 26 | 27 | class NavieComplexLSTM(nn.Module): 28 | def __init__(self, input_size, hidden_size, projection_dim=None, bidirectional=False, batch_first=False): 29 | super(NavieComplexLSTM, self).__init__() 30 | 31 | self.input_dim = input_size//2 32 | self.rnn_units = hidden_size//2 33 | self.real_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, batch_first=False) 34 | self.imag_lstm = nn.LSTM(self.input_dim, self.rnn_units, num_layers=1, bidirectional=bidirectional, batch_first=False) 35 | if bidirectional: 36 | bidirectional=2 37 | else: 38 | bidirectional=1 39 | if projection_dim is not None: 40 | self.projection_dim = projection_dim//2 41 | self.r_trans = nn.Linear(self.rnn_units*bidirectional, self.projection_dim) 42 | self.i_trans = nn.Linear(self.rnn_units*bidirectional, self.projection_dim) 43 | else: 44 | self.projection_dim = None 45 | 46 | def forward(self, inputs): 47 | if isinstance(inputs,list): 48 | real, imag = inputs 49 | elif isinstance(inputs, torch.Tensor): 50 | real, imag = torch.chunk(inputs,-1) 51 | r2r_out = self.real_lstm(real)[0] 52 | self.real_lstm.flatten_parameters() 53 | r2i_out = self.imag_lstm(real)[0] 54 | self.imag_lstm.flatten_parameters() 55 | i2r_out = self.real_lstm(imag)[0] 56 | self.real_lstm.flatten_parameters() 57 | i2i_out = self.imag_lstm(imag)[0] 58 | self.imag_lstm.flatten_parameters() 59 | 60 | real_out = r2r_out - i2i_out 61 | imag_out = i2r_out + r2i_out 62 | if self.projection_dim is not None: 63 | real_out = self.r_trans(real_out) 64 | imag_out = self.i_trans(imag_out) 65 | #print(real_out.shape,imag_out.shape) 66 | return [real_out, imag_out] 67 | 68 | def flatten_parameters(self): 69 | self.imag_lstm.flatten_parameters() 70 | self.real_lstm.flatten_parameters() 71 | 72 | def complex_cat(inputs, axis): 73 | 74 | real, imag = [],[] 75 | for idx, data in enumerate(inputs): 76 | r, i = torch.chunk(data,2,axis) 77 | real.append(r) 78 | imag.append(i) 79 | real = torch.cat(real,axis) 80 | imag = torch.cat(imag,axis) 81 | outputs = torch.cat([real, imag],axis) 82 | return outputs 83 | 84 | class ComplexConv2d(nn.Module): 85 | 86 | def __init__( 87 | self, 88 | in_channels, 89 | out_channels, 90 | kernel_size=(1,1), 91 | stride=(1,1), 92 | padding=(0,0), 93 | dilation=1, 94 | groups = 1, 95 | causal=True, 96 | complex_axis=1, 97 | ): 98 | ''' 99 | in_channels: real+imag 100 | out_channels: real+imag 101 | kernel_size : input [B,C,D,T] kernel size in [D,T] 102 | padding : input [B,C,D,T] padding in [D,T] 103 | causal: if causal, will padding time dimension's left side, 104 | otherwise both 105 | 106 | ''' 107 | super(ComplexConv2d, self).__init__() 108 | self.in_channels = in_channels//2 109 | self.out_channels = out_channels//2 110 | self.kernel_size = kernel_size 111 | self.stride = stride 112 | self.padding = padding 113 | self.causal = causal 114 | self.groups = groups 115 | self.dilation = dilation 116 | self.complex_axis=complex_axis 117 | self.real_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,padding=[self.padding[0],0],dilation=self.dilation, groups=self.groups) 118 | self.imag_conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size, self.stride,padding=[self.padding[0],0],dilation=self.dilation, groups=self.groups) 119 | 120 | nn.init.normal_(self.real_conv.weight.data,std=0.05) 121 | nn.init.normal_(self.imag_conv.weight.data,std=0.05) 122 | nn.init.constant_(self.real_conv.bias,0.) 123 | nn.init.constant_(self.imag_conv.bias,0.) 124 | 125 | 126 | def forward(self,inputs): 127 | if self.padding[1] != 0 and self.causal: 128 | inputs = F.pad(inputs,[self.padding[1], 0,0,0]) 129 | else: 130 | inputs = F.pad(inputs,[self.padding[1], self.padding[1],0,0]) 131 | 132 | if self.complex_axis == 0: 133 | real = self.real_conv(inputs) 134 | imag = self.imag_conv(inputs) 135 | real2real,imag2real = torch.chunk(real,2, self.complex_axis) 136 | real2imag,imag2imag = torch.chunk(imag,2, self.complex_axis) 137 | 138 | else: 139 | if isinstance(inputs, torch.Tensor): 140 | real,imag = torch.chunk(inputs, 2, self.complex_axis) 141 | 142 | real2real = self.real_conv(real,) 143 | imag2imag = self.imag_conv(imag,) 144 | 145 | real2imag = self.imag_conv(real) 146 | imag2real = self.real_conv(imag) 147 | 148 | real = real2real - imag2imag 149 | imag = real2imag + imag2real 150 | out = torch.cat([real, imag], self.complex_axis) 151 | 152 | return out 153 | 154 | 155 | class ComplexConvTranspose2d(nn.Module): 156 | 157 | def __init__( 158 | self, 159 | in_channels, 160 | out_channels, 161 | kernel_size=(1,1), 162 | stride=(1,1), 163 | padding=(0,0), 164 | output_padding=(0,0), 165 | causal=False, 166 | complex_axis=1, 167 | groups=1 168 | ): 169 | ''' 170 | in_channels: real+imag 171 | out_channels: real+imag 172 | ''' 173 | super(ComplexConvTranspose2d, self).__init__() 174 | self.in_channels = in_channels//2 175 | self.out_channels = out_channels//2 176 | self.kernel_size = kernel_size 177 | self.stride = stride 178 | self.padding = padding 179 | self.output_padding=output_padding 180 | self.groups = groups 181 | 182 | self.real_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels,kernel_size, self.stride,padding=self.padding,output_padding=output_padding, groups=self.groups) 183 | self.imag_conv = nn.ConvTranspose2d(self.in_channels, self.out_channels,kernel_size, self.stride,padding=self.padding,output_padding=output_padding, groups=self.groups) 184 | self.complex_axis=complex_axis 185 | 186 | nn.init.normal_(self.real_conv.weight,std=0.05) 187 | nn.init.normal_(self.imag_conv.weight,std=0.05) 188 | nn.init.constant_(self.real_conv.bias,0.) 189 | nn.init.constant_(self.imag_conv.bias,0.) 190 | 191 | def forward(self,inputs): 192 | 193 | if isinstance(inputs, torch.Tensor): 194 | real,imag = torch.chunk(inputs, 2, self.complex_axis) 195 | elif isinstance(inputs, tuple) or isinstance(inputs, list): 196 | real = inputs[0] 197 | imag = inputs[1] 198 | if self.complex_axis == 0: 199 | real = self.real_conv(inputs) 200 | imag = self.imag_conv(inputs) 201 | real2real,imag2real = torch.chunk(real,2, self.complex_axis) 202 | real2imag,imag2imag = torch.chunk(imag,2, self.complex_axis) 203 | 204 | else: 205 | if isinstance(inputs, torch.Tensor): 206 | real,imag = torch.chunk(inputs, 2, self.complex_axis) 207 | 208 | real2real = self.real_conv(real,) 209 | imag2imag = self.imag_conv(imag,) 210 | 211 | real2imag = self.imag_conv(real) 212 | imag2real = self.real_conv(imag) 213 | 214 | real = real2real - imag2imag 215 | imag = real2imag + imag2real 216 | out = torch.cat([real, imag], self.complex_axis) 217 | 218 | return out 219 | 220 | 221 | 222 | # Source: https://github.com/ChihebTrabelsi/deep_complex_networks/tree/pytorch 223 | # from https://github.com/IMLHF/SE_DCUNet/blob/f28bf1661121c8901ad38149ea827693f1830715/models/layers/complexnn.py#L55 224 | 225 | class ComplexBatchNorm(torch.nn.Module): 226 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 227 | track_running_stats=True, complex_axis=1): 228 | super(ComplexBatchNorm, self).__init__() 229 | self.num_features = num_features//2 230 | self.eps = eps 231 | self.momentum = momentum 232 | self.affine = affine 233 | self.track_running_stats = track_running_stats 234 | 235 | self.complex_axis = complex_axis 236 | 237 | if self.affine: 238 | self.Wrr = torch.nn.Parameter(torch.Tensor(self.num_features)) 239 | self.Wri = torch.nn.Parameter(torch.Tensor(self.num_features)) 240 | self.Wii = torch.nn.Parameter(torch.Tensor(self.num_features)) 241 | self.Br = torch.nn.Parameter(torch.Tensor(self.num_features)) 242 | self.Bi = torch.nn.Parameter(torch.Tensor(self.num_features)) 243 | else: 244 | self.register_parameter('Wrr', None) 245 | self.register_parameter('Wri', None) 246 | self.register_parameter('Wii', None) 247 | self.register_parameter('Br', None) 248 | self.register_parameter('Bi', None) 249 | 250 | if self.track_running_stats: 251 | self.register_buffer('RMr', torch.zeros(self.num_features)) 252 | self.register_buffer('RMi', torch.zeros(self.num_features)) 253 | self.register_buffer('RVrr', torch.ones (self.num_features)) 254 | self.register_buffer('RVri', torch.zeros(self.num_features)) 255 | self.register_buffer('RVii', torch.ones (self.num_features)) 256 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 257 | else: 258 | self.register_parameter('RMr', None) 259 | self.register_parameter('RMi', None) 260 | self.register_parameter('RVrr', None) 261 | self.register_parameter('RVri', None) 262 | self.register_parameter('RVii', None) 263 | self.register_parameter('num_batches_tracked', None) 264 | self.reset_parameters() 265 | 266 | def reset_running_stats(self): 267 | if self.track_running_stats: 268 | self.RMr.zero_() 269 | self.RMi.zero_() 270 | self.RVrr.fill_(1) 271 | self.RVri.zero_() 272 | self.RVii.fill_(1) 273 | self.num_batches_tracked.zero_() 274 | 275 | def reset_parameters(self): 276 | self.reset_running_stats() 277 | if self.affine: 278 | self.Br.data.zero_() 279 | self.Bi.data.zero_() 280 | self.Wrr.data.fill_(1) 281 | self.Wri.data.uniform_(-.9, +.9) # W will be positive-definite 282 | self.Wii.data.fill_(1) 283 | 284 | def _check_input_dim(self, xr, xi): 285 | assert(xr.shape == xi.shape) 286 | assert(xr.size(1) == self.num_features) 287 | 288 | def forward(self, inputs): 289 | #self._check_input_dim(xr, xi) 290 | 291 | xr, xi = torch.chunk(inputs,2, axis=self.complex_axis) 292 | exponential_average_factor = 0.0 293 | 294 | if self.training and self.track_running_stats: 295 | self.num_batches_tracked += 1 296 | if self.momentum is None: # use cumulative moving average 297 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 298 | else: # use exponential moving average 299 | exponential_average_factor = self.momentum 300 | 301 | # 302 | # NOTE: The precise meaning of the "training flag" is: 303 | # True: Normalize using batch statistics, update running statistics 304 | # if they are being collected. 305 | # False: Normalize using running statistics, ignore batch statistics. 306 | # 307 | training = self.training or not self.track_running_stats 308 | redux = [i for i in reversed(range(xr.dim())) if i!=1] 309 | vdim = [1] * xr.dim() 310 | vdim[1] = xr.size(1) 311 | 312 | # 313 | # Mean M Computation and Centering 314 | # 315 | # Includes running mean update if training and running. 316 | # 317 | if training: 318 | Mr, Mi = xr, xi 319 | for d in redux: 320 | Mr = Mr.mean(d, keepdim=True) 321 | Mi = Mi.mean(d, keepdim=True) 322 | if self.track_running_stats: 323 | self.RMr.lerp_(Mr.squeeze(), exponential_average_factor) 324 | self.RMi.lerp_(Mi.squeeze(), exponential_average_factor) 325 | else: 326 | Mr = self.RMr.view(vdim) 327 | Mi = self.RMi.view(vdim) 328 | xr, xi = xr-Mr, xi-Mi 329 | 330 | # 331 | # Variance Matrix V Computation 332 | # 333 | # Includes epsilon numerical stabilizer/Tikhonov regularizer. 334 | # Includes running variance update if training and running. 335 | # 336 | if training: 337 | Vrr = xr * xr 338 | Vri = xr * xi 339 | Vii = xi * xi 340 | for d in redux: 341 | Vrr = Vrr.mean(d, keepdim=True) 342 | Vri = Vri.mean(d, keepdim=True) 343 | Vii = Vii.mean(d, keepdim=True) 344 | if self.track_running_stats: 345 | self.RVrr.lerp_(Vrr.squeeze(), exponential_average_factor) 346 | self.RVri.lerp_(Vri.squeeze(), exponential_average_factor) 347 | self.RVii.lerp_(Vii.squeeze(), exponential_average_factor) 348 | else: 349 | Vrr = self.RVrr.view(vdim) 350 | Vri = self.RVri.view(vdim) 351 | Vii = self.RVii.view(vdim) 352 | Vrr = Vrr + self.eps 353 | Vri = Vri 354 | Vii = Vii + self.eps 355 | 356 | # 357 | # Matrix Inverse Square Root U = V^-0.5 358 | # 359 | # sqrt of a 2x2 matrix, 360 | # - https://en.wikipedia.org/wiki/Square_root_of_a_2_by_2_matrix 361 | tau = Vrr + Vii 362 | delta = torch.addcmul(Vrr * Vii, -1, Vri, Vri) 363 | s = delta.sqrt() 364 | t = (tau + 2*s).sqrt() 365 | 366 | # matrix inverse, http://mathworld.wolfram.com/MatrixInverse.html 367 | rst = (s * t).reciprocal() 368 | Urr = (s + Vii) * rst 369 | Uii = (s + Vrr) * rst 370 | Uri = ( - Vri) * rst 371 | 372 | # 373 | # Optionally left-multiply U by affine weights W to produce combined 374 | # weights Z, left-multiply the inputs by Z, then optionally bias them. 375 | # 376 | # y = Zx + B 377 | # y = WUx + B 378 | # y = [Wrr Wri][Urr Uri] [xr] + [Br] 379 | # [Wir Wii][Uir Uii] [xi] [Bi] 380 | # 381 | if self.affine: 382 | Wrr, Wri, Wii = self.Wrr.view(vdim), self.Wri.view(vdim), self.Wii.view(vdim) 383 | Zrr = (Wrr * Urr) + (Wri * Uri) 384 | Zri = (Wrr * Uri) + (Wri * Uii) 385 | Zir = (Wri * Urr) + (Wii * Uri) 386 | Zii = (Wri * Uri) + (Wii * Uii) 387 | else: 388 | Zrr, Zri, Zir, Zii = Urr, Uri, Uri, Uii 389 | 390 | yr = (Zrr * xr) + (Zri * xi) 391 | yi = (Zir * xr) + (Zii * xi) 392 | 393 | if self.affine: 394 | yr = yr + self.Br.view(vdim) 395 | yi = yi + self.Bi.view(vdim) 396 | 397 | outputs = torch.cat([yr, yi], self.complex_axis) 398 | return outputs 399 | 400 | def extra_repr(self): 401 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 402 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 403 | 404 | def complex_cat(inputs, axis): 405 | 406 | real, imag = [],[] 407 | for idx, data in enumerate(inputs): 408 | r, i = torch.chunk(data,2,axis) 409 | real.append(r) 410 | imag.append(i) 411 | real = torch.cat(real,axis) 412 | imag = torch.cat(imag,axis) 413 | outputs = torch.cat([real, imag],axis) 414 | return outputs 415 | 416 | if __name__ == '__main__': 417 | import dc_crn7 418 | torch.manual_seed(20) 419 | onet1 = dc_crn7.ComplexConv2d(12,12,kernel_size=(3,2),padding=(2,1)) 420 | onet2 = dc_crn7.ComplexConvTranspose2d(12,12,kernel_size=(3,2),padding=(2,1)) 421 | inputs = torch.randn([1,12,12,10]) 422 | # print(onet1.real_kernel[0,0,0,0]) 423 | nnet1 = ComplexConv2d(12,12,kernel_size=(3,2),padding=(2,1),causal=True) 424 | # print(nnet1.real_conv.weight[0,0,0,0]) 425 | nnet2 = ComplexConvTranspose2d(12,12,kernel_size=(3,2),padding=(2,1)) 426 | print(torch.mean(nnet1(inputs)-onet1(inputs))) 427 | 428 | 429 | 430 | -------------------------------------------------------------------------------- /conv_stft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from scipy.signal import get_window 6 | 7 | 8 | def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False): 9 | if win_type == 'None' or win_type is None: 10 | window = np.ones(win_len) 11 | else: 12 | window = get_window(win_type, win_len, fftbins=True)#**0.5 13 | 14 | N = fft_len 15 | fourier_basis = np.fft.rfft(np.eye(N))[:win_len] 16 | real_kernel = np.real(fourier_basis) 17 | imag_kernel = np.imag(fourier_basis) 18 | kernel = np.concatenate([real_kernel, imag_kernel], 1).T 19 | 20 | if invers : 21 | kernel = np.linalg.pinv(kernel).T 22 | 23 | kernel = kernel*window 24 | kernel = kernel[:, None, :] 25 | return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(window[None,:,None].astype(np.float32)) 26 | 27 | 28 | class ConvSTFT(nn.Module): 29 | 30 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 31 | super(ConvSTFT, self).__init__() 32 | 33 | if fft_len == None: 34 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 35 | else: 36 | self.fft_len = fft_len 37 | 38 | kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type) 39 | #self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 40 | self.register_buffer('weight', kernel) 41 | self.feature_type = feature_type 42 | self.stride = win_inc 43 | self.win_len = win_len 44 | self.dim = self.fft_len 45 | 46 | def forward(self, inputs): 47 | if inputs.dim() == 2: 48 | inputs = torch.unsqueeze(inputs, 1) 49 | inputs = F.pad(inputs,[self.win_len-self.stride, self.win_len-self.stride]) 50 | outputs = F.conv1d(inputs, self.weight, stride=self.stride) 51 | 52 | if self.feature_type == 'complex': 53 | return outputs 54 | else: 55 | dim = self.dim//2+1 56 | real = outputs[:, :dim, :] 57 | imag = outputs[:, dim:, :] 58 | mags = torch.sqrt(real**2+imag**2) 59 | phase = torch.atan2(imag, real) 60 | return mags, phase 61 | 62 | class ConviSTFT(nn.Module): 63 | 64 | def __init__(self, win_len, win_inc, fft_len=None, win_type='hamming', feature_type='real', fix=True): 65 | super(ConviSTFT, self).__init__() 66 | if fft_len == None: 67 | self.fft_len = np.int(2**np.ceil(np.log2(win_len))) 68 | else: 69 | self.fft_len = fft_len 70 | kernel, window = init_kernels(win_len, win_inc, self.fft_len, win_type, invers=True) 71 | #self.weight = nn.Parameter(kernel, requires_grad=(not fix)) 72 | self.register_buffer('weight', kernel) 73 | self.feature_type = feature_type 74 | self.win_type = win_type 75 | self.win_len = win_len 76 | self.stride = win_inc 77 | self.stride = win_inc 78 | self.dim = self.fft_len 79 | self.register_buffer('window', window) 80 | self.register_buffer('enframe', torch.eye(win_len)[:,None,:]) 81 | 82 | def forward(self, inputs, phase=None): 83 | """ 84 | inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags) 85 | phase: [B, N//2+1, T] (if not none) 86 | """ 87 | 88 | if phase is not None: 89 | real = inputs*torch.cos(phase) 90 | imag = inputs*torch.sin(phase) 91 | inputs = torch.cat([real, imag], 1) 92 | outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride) 93 | 94 | # this is from torch-stft: https://github.com/pseeth/torch-stft 95 | t = self.window.repeat(1,1,inputs.size(-1))**2 96 | coff = F.conv_transpose1d(t, self.enframe, stride=self.stride) 97 | outputs = outputs/(coff+1e-8) 98 | #outputs = torch.where(coff == 0, outputs, outputs/coff) 99 | outputs = outputs[...,self.win_len-self.stride:-(self.win_len-self.stride)] 100 | 101 | return outputs 102 | 103 | def test_fft(): 104 | torch.manual_seed(20) 105 | win_len = 320 106 | win_inc = 160 107 | fft_len = 512 108 | inputs = torch.randn([1, 1, 16000*4]) 109 | fft = ConvSTFT(win_len, win_inc, fft_len, win_type='hanning', feature_type='real') 110 | import librosa 111 | 112 | outputs1 = fft(inputs)[0] 113 | outputs1 = outputs1.numpy()[0] 114 | np_inputs = inputs.numpy().reshape([-1]) 115 | librosa_stft = librosa.stft(np_inputs, win_length=win_len, n_fft=fft_len, hop_length=win_inc, center=False) 116 | print(np.mean((outputs1 - np.abs(librosa_stft))**2)) 117 | 118 | 119 | def test_ifft1(): 120 | import soundfile as sf 121 | N = 400 122 | inc = 100 123 | fft_len=512 124 | torch.manual_seed(N) 125 | data = np.random.randn(16000*8)[None,None,:] 126 | # data = sf.read('../ori.wav')[0] 127 | inputs = data.reshape([1,1,-1]) 128 | fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 129 | ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 130 | inputs = torch.from_numpy(inputs.astype(np.float32)) 131 | outputs1 = fft(inputs) 132 | print(outputs1.shape) 133 | outputs2 = ifft(outputs1) 134 | sf.write('conv_stft.wav', outputs2.numpy()[0,0,:],16000) 135 | print('wav MSE', torch.mean(torch.abs(inputs[...,:outputs2.size(2)]-outputs2)**2)) 136 | 137 | 138 | def test_ifft2(): 139 | N = 400 140 | inc = 100 141 | fft_len=512 142 | np.random.seed(20) 143 | torch.manual_seed(20) 144 | t = np.random.randn(16000*4)*0.001 145 | t = np.clip(t, -1, 1) 146 | #input = torch.randn([1,16000*4]) 147 | input = torch.from_numpy(t[None,None,:].astype(np.float32)) 148 | 149 | fft = ConvSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 150 | ifft = ConviSTFT(N, inc, fft_len=fft_len, win_type='hanning', feature_type='complex') 151 | 152 | out1 = fft(input) 153 | output = ifft(out1) 154 | print('random MSE', torch.mean(torch.abs(input-output)**2)) 155 | import soundfile as sf 156 | sf.write('zero.wav', output[0,0].numpy(),16000) 157 | 158 | 159 | if __name__ == '__main__': 160 | #test_fft() 161 | test_ifft1() 162 | #test_ifft2() 163 | -------------------------------------------------------------------------------- /dataload.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import random 4 | import wave 5 | import numpy as np 6 | import librosa 7 | import torch 8 | import torchvision 9 | from torch.utils.data import DataLoader,Dataset 10 | from scipy.io import loadmat 11 | from scipy import signal 12 | 13 | class Depression_3dmel_random_train(Dataset): 14 | def __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数 15 | self.root_dir = root_dir # 文件目录 16 | self.transform = transform # 变换 17 | #self.images = os.listdir(self.root_dir) # 目录里的所有文件 18 | self.data_list=[] 19 | self.label_list=[] 20 | for j in range(160): 21 | list_shuffle = [] 22 | labels = os.listdir(root_dir) 23 | for i in range(len(labels)): 24 | list = [] 25 | for root, dirs, files in os.walk(root_dir + labels[i]): 26 | for file in files: 27 | list.append(os.path.join(root, file)) 28 | random.shuffle(list) 29 | list = list[:10] 30 | list_shuffle += list 31 | random.shuffle(list_shuffle) 32 | self.data_list += list_shuffle 33 | for k in range(len(self.data_list)): 34 | self.label_list.append((self.data_list[k].split('/'))[3]) 35 | 36 | def __len__(self): # 返回整个数据集的大小 37 | return len(self.data_list) 38 | 39 | def __getitem__(self, index): # 根据索引index返回dataset[index] 40 | waveData=loadmat(self.data_list[index])['value'] 41 | waveData = torch.from_numpy(waveData) 42 | label = int(self.label_list[index]) 43 | sample = (waveData,label) # 根据图片和标签创建元组 44 | return sample 45 | class Depression_3dmel_order_test(Dataset): 46 | def __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数 47 | self.root_dir = root_dir # 文件目录 48 | self.transform = transform # 变换 49 | #self.images = os.listdir(self.root_dir) # 目录里的所有文件 50 | self.data_list=[] 51 | self.label_list=[] 52 | for root, dirs, files in os.walk(root_dir): #./clip/train/ 53 | for dir in dirs: 54 | for root2, dirs2, files2 in os.walk(root_dir+dir): #./clip/train/00/ 55 | for file2 in files2: 56 | if file2.endswith(".mat"): 57 | self.data_list.append(os.path.join(root2,file2)) # ./clip/train/00/223_1/00235.jpg 58 | self.label_list.append(dir) # 00 59 | 60 | def __len__(self): # 返回整个数据集的大小 61 | return len(self.data_list) 62 | 63 | def __getitem__(self, index): # 根据索引index返回dataset[index] 64 | waveData=loadmat(self.data_list[index])['value'] 65 | name=self.data_list[index].split('/')[-2] 66 | waveData = torch.from_numpy(waveData) 67 | label = int(self.label_list[index]) 68 | sample = (waveData,label,str(name)) # 根据图片和标签创建元组 69 | return sample 70 | class Depression_3dmel_order_dev(Dataset): 71 | def __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数 72 | self.root_dir = root_dir # 文件目录 73 | self.transform = transform # 变换 74 | #self.images = os.listdir(self.root_dir) # 目录里的所有文件 75 | self.data_list=[] 76 | self.label_list=[] 77 | for root, dirs, files in os.walk(root_dir): #./clip/train/ 78 | for dir in dirs: 79 | for root2, dirs2, files2 in os.walk(root_dir+dir): #./clip/train/00/ 80 | for file2 in files2: 81 | if file2.endswith(".mat"): 82 | self.data_list.append(os.path.join(root2,file2)) # ./clip/train/00/223_1/00235.jpg 83 | self.label_list.append(dir) # 00 84 | 85 | def __len__(self): # 返回整个数据集的大小 86 | return len(self.data_list) 87 | 88 | def __getitem__(self, index): # 根据索引index返回dataset[index] 89 | waveData=loadmat(self.data_list[index])['value'] 90 | #name=self.data_list[index].split('/')[-2] 91 | waveData = torch.from_numpy(waveData) 92 | label = int(self.label_list[index]) 93 | sample = (waveData,label) # 根据图片和标签创建元组 94 | return sample 95 | def train_data_loader_3d(root='./audio_mat_3dmel_3s/train/',batch_size=64,shuffle=False): 96 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 97 | traindata = Depression_3dmel_random_train(root, transform=transform) # 初始化类,设置数据集所在路径以及变换 98 | trainloader = DataLoader(traindata, batch_size=batch_size, shuffle=shuffle,num_workers=4) # 使用DataLoader加载数据 99 | return trainloader 100 | def test_data_loader_3d(root='./audio_mat_3dmel_3s/test/',batch_size=1,shuffle=False): 101 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 102 | testdata = Depression_3dmel_order_test(root, transform=transform) # 初始化类,设置数据集所在路径以及变换 103 | testloader = DataLoader(testdata, batch_size=batch_size, shuffle=shuffle, num_workers=4,drop_last=True) # 使用DataLoader加载数据 104 | return testloader 105 | def val_data_loader_3d(root='./audio_mat_3dmel_3s/dev/',batch_size=1,shuffle=False): 106 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 107 | testdata = Depression_3dmel_order_dev(root, transform=transform) # 初始化类,设置数据集所在路径以及变换 108 | testloader = DataLoader(testdata, batch_size=batch_size, shuffle=shuffle, num_workers=4,drop_last=True) # 使用DataLoader加载数据 109 | return testloader 110 | 111 | 112 | class Depression_wav_order_train14(Dataset): 113 | def __init__(self, root_dir,root_dir2, transform=None): # __init__是初始化该类的一些基础参数 114 | self.root_dir = root_dir # 文件目录 115 | self.root_dir2= root_dir2 116 | self.transform = transform # 变换 117 | # self.images = os.listdir(self.root_dir) # 目录里的所有文件 118 | self.data_list = [] 119 | self.label_list = [] 120 | for root, dirs, files in os.walk(root_dir): # ./clip/train/ 121 | for dir in dirs: 122 | for root2, dirs2, files2 in os.walk(root_dir + dir): # ./clip/train/00/ 123 | for file2 in files2: 124 | if file2.endswith(".wav"): 125 | self.data_list.append(os.path.join(root2, file2)) # ./clip/train/00/223_1/00235.jpg 126 | self.label_list.append(dir) # 00 127 | for root, dirs, files in os.walk(root_dir2): # ./clip/train/ 128 | for dir in dirs: 129 | for root2, dirs2, files2 in os.walk(root_dir2 + dir): # ./clip/train/00/ 130 | for file2 in files2: 131 | if file2.endswith(".wav"): 132 | self.data_list.append(os.path.join(root2, file2)) # ./clip/train/00/223_1/00235.jpg 133 | self.label_list.append(dir) # 00 134 | 135 | def __len__(self): # 返回整个数据集的大小 136 | return len(self.data_list) 137 | 138 | def __getitem__(self, index): # 根据索引index返回dataset[index] 139 | f = wave.open(self.data_list[index], 'rb') # 获取索引为index的图片的路径名 140 | params = f.getparams() 141 | nchannels, sampwidth, framerate, nframes = params[:4] 142 | strData = f.readframes(nframes) # 读取音频,字符串格式 143 | waveData = np.fromstring(strData, dtype=np.int16) # 将字符串转化为int 144 | waveData = torch.tensor(waveData).unsqueeze(0) 145 | waveData = waveData.type(torch.FloatTensor) 146 | label = int(self.label_list[index]) 147 | sample = (waveData, label) # 根据图片和标签创建元组 148 | # if self.transform: 149 | # sample = self.transform(sample) # 对样本进行变换 150 | return sample 151 | class Depression_wav_order_test14(Dataset): 152 | def __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数 153 | self.root_dir = root_dir # 文件目录 154 | self.transform = transform # 变换 155 | #self.images = os.listdir(self.root_dir) # 目录里的所有文件 156 | self.data_list=[] 157 | self.label_list=[] 158 | for root, dirs, files in os.walk(root_dir): #./clip/train/ 159 | for dir in sorted(dirs): 160 | for root2, dirs2, files2 in os.walk(root_dir+dir): #./clip/train/00/ 161 | for file2 in files2: 162 | if file2.endswith(".wav"): 163 | self.data_list.append(os.path.join(root2,file2)) # ./clip/train/00/223_1/00235.jpg 164 | self.label_list.append(dir) # 00 165 | 166 | def __len__(self): # 返回整个数据集的大小 167 | return len(self.data_list) 168 | 169 | def __getitem__(self, index): # 根据索引index返回dataset[index] 170 | f = wave.open(self.data_list[index], 'rb') # 获取索引为index的图片的路径名 171 | params = f.getparams() 172 | nchannels, sampwidth, framerate, nframes = params[:4] 173 | strData = f.readframes(nframes) # 读取音频,字符串格式 174 | waveData = np.fromstring(strData, dtype=np.int16) # 将字符串转化为int 175 | waveData = torch.tensor(waveData).unsqueeze(0) 176 | waveData = waveData.type(torch.FloatTensor) 177 | label = int(self.label_list[index]) 178 | name = self.data_list[index].split('/')[-2] 179 | sample = (waveData, label,str(name)) # 根据图片和标签创建元组 180 | # if self.transform: 181 | # sample = self.transform(sample) # 对样本进行变换 182 | return sample 183 | class Depression_wav_random_train(Dataset): 184 | def __init__(self, root_dir, transform=None,flag=3): # __init__是初始化该类的一些基础参数 185 | self.root_dir = root_dir # 文件目录 186 | self.transform = transform # 变换 187 | self.data_list = [] 188 | self.label_list = [] 189 | for j in range(160): 190 | list_shuffle = [] 191 | labels = os.listdir(root_dir) 192 | for i in range(len(labels)): 193 | list = [] 194 | for root, dirs, files in os.walk(os.path.join(root_dir, labels[i])): 195 | for file in files: 196 | if file.endswith(".wav"): 197 | list.append(os.path.join(root, file)) 198 | random.shuffle(list) 199 | list = list[:10] 200 | list_shuffle += list 201 | random.shuffle(list_shuffle) 202 | self.data_list += list_shuffle 203 | for k in range(len(self.data_list)): 204 | self.label_list.append((self.data_list[k].split('/'))[flag]) 205 | 206 | def __len__(self): # 返回整个数据集的大小 207 | return len(self.data_list) 208 | 209 | def __getitem__(self, index): # 根据索引index返回dataset[index] 210 | f = wave.open(self.data_list[index], 'rb') # 获取索引为index的图片的路径名 211 | params = f.getparams() 212 | nchannels, sampwidth, framerate, nframes = params[:4] 213 | strData = f.readframes(nframes) # 读取音频,字符串格式 214 | waveData = np.fromstring(strData, dtype=np.int16) # 将字符串转化为int 215 | waveData = torch.tensor(waveData).unsqueeze(0) 216 | waveData = waveData.type(torch.FloatTensor) 217 | label = int(self.label_list[index]) 218 | sample = (waveData, label) # 根据图片和标签创建元组 219 | # if self.transform: 220 | # sample = self.transform(sample) # 对样本进行变换 221 | return sample 222 | class Depression_wav_order_test(Dataset): 223 | def __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数 224 | self.root_dir = root_dir # 文件目录 225 | self.transform = transform # 变换 226 | #self.images = os.listdir(self.root_dir) # 目录里的所有文件 227 | self.data_list=[] 228 | self.label_list=[] 229 | for root, dirs, files in os.walk(root_dir): #./clip/train/ 230 | for dir in sorted(dirs): 231 | for root2, dirs2, files2 in os.walk(os.path.join(root_dir, dir)): #./clip/train/00/ 232 | for file2 in files2: 233 | if file2.endswith(".wav"): 234 | self.data_list.append(os.path.join(root2,file2)) # ./clip/train/00/223_1/00235.jpg 235 | self.label_list.append(dir) # 00 236 | 237 | def __len__(self): # 返回整个数据集的大小 238 | return len(self.data_list) 239 | 240 | def __getitem__(self, index): # 根据索引index返回dataset[index] 241 | f = wave.open(self.data_list[index], 'rb') # 获取索引为index的图片的路径名 242 | params = f.getparams() 243 | nchannels, sampwidth, framerate, nframes = params[:4] 244 | strData = f.readframes(nframes) # 读取音频,字符串格式 245 | waveData = np.fromstring(strData, dtype=np.int16) # 将字符串转化为int 246 | waveData = torch.tensor(waveData).unsqueeze(0) 247 | waveData = waveData.type(torch.FloatTensor) 248 | label = int(self.label_list[index]) 249 | name = self.data_list[index].split('/')[-2] 250 | sample = (waveData, label,str(name)) # 根据图片和标签创建元组 251 | # if self.transform: 252 | # sample = self.transform(sample) # 对样本进行变换 253 | return sample 254 | class Depression_wav_order_dev(Dataset): 255 | def __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数 256 | self.root_dir = root_dir # 文件目录 257 | self.transform = transform # 变换 258 | #self.images = os.listdir(self.root_dir) # 目录里的所有文件 259 | self.data_list=[] 260 | self.label_list=[] 261 | for root, dirs, files in os.walk(root_dir): #./clip/train/ 262 | for dir in dirs: 263 | for root2, dirs2, files2 in os.walk(os.path.join(root_dir, dir)): #./clip/train/00/ 264 | for file2 in files2: 265 | if file2.endswith(".wav"): 266 | self.data_list.append(os.path.join(root2,file2)) # ./clip/train/00/223_1/00235.jpg 267 | self.label_list.append(dir) # 00 268 | 269 | def __len__(self): # 返回整个数据集的大小 270 | return len(self.data_list) 271 | 272 | def __getitem__(self, index): # 根据索引index返回dataset[index] 273 | f = wave.open(self.data_list[index], 'rb') # 获取索引为index的图片的路径名 274 | params = f.getparams() 275 | nchannels, sampwidth, framerate, nframes = params[:4] 276 | strData = f.readframes(nframes) # 读取音频,字符串格式 277 | waveData = np.fromstring(strData, dtype=np.int16) # 将字符串转化为int 278 | waveData = torch.tensor(waveData).unsqueeze(0) 279 | waveData = waveData.type(torch.FloatTensor) 280 | label = int(self.label_list[index]) 281 | sample = (waveData, label) # 根据图片和标签创建元组 282 | # if self.transform: 283 | # sample = self.transform(sample) # 对样本进行变换 284 | return sample 285 | 286 | def train_data_loader(root,batch_size,shuffle,flag): 287 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 288 | traindata = Depression_wav_random_train(root, transform=transform,flag=flag) # 初始化类,设置数据集所在路径以及变换 289 | # print('batch_size: ', batch_size) 290 | trainloader = DataLoader(traindata, batch_size=batch_size, shuffle=shuffle,num_workers=4) # 使用DataLoader加载数据 291 | return trainloader 292 | def test_data_loader(root='./audio_wav_3s/test/',batch_size=1,shuffle=False): 293 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 294 | testdata = Depression_wav_order_test(root, transform=transform) # 初始化类,设置数据集所在路径以及变换 295 | testloader = DataLoader(testdata, batch_size=batch_size, shuffle=shuffle, num_workers=4,drop_last=True) # 使用DataLoader加载数据 296 | return testloader 297 | def val_data_loader(root='./audio_wav_3s/dev/',batch_size=64,shuffle=False): 298 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 299 | testdata = Depression_wav_order_dev(root, transform=transform) # 初始化类,设置数据集所在路径以及变换 300 | testloader = DataLoader(testdata, batch_size=batch_size, shuffle=shuffle, num_workers=4,drop_last=True) # 使用DataLoader加载数据 301 | return testloader 302 | def train_data_loader14(root,root2,batch_size,shuffle): 303 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 304 | traindata = Depression_wav_order_train14(root,root2, transform=transform) # 初始化类,设置数据集所在路径以及变换 305 | trainloader = DataLoader(traindata, batch_size=batch_size, shuffle=shuffle,num_workers=4) # 使用DataLoader加载数据 306 | return trainloader 307 | 308 | -------------------------------------------------------------------------------- /dataload_vad.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import random 4 | import wave 5 | import numpy as np 6 | import librosa 7 | import torch 8 | import torchvision 9 | from torch.utils.data import DataLoader,Dataset 10 | from scipy.io import loadmat 11 | from scipy import signal 12 | import webrtcvad 13 | 14 | 15 | class Depression_wav_order_train14(Dataset): 16 | def __init__(self, root_dir,root_dir2, transform=None): # __init__是初始化该类的一些基础参数 17 | self.root_dir = root_dir # 文件目录 18 | self.root_dir2= root_dir2 19 | self.transform = transform # 变换 20 | # self.images = os.listdir(self.root_dir) # 目录里的所有文件 21 | self.data_list = [] 22 | self.label_list = [] 23 | for root, dirs, files in os.walk(root_dir): # ./clip/train/ 24 | for dir in dirs: 25 | for root2, dirs2, files2 in os.walk(root_dir + dir): # ./clip/train/00/ 26 | for file2 in files2: 27 | if file2.endswith(".wav"): 28 | self.data_list.append(os.path.join(root2, file2)) # ./clip/train/00/223_1/00235.jpg 29 | self.label_list.append(dir) # 00 30 | for root, dirs, files in os.walk(root_dir2): # ./clip/train/ 31 | for dir in dirs: 32 | for root2, dirs2, files2 in os.walk(root_dir2 + dir): # ./clip/train/00/ 33 | for file2 in files2: 34 | if file2.endswith(".wav"): 35 | self.data_list.append(os.path.join(root2, file2)) # ./clip/train/00/223_1/00235.jpg 36 | self.label_list.append(dir) # 00 37 | 38 | def __len__(self): # 返回整个数据集的大小 39 | return len(self.data_list) 40 | 41 | def __getitem__(self, index): # 根据索引index返回dataset[index] 42 | f = wave.open(self.data_list[index], 'rb') # 获取索引为index的图片的路径名 43 | params = f.getparams() 44 | nchannels, sampwidth, framerate, nframes = params[:4] 45 | strData = f.readframes(nframes) # 读取音频,字符串格式 46 | waveData = np.fromstring(strData, dtype=np.int16) # 将字符串转化为int 47 | waveData = torch.tensor(waveData).unsqueeze(0) 48 | waveData = waveData.type(torch.FloatTensor) 49 | label = int(self.label_list[index]) 50 | sample = (waveData, label) # 根据图片和标签创建元组 51 | # if self.transform: 52 | # sample = self.transform(sample) # 对样本进行变换 53 | return sample 54 | class Depression_wav_order_test14(Dataset): 55 | def __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数 56 | self.root_dir = root_dir # 文件目录 57 | self.transform = transform # 变换 58 | #self.images = os.listdir(self.root_dir) # 目录里的所有文件 59 | self.data_list=[] 60 | self.label_list=[] 61 | for root, dirs, files in os.walk(root_dir): #./clip/train/ 62 | for dir in sorted(dirs): 63 | for root2, dirs2, files2 in os.walk(root_dir+dir): #./clip/train/00/ 64 | for file2 in files2: 65 | if file2.endswith(".wav"): 66 | self.data_list.append(os.path.join(root2,file2)) # ./clip/train/00/223_1/00235.jpg 67 | self.label_list.append(dir) # 00 68 | 69 | def __len__(self): # 返回整个数据集的大小 70 | return len(self.data_list) 71 | 72 | def __getitem__(self, index): # 根据索引index返回dataset[index] 73 | f = wave.open(self.data_list[index], 'rb') # 获取索引为index的图片的路径名 74 | params = f.getparams() 75 | nchannels, sampwidth, framerate, nframes = params[:4] 76 | strData = f.readframes(nframes) # 读取音频,字符串格式 77 | waveData = np.fromstring(strData, dtype=np.int16) # 将字符串转化为int 78 | waveData = torch.tensor(waveData).unsqueeze(0) 79 | waveData = waveData.type(torch.FloatTensor) 80 | label = int(self.label_list[index]) 81 | name = self.data_list[index].split('/')[-2] 82 | sample = (waveData, label,str(name)) # 根据图片和标签创建元组 83 | # if self.transform: 84 | # sample = self.transform(sample) # 对样本进行变换 85 | return sample 86 | class Depression_wav_random_train(Dataset): 87 | def __init__(self, root_dir, transform=None,flag=3): # __init__是初始化该类的一些基础参数 88 | self.root_dir = root_dir # 文件目录 89 | self.transform = transform # 变换 90 | self.vad = webrtcvad.Vad(1) 91 | self.data_list = [] 92 | self.label_list = [] 93 | for j in range(160): 94 | list_shuffle = [] 95 | labels = os.listdir(root_dir) 96 | for i in range(len(labels)): 97 | list = [] 98 | for root, dirs, files in os.walk(os.path.join(root_dir, labels[i])): 99 | for file in files: 100 | if file.endswith(".wav"): 101 | list.append(os.path.join(root, file)) 102 | random.shuffle(list) 103 | list = list[:10] 104 | list_shuffle += list 105 | random.shuffle(list_shuffle) 106 | self.data_list += list_shuffle 107 | for k in range(len(self.data_list)): 108 | self.label_list.append((self.data_list[k].split('/'))[flag]) 109 | 110 | def __len__(self): # 返回整个数据集的大小 111 | return len(self.data_list) 112 | 113 | def __getitem__(self, index): # 根据索引index返回dataset[index] 114 | f = wave.open(self.data_list[index], 'rb') # 获取索引为index的图片的路径名 115 | params = f.getparams() 116 | nchannels, sampwidth, framerate, nframes = params[:4] 117 | strData = f.readframes(nframes) # 读取音频,字符串格式 118 | waveData = np.fromstring(strData, dtype=np.int16) # 将字符串转化为int 119 | 120 | waveData = torch.tensor(waveData).unsqueeze(0) 121 | waveData = waveData.type(torch.FloatTensor) 122 | label = int(self.label_list[index]) 123 | sample = (waveData, label) # 根据图片和标签创建元组 124 | # if self.transform: 125 | # sample = self.transform(sample) # 对样本进行变换 126 | return sample 127 | 128 | 129 | class Depression_wav_order_test(Dataset): 130 | def __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数 131 | self.root_dir = root_dir # 文件目录 132 | self.transform = transform # 变换 133 | #self.images = os.listdir(self.root_dir) # 目录里的所有文件 134 | self.data_list=[] 135 | # self.label_list=[] 136 | for root, dirs, files in os.walk(root_dir): #./clip/train/ 137 | for dir in sorted(dirs): 138 | for root2, dirs2, files2 in os.walk(os.path.join(root_dir, dir)): #./clip/train/00/ 139 | for file2 in files2: 140 | if file2.endswith(".wav"): 141 | self.data_list.append(os.path.join(root2,file2)) # ./clip/train/00/223_1/00235.jpg 142 | # self.label_list.append(dir) # 00 143 | 144 | def __len__(self): # 返回整个数据集的大小 145 | return len(self.data_list) 146 | 147 | def __getitem__(self, index): # 根据索引index返回dataset[index] 148 | f = wave.open(self.data_list[index], 'rb') # 获取索引为index的图片的路径名 149 | params = f.getparams() 150 | nchannels, sampwidth, framerate, nframes = params[:4] 151 | strData = f.readframes(nframes) # 读取音频,字符串格式 152 | waveData = np.fromstring(strData, dtype=np.int16) # 将字符串转化为int 153 | waveData = torch.tensor(waveData).unsqueeze(0) 154 | waveData = waveData.type(torch.FloatTensor) 155 | label = int(self.data_list[index].split('/')[-3]) 156 | # label = int(self.label_list[index]) 157 | name = self.data_list[index].split('/')[-2] 158 | sample = (waveData, label, str(name)) 159 | return sample 160 | 161 | # if vad.is_speech(waveData, framerate): 162 | 163 | # else: 164 | 165 | 166 | class Depression_wav_order_dev(Dataset): 167 | def __init__(self, root_dir, transform=None): # __init__是初始化该类的一些基础参数 168 | self.root_dir = root_dir # 文件目录 169 | self.transform = transform # 变换 170 | #self.images = os.listdir(self.root_dir) # 目录里的所有文件 171 | self.data_list=[] 172 | self.label_list=[] 173 | for root, dirs, files in os.walk(root_dir): #./clip/train/ 174 | for dir in dirs: 175 | for root2, dirs2, files2 in os.walk(os.path.join(root_dir, dir)): #./clip/train/00/ 176 | for file2 in files2: 177 | if file2.endswith(".wav"): 178 | self.data_list.append(os.path.join(root2,file2)) # ./clip/train/00/223_1/00235.jpg 179 | self.label_list.append(dir) # 00 180 | 181 | def __len__(self): # 返回整个数据集的大小 182 | return len(self.data_list) 183 | 184 | def __getitem__(self, index): # 根据索引index返回dataset[index] 185 | f = wave.open(self.data_list[index], 'rb') # 获取索引为index的图片的路径名 186 | params = f.getparams() 187 | nchannels, sampwidth, framerate, nframes = params[:4] 188 | strData = f.readframes(nframes) # 读取音频,字符串格式 189 | waveData = np.fromstring(strData, dtype=np.int16) # 将字符串转化为int 190 | waveData = torch.tensor(waveData).unsqueeze(0) 191 | waveData = waveData.type(torch.FloatTensor) 192 | label = int(self.label_list[index]) 193 | sample = (waveData, label) # 根据图片和标签创建元组 194 | # if self.transform: 195 | # sample = self.transform(sample) # 对样本进行变换 196 | return sample 197 | 198 | def train_data_loader(root,batch_size,shuffle,flag): 199 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 200 | traindata = Depression_wav_random_train(root, transform=transform,flag=flag) # 初始化类,设置数据集所在路径以及变换 201 | # print('batch_size: ', batch_size) 202 | trainloader = DataLoader(traindata, batch_size=batch_size, shuffle=shuffle,num_workers=4) # 使用DataLoader加载数据 203 | return trainloader 204 | def test_data_loader(root='./audio_wav_3s/test/',batch_size=1,shuffle=False): 205 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 206 | testdata = Depression_wav_order_test(root, transform=transform) # 初始化类,设置数据集所在路径以及变换 207 | testloader = DataLoader(testdata, batch_size=batch_size, shuffle=shuffle, num_workers=4,drop_last=True) # 使用DataLoader加载数据 208 | return testloader 209 | def val_data_loader(root='./audio_wav_3s/dev/',batch_size=64,shuffle=False): 210 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 211 | testdata = Depression_wav_order_dev(root, transform=transform) # 初始化类,设置数据集所在路径以及变换 212 | testloader = DataLoader(testdata, batch_size=batch_size, shuffle=shuffle, num_workers=4,drop_last=True) # 使用DataLoader加载数据 213 | return testloader 214 | def train_data_loader14(root,root2,batch_size,shuffle): 215 | transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) 216 | traindata = Depression_wav_order_train14(root,root2, transform=transform) # 初始化类,设置数据集所在路径以及变换 217 | trainloader = DataLoader(traindata, batch_size=batch_size, shuffle=shuffle,num_workers=4) # 使用DataLoader加载数据 218 | return trainloader 219 | 220 | -------------------------------------------------------------------------------- /dc_crn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import os 5 | import sys 6 | # from show import show_params, show_model 7 | import torch.nn.functional as F 8 | from conv_stft import ConvSTFT, ConviSTFT 9 | 10 | from complexnn import ComplexConv2d, ComplexConvTranspose2d, NavieComplexLSTM, complex_cat, ComplexBatchNorm 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class SelfAttention(nn.Module): 16 | def __init__(self, hidden_size, mean_only=False): 17 | super(SelfAttention, self).__init__() 18 | 19 | #self.output_size = output_size 20 | self.hidden_size = hidden_size 21 | self.att_weights = nn.Parameter(torch.Tensor(1, hidden_size),requires_grad=True) 22 | 23 | self.mean_only = mean_only 24 | 25 | init.kaiming_uniform_(self.att_weights) 26 | 27 | def forward(self, inputs): 28 | 29 | batch_size = inputs.size(0) 30 | weights = torch.bmm(inputs, self.att_weights.permute(1, 0).unsqueeze(0).repeat(batch_size, 1, 1)) 31 | 32 | if inputs.size(0)==1: 33 | attentions = F.softmax(torch.tanh(weights),dim=1) 34 | weighted = torch.mul(inputs, attentions.expand_as(inputs)) 35 | else: 36 | attentions = F.softmax(torch.tanh(weights.squeeze()),dim=1) 37 | weighted = torch.mul(inputs, attentions.unsqueeze(2).expand_as(inputs)) 38 | 39 | if self.mean_only: 40 | return weighted.sum(1) 41 | else: 42 | noise = 1e-5*torch.randn(weighted.size()) 43 | 44 | if inputs.is_cuda: 45 | noise = noise.to(inputs.device) 46 | avg_repr, std_repr = weighted.sum(1), (weighted+noise).std(1) 47 | 48 | representations = torch.cat((avg_repr,std_repr),1) 49 | 50 | return representations 51 | 52 | class SELayer(nn.Module): 53 | def __init__(self, channel, reduction=16): 54 | super(SELayer, self).__init__() 55 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 56 | self.fc = nn.Sequential( 57 | nn.Linear(channel, channel // reduction, bias=False), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(channel // reduction, channel, bias=False), 60 | nn.Sigmoid() 61 | ) 62 | 63 | def forward(self, x): 64 | b, c, _, _ = x.size() 65 | y = self.avg_pool(x).view(b, c) 66 | y = self.fc(y).view(b, c, 1, 1) 67 | return x * y.expand_as(x) 68 | 69 | class SEBasicBlock(nn.Module): 70 | expansion = 1 71 | 72 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 73 | base_width=64, dilation=1, norm_layer=None, 74 | *, reduction=16): 75 | super(SEBasicBlock, self).__init__() 76 | self.conv1 = conv3x3(inplanes, planes, stride) 77 | self.bn1 = nn.BatchNorm2d(planes) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.conv2 = conv3x3(planes, planes, 1) 80 | self.bn2 = nn.BatchNorm2d(planes) 81 | self.se = SELayer(planes, reduction) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | residual = x 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.se(out) 94 | 95 | if self.downsample is not None: 96 | residual = self.downsample(x) 97 | 98 | out += residual 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | class Conv_regression(nn.Module): 104 | def __init__(self,channel, length): #(64,7999) 105 | super(Conv_regression, self).__init__() 106 | self.conv1 = nn.Conv1d(channel,512,kernel_size=3,stride=1,padding=1) 107 | self.bn1 = nn.BatchNorm1d(512) 108 | self.conv2 = nn.Conv1d(512, 256, kernel_size=3, stride=1, padding=1) 109 | self.bn2 = nn.BatchNorm1d(256) 110 | self.conv3 = nn.Conv1d(256, 128, kernel_size=3, stride=1, padding=1) 111 | self.bn3 = nn.BatchNorm1d(128) 112 | self.relu= nn.ReLU() 113 | self.avgpool1=nn.AdaptiveAvgPool1d(int(length/3)) 114 | self.avgpool2 = nn.AdaptiveAvgPool1d( int(length / 9)) 115 | self.avgpool3 = nn.AdaptiveAvgPool1d( int(length / 27)) 116 | self.linear1 = nn.Linear(int(length / 27),1) 117 | self.linear2 = nn.Linear(128, 1) 118 | 119 | 120 | def forward(self, x): 121 | y = self.relu(self.bn1(self.conv1(x))) 122 | # print('y1: ', y.size()) 123 | y = self.avgpool1(y) 124 | # print('y2: ', y.size()) 125 | y = self.relu(self.bn2(self.conv2(y))) 126 | # print('y3: ', y.size()) 127 | y = self.avgpool2(y) 128 | # print('y4: ', y.size()) 129 | y = self.relu(self.bn3(self.conv3(y))) 130 | # print('y5: ', y.size()) 131 | y = self.avgpool3(y) 132 | # print('y6: ', y.size()) 133 | y = self.linear1(self.relu(y)) 134 | y = torch.squeeze(y) 135 | y = self.linear2(self.relu(y)) 136 | return y 137 | 138 | class Conv_regression_selfattention(nn.Module): 139 | def __init__(self,channel): #(64,7999) 140 | super(Conv_regression_selfattention, self).__init__() 141 | self.conv5 = nn.Conv2d(channel, 256, kernel_size=(4, 3), stride=(1, 1), padding=(0, 1), bias=False) 142 | self.bn5 = nn.BatchNorm2d(256) 143 | 144 | self.activation = nn.ReLU() 145 | 146 | self.attention = SelfAttention(256) 147 | 148 | self.fc = nn.Linear(256 * 2, 128) 149 | self.fc_mu = nn.Linear(128, 1) 150 | 151 | 152 | def forward(self, x): 153 | # print('x1: ', x.size()) 154 | x = self.conv5(x) 155 | # print('x2: ', x.size()) 156 | x = self.activation(self.bn5(x)).squeeze(2) 157 | # print('x3: ', x.size()) 158 | 159 | stats = self.attention(x.permute(0, 2, 1).contiguous()) 160 | # print('stats: ', stats.size()) 161 | 162 | feat = self.fc(stats) 163 | # print('x4: ', feat.size()) 164 | 165 | mu = self.fc_mu(feat) 166 | # print('x5: ', mu.size()) 167 | return mu 168 | 169 | 170 | class PreActBlock(nn.Module): 171 | '''Pre-activation version of the BasicBlock.''' 172 | expansion = 1 173 | 174 | def __init__(self, in_planes, planes, stride, *args, **kwargs): 175 | super(PreActBlock, self).__init__() 176 | self.bn1 = nn.BatchNorm2d(in_planes) 177 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 178 | self.bn2 = nn.BatchNorm2d(planes) 179 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 180 | 181 | if stride != 1 or in_planes != self.expansion*planes: 182 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 183 | 184 | def forward(self, x): 185 | out = F.relu(self.bn1(x)) 186 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 187 | out = self.conv1(out) 188 | out = self.conv2(F.relu(self.bn2(out))) 189 | out += shortcut 190 | return out 191 | 192 | 193 | class PreActBottleneck(nn.Module): 194 | '''Pre-activation version of the original Bottleneck module.''' 195 | expansion = 4 196 | 197 | def __init__(self, in_planes, planes, stride, *args, **kwargs): 198 | super(PreActBottleneck, self).__init__() 199 | self.bn1 = nn.BatchNorm2d(in_planes) 200 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 201 | self.bn2 = nn.BatchNorm2d(planes) 202 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 203 | self.bn3 = nn.BatchNorm2d(planes) 204 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 205 | 206 | if stride != 1 or in_planes != self.expansion*planes: 207 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 208 | 209 | def forward(self, x): 210 | out = F.relu(self.bn1(x)) 211 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 212 | out = self.conv1(out) 213 | out = self.conv2(F.relu(self.bn2(out))) 214 | out = self.conv3(F.relu(self.bn3(out))) 215 | out += shortcut 216 | return out 217 | 218 | def conv1x1(in_planes, out_planes, stride=1): 219 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 220 | 221 | # RESNET_CONFIGS = {'18': [[2, 2, 2, 2], PreActBlock], 222 | # '28': [[3, 4, 6, 3], PreActBlock], 223 | # '34': [[3, 4, 6, 3], PreActBlock], 224 | # '50': [[3, 4, 6, 3], PreActBottleneck], 225 | # '101': [[3, 4, 23, 3], PreActBottleneck] 226 | # } 227 | RESNET_CONFIGS = {'18': [[2, 2, 2, 2], SEBasicBlock]} 228 | 229 | class DCCRN(nn.Module): 230 | 231 | def __init__( 232 | self, 233 | rnn_layers=2, 234 | rnn_units=128, 235 | win_len=400, 236 | win_inc=100, 237 | fft_len=512, 238 | win_type='hanning', 239 | use_clstm=False, 240 | use_cbn = False, 241 | kernel_size=5, 242 | kernel_num=[16,32,64,128,256,256], 243 | resnet_type='18' 244 | ): 245 | ''' 246 | 247 | rnn_layers: the number of lstm layers in the crn, 248 | rnn_units: for clstm, rnn_units = real+imag 249 | 250 | ''' 251 | 252 | super(DCCRN, self).__init__() 253 | 254 | # for fft 255 | self.win_len = win_len 256 | self.win_inc = win_inc 257 | self.fft_len = fft_len 258 | self.win_type = win_type 259 | 260 | input_dim = win_len 261 | 262 | self.rnn_units = rnn_units 263 | self.input_dim = input_dim 264 | self.hidden_layers = rnn_layers 265 | self.kernel_size = kernel_size 266 | #self.kernel_num = [2, 8, 16, 32, 128, 128, 128] 267 | #self.kernel_num = [2, 16, 32, 64, 128, 256, 256] 268 | self.kernel_num = [2]+kernel_num 269 | self.use_clstm = use_clstm 270 | 271 | #bidirectional=True 272 | bidirectional=False 273 | fac = 2 if bidirectional else 1 274 | 275 | 276 | fix=True 277 | self.fix = fix 278 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix) 279 | 280 | # resnet 281 | self.in_planes = 16 282 | enc_dim = 256 283 | layers, block = RESNET_CONFIGS[resnet_type] 284 | self._norm_layer = nn.BatchNorm2d 285 | 286 | self.conv1 = nn.Conv2d(2, 16, kernel_size=(9, 3), stride=(3, 1), padding=(1, 1), bias=False) 287 | self.bn1 = nn.BatchNorm2d(16) 288 | self.activation = nn.ReLU() 289 | 290 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 291 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 292 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 293 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 294 | 295 | self.conv5 = nn.Conv2d(512 * block.expansion, 256, kernel_size=(6, 3), stride=(1, 1), padding=(0, 1), 296 | bias=False) 297 | self.bn5 = nn.BatchNorm2d(256) 298 | self.conv6 = nn.Conv2d(256, 256, kernel_size=(6, 3), stride=(1, 1), padding=(0, 1), 299 | bias=False) 300 | self.bn6 = nn.BatchNorm2d(256) 301 | self.fc = nn.Linear(256 * 2, enc_dim) 302 | self.fc_mu = nn.Linear(enc_dim, 1) 303 | 304 | self.initialize_params() 305 | self.attention = SelfAttention(256) 306 | 307 | def initialize_params(self): 308 | for layer in self.modules(): 309 | if isinstance(layer, torch.nn.Conv2d): 310 | init.kaiming_normal_(layer.weight, a=0, mode='fan_out') 311 | elif isinstance(layer, torch.nn.Linear): 312 | init.kaiming_uniform_(layer.weight) 313 | elif isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.BatchNorm1d): 314 | layer.weight.data.fill_(1) 315 | layer.bias.data.zero_() 316 | 317 | def _make_layer(self, block, planes, num_blocks, stride=1): 318 | norm_layer = self._norm_layer 319 | downsample = None 320 | if stride != 1 or self.in_planes != planes * block.expansion: 321 | downsample = nn.Sequential(conv1x1(self.in_planes, planes * block.expansion, stride), 322 | norm_layer(planes * block.expansion)) 323 | layers = [] 324 | layers.append(block(self.in_planes, planes, stride, downsample, 1, 64, 1, norm_layer)) 325 | self.in_planes = planes * block.expansion 326 | for _ in range(1, num_blocks): 327 | layers.append( 328 | block(self.in_planes, planes, 1, groups=1, base_width=64, dilation=False, norm_layer=norm_layer)) 329 | 330 | return nn.Sequential(*layers) 331 | 332 | 333 | def forward(self, inputs, lens=None): 334 | # print('input: ', inputs.size()) 335 | specs = self.stft(inputs) 336 | real = specs[:,:self.fft_len//2+1] 337 | imag = specs[:,self.fft_len//2+1:] 338 | spec_mags = torch.sqrt(real**2+imag**2+1e-8) 339 | spec_mags = spec_mags 340 | spec_phase = torch.atan2(imag, real) 341 | spec_phase = spec_phase 342 | cspecs = torch.stack([real,imag],1) 343 | cspecs = cspecs[:,:,1:] 344 | # print('cspecs: ', cspecs.size()) 345 | ''' 346 | means = torch.mean(cspecs, [1,2,3], keepdim=True) 347 | std = torch.std(cspecs, [1,2,3], keepdim=True ) 348 | normed_cspecs = (cspecs-means)/(std+1e-8) 349 | out = normed_cspecs 350 | ''' 351 | 352 | # print('cspecs: ', cspecs.size()) 353 | x = self.conv1(cspecs) 354 | # print('x2: ', x.size()) 355 | x = self.activation(self.bn1(x)) 356 | x = self.layer1(x) 357 | # print('x3: ', x.size()) 358 | x = self.layer2(x) 359 | # print('x4: ', x.size()) 360 | x = self.layer3(x) 361 | # print('x5: ', x.size()) 362 | x = self.layer4(x) 363 | # print('layer4: ', x.size()) 364 | x = self.bn5(self.conv5(x)) 365 | # print('conv5: ', x.size()) 366 | x = self.bn6(self.conv6(x)) 367 | x = self.activation(x).squeeze(2) 368 | # print('x8: ', x.size()) 369 | 370 | stats = self.attention(x.permute(0, 2, 1).contiguous()) 371 | # print('stats: ', stats.size()) 372 | 373 | feat = self.fc(stats) 374 | 375 | mu = self.fc_mu(feat) 376 | 377 | 378 | # y=self.Conv_regression_selfattention(out) 379 | return mu 380 | 381 | 382 | 383 | if __name__ == '__main__': 384 | torch.manual_seed(10) 385 | torch.autograd.set_detect_anomaly(True) 386 | inputs = torch.randn([10,16000*4]).clamp_(-1,1) 387 | labels = torch.randn([10,16000*4]).clamp_(-1,1) 388 | 389 | ''' 390 | # DCCRN-E 391 | net = DCCRN(rnn_units=256,masking_mode='E') 392 | outputs = net(inputs)[1] 393 | loss = net.loss(outputs, labels, loss_mode='SI-SNR') 394 | print(loss) 395 | 396 | # DCCRN-R 397 | net = DCCRN(rnn_units=256,masking_mode='R') 398 | outputs = net(inputs)[1] 399 | loss = net.loss(outputs, labels, loss_mode='SI-SNR') 400 | print(loss) 401 | 402 | # DCCRN-C 403 | net = DCCRN(rnn_units=256,masking_mode='C') 404 | outputs = net(inputs)[1] 405 | loss = net.loss(outputs, labels, loss_mode='SI-SNR') 406 | print(loss) 407 | 408 | ''' 409 | # DCCRN-CL 410 | net = DCCRN(rnn_units=256,masking_mode='E',use_clstm=True,kernel_num=[32, 64, 128, 256, 256,256]) 411 | outputs = net(inputs)[1] 412 | loss = net.loss(outputs, labels, loss_mode='SI-SNR') 413 | print(loss) 414 | -------------------------------------------------------------------------------- /dc_crn_test_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import os 5 | import sys 6 | # from show import show_params, show_model 7 | import torch.nn.functional as F 8 | from conv_stft import ConvSTFT, ConviSTFT 9 | 10 | from complexnn import ComplexConv2d, ComplexConvTranspose2d, NavieComplexLSTM, complex_cat, ComplexBatchNorm 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class SelfAttention(nn.Module): 16 | def __init__(self, hidden_size, mean_only=False): 17 | super(SelfAttention, self).__init__() 18 | 19 | #self.output_size = output_size 20 | self.hidden_size = hidden_size 21 | self.att_weights = nn.Parameter(torch.Tensor(1, hidden_size),requires_grad=True) 22 | 23 | self.mean_only = mean_only 24 | 25 | init.kaiming_uniform_(self.att_weights) 26 | 27 | def forward(self, inputs): 28 | 29 | batch_size = inputs.size(0) 30 | weights = torch.bmm(inputs, self.att_weights.permute(1, 0).unsqueeze(0).repeat(batch_size, 1, 1)) 31 | 32 | if inputs.size(0)==1: 33 | attentions = F.softmax(torch.tanh(weights),dim=1) 34 | weighted = torch.mul(inputs, attentions.expand_as(inputs)) 35 | else: 36 | attentions = F.softmax(torch.tanh(weights.squeeze()),dim=1) 37 | weighted = torch.mul(inputs, attentions.unsqueeze(2).expand_as(inputs)) 38 | 39 | if self.mean_only: 40 | return weighted.sum(1) 41 | else: 42 | noise = 1e-5*torch.randn(weighted.size()) 43 | 44 | if inputs.is_cuda: 45 | noise = noise.to(inputs.device) 46 | avg_repr, std_repr = weighted.sum(1), (weighted+noise).std(1) 47 | 48 | representations = torch.cat((avg_repr,std_repr),1) 49 | 50 | return representations, weighted 51 | 52 | class SELayer(nn.Module): 53 | def __init__(self, channel, reduction=16): 54 | super(SELayer, self).__init__() 55 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 56 | self.fc = nn.Sequential( 57 | nn.Linear(channel, channel // reduction, bias=False), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(channel // reduction, channel, bias=False), 60 | nn.Sigmoid() 61 | ) 62 | 63 | def forward(self, x): 64 | b, c, _, _ = x.size() 65 | y = self.avg_pool(x).view(b, c) 66 | y = self.fc(y).view(b, c, 1, 1) 67 | return x * y.expand_as(x) 68 | 69 | class SEBasicBlock(nn.Module): 70 | expansion = 1 71 | 72 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 73 | base_width=64, dilation=1, norm_layer=None, 74 | *, reduction=16): 75 | super(SEBasicBlock, self).__init__() 76 | self.conv1 = conv3x3(inplanes, planes, stride) 77 | self.bn1 = nn.BatchNorm2d(planes) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.conv2 = conv3x3(planes, planes, 1) 80 | self.bn2 = nn.BatchNorm2d(planes) 81 | self.se = SELayer(planes, reduction) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | residual = x 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.se(out) 94 | 95 | if self.downsample is not None: 96 | residual = self.downsample(x) 97 | 98 | out += residual 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | 104 | 105 | class PreActBlock(nn.Module): 106 | '''Pre-activation version of the BasicBlock.''' 107 | expansion = 1 108 | 109 | def __init__(self, in_planes, planes, stride, *args, **kwargs): 110 | super(PreActBlock, self).__init__() 111 | self.bn1 = nn.BatchNorm2d(in_planes) 112 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(planes) 114 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 115 | 116 | if stride != 1 or in_planes != self.expansion*planes: 117 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 118 | 119 | def forward(self, x): 120 | out = F.relu(self.bn1(x)) 121 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 122 | out = self.conv1(out) 123 | out = self.conv2(F.relu(self.bn2(out))) 124 | out += shortcut 125 | return out 126 | 127 | 128 | class PreActBottleneck(nn.Module): 129 | '''Pre-activation version of the original Bottleneck module.''' 130 | expansion = 4 131 | 132 | def __init__(self, in_planes, planes, stride, *args, **kwargs): 133 | super(PreActBottleneck, self).__init__() 134 | self.bn1 = nn.BatchNorm2d(in_planes) 135 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 136 | self.bn2 = nn.BatchNorm2d(planes) 137 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 138 | self.bn3 = nn.BatchNorm2d(planes) 139 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 140 | 141 | if stride != 1 or in_planes != self.expansion*planes: 142 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 143 | 144 | def forward(self, x): 145 | out = F.relu(self.bn1(x)) 146 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 147 | out = self.conv1(out) 148 | out = self.conv2(F.relu(self.bn2(out))) 149 | out = self.conv3(F.relu(self.bn3(out))) 150 | out += shortcut 151 | return out 152 | 153 | def conv1x1(in_planes, out_planes, stride=1): 154 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 155 | 156 | # RESNET_CONFIGS = {'18': [[2, 2, 2, 2], PreActBlock], 157 | # '28': [[3, 4, 6, 3], PreActBlock], 158 | # '34': [[3, 4, 6, 3], PreActBlock], 159 | # '50': [[3, 4, 6, 3], PreActBottleneck], 160 | # '101': [[3, 4, 23, 3], PreActBottleneck] 161 | # } 162 | RESNET_CONFIGS = {'18': [[2, 2, 2, 2], SEBasicBlock]} 163 | 164 | class DCCRN(nn.Module): 165 | 166 | def __init__( 167 | self, 168 | rnn_layers=2, 169 | rnn_units=128, 170 | win_len=400, 171 | win_inc=100, 172 | fft_len=512, 173 | win_type='hanning', 174 | use_clstm=False, 175 | use_cbn = False, 176 | kernel_size=5, 177 | kernel_num=[16,32,64,128,256,256], 178 | resnet_type='18' 179 | ): 180 | ''' 181 | 182 | rnn_layers: the number of lstm layers in the crn, 183 | rnn_units: for clstm, rnn_units = real+imag 184 | 185 | ''' 186 | 187 | super(DCCRN, self).__init__() 188 | 189 | # for fft 190 | self.win_len = win_len 191 | self.win_inc = win_inc 192 | self.fft_len = fft_len 193 | self.win_type = win_type 194 | 195 | input_dim = win_len 196 | 197 | self.rnn_units = rnn_units 198 | self.input_dim = input_dim 199 | self.hidden_layers = rnn_layers 200 | self.kernel_size = kernel_size 201 | #self.kernel_num = [2, 8, 16, 32, 128, 128, 128] 202 | #self.kernel_num = [2, 16, 32, 64, 128, 256, 256] 203 | self.kernel_num = [2]+kernel_num 204 | self.use_clstm = use_clstm 205 | 206 | #bidirectional=True 207 | bidirectional=False 208 | fac = 2 if bidirectional else 1 209 | 210 | 211 | fix=True 212 | self.fix = fix 213 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix) 214 | 215 | # resnet 216 | self.in_planes = 16 217 | enc_dim = 256 218 | layers, block = RESNET_CONFIGS[resnet_type] 219 | self._norm_layer = nn.BatchNorm2d 220 | 221 | self.conv1 = nn.Conv2d(2, 16, kernel_size=(9, 3), stride=(3, 1), padding=(1, 1), bias=False) 222 | self.bn1 = nn.BatchNorm2d(16) 223 | self.activation = nn.ReLU() 224 | 225 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 226 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 227 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 228 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 229 | 230 | self.conv5 = nn.Conv2d(512 * block.expansion, 256, kernel_size=(6, 3), stride=(1, 1), padding=(0, 1), 231 | bias=False) 232 | self.bn5 = nn.BatchNorm2d(256) 233 | self.conv6 = nn.Conv2d(256, 256, kernel_size=(6, 3), stride=(1, 1), padding=(0, 1), 234 | bias=False) 235 | self.bn6 = nn.BatchNorm2d(256) 236 | self.fc = nn.Linear(256 * 2, enc_dim) 237 | self.fc_mu = nn.Linear(enc_dim, 1) 238 | 239 | self.initialize_params() 240 | self.attention = SelfAttention(256) 241 | 242 | def initialize_params(self): 243 | for layer in self.modules(): 244 | if isinstance(layer, torch.nn.Conv2d): 245 | init.kaiming_normal_(layer.weight, a=0, mode='fan_out') 246 | elif isinstance(layer, torch.nn.Linear): 247 | init.kaiming_uniform_(layer.weight) 248 | elif isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.BatchNorm1d): 249 | layer.weight.data.fill_(1) 250 | layer.bias.data.zero_() 251 | 252 | def _make_layer(self, block, planes, num_blocks, stride=1): 253 | norm_layer = self._norm_layer 254 | downsample = None 255 | if stride != 1 or self.in_planes != planes * block.expansion: 256 | downsample = nn.Sequential(conv1x1(self.in_planes, planes * block.expansion, stride), 257 | norm_layer(planes * block.expansion)) 258 | layers = [] 259 | layers.append(block(self.in_planes, planes, stride, downsample, 1, 64, 1, norm_layer)) 260 | self.in_planes = planes * block.expansion 261 | for _ in range(1, num_blocks): 262 | layers.append( 263 | block(self.in_planes, planes, 1, groups=1, base_width=64, dilation=False, norm_layer=norm_layer)) 264 | 265 | return nn.Sequential(*layers) 266 | 267 | 268 | def forward(self, inputs, lens=None): 269 | # print('input: ', inputs.size()) 270 | specs = self.stft(inputs) 271 | real = specs[:,:self.fft_len//2+1] 272 | imag = specs[:,self.fft_len//2+1:] 273 | spec_mags = torch.sqrt(real**2+imag**2+1e-8) 274 | spec_mags = spec_mags 275 | spec_phase = torch.atan2(imag, real) 276 | spec_phase = spec_phase 277 | cspecs = torch.stack([real,imag],1) 278 | cspecs = cspecs[:,:,1:] 279 | # print('cspecs: ', cspecs.size()) 280 | ''' 281 | means = torch.mean(cspecs, [1,2,3], keepdim=True) 282 | std = torch.std(cspecs, [1,2,3], keepdim=True ) 283 | normed_cspecs = (cspecs-means)/(std+1e-8) 284 | out = normed_cspecs 285 | ''' 286 | 287 | # print('cspecs: ', cspecs.size()) 288 | x = self.conv1(cspecs) 289 | # print('x2: ', x.size()) 290 | x = self.activation(self.bn1(x)) 291 | x = self.layer1(x) 292 | # print('x3: ', x.size()) 293 | x = self.layer2(x) 294 | # print('x4: ', x.size()) 295 | x = self.layer3(x) 296 | # print('x5: ', x.size()) 297 | x = self.layer4(x) 298 | # print('layer4: ', x.size()) 299 | x = self.bn5(self.conv5(x)) 300 | print('conv5: ', x.size()) 301 | x = self.bn6(self.conv6(x)) 302 | x = self.activation(x).squeeze(2) 303 | print('conv6: ', x.size()) 304 | 305 | stats, weighted = self.attention(x.permute(0, 2, 1).contiguous()) 306 | 307 | return stats, weighted 308 | 309 | -------------------------------------------------------------------------------- /dc_crn_test_avg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import os 5 | import sys 6 | # from show import show_params, show_model 7 | import torch.nn.functional as F 8 | from conv_stft import ConvSTFT, ConviSTFT 9 | 10 | from complexnn import ComplexConv2d, ComplexConvTranspose2d, NavieComplexLSTM, complex_cat, ComplexBatchNorm 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class SelfAttention(nn.Module): 16 | def __init__(self, hidden_size, mean_only=False): 17 | super(SelfAttention, self).__init__() 18 | 19 | #self.output_size = output_size 20 | self.hidden_size = hidden_size 21 | self.att_weights = nn.Parameter(torch.Tensor(1, hidden_size),requires_grad=True) 22 | 23 | self.mean_only = mean_only 24 | 25 | init.kaiming_uniform_(self.att_weights) 26 | 27 | def forward(self, inputs): 28 | 29 | batch_size = inputs.size(0) 30 | weights = torch.bmm(inputs, self.att_weights.permute(1, 0).unsqueeze(0).repeat(batch_size, 1, 1)) 31 | 32 | if inputs.size(0)==1: 33 | attentions = F.softmax(torch.tanh(weights),dim=1) 34 | weighted = torch.mul(inputs, attentions.expand_as(inputs)) 35 | else: 36 | attentions = F.softmax(torch.tanh(weights.squeeze()),dim=1) 37 | weighted = torch.mul(inputs, attentions.unsqueeze(2).expand_as(inputs)) 38 | 39 | # if self.mean_only: 40 | # return weighted.sum(1) 41 | # else: 42 | # noise = 1e-5*torch.randn(weighted.size()) 43 | 44 | # if inputs.is_cuda: 45 | # noise = noise.to(inputs.device) 46 | # avg_repr, std_repr = weighted.sum(1), (weighted+noise).std(1) 47 | 48 | # representations = torch.cat((avg_repr,std_repr),1) 49 | 50 | # return representations, weighted 51 | return weighted 52 | 53 | class SELayer(nn.Module): 54 | def __init__(self, channel, reduction=16): 55 | super(SELayer, self).__init__() 56 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 57 | self.fc = nn.Sequential( 58 | nn.Linear(channel, channel // reduction, bias=False), 59 | nn.ReLU(inplace=True), 60 | nn.Linear(channel // reduction, channel, bias=False), 61 | nn.Sigmoid() 62 | ) 63 | 64 | def forward(self, x): 65 | b, c, _, _ = x.size() 66 | y = self.avg_pool(x).view(b, c) 67 | y = self.fc(y).view(b, c, 1, 1) 68 | return x * y.expand_as(x) 69 | 70 | class SEBasicBlock(nn.Module): 71 | expansion = 1 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 74 | base_width=64, dilation=1, norm_layer=None, 75 | *, reduction=16): 76 | super(SEBasicBlock, self).__init__() 77 | self.conv1 = conv3x3(inplanes, planes, stride) 78 | self.bn1 = nn.BatchNorm2d(planes) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.conv2 = conv3x3(planes, planes, 1) 81 | self.bn2 = nn.BatchNorm2d(planes) 82 | self.se = SELayer(planes, reduction) 83 | self.downsample = downsample 84 | self.stride = stride 85 | 86 | def forward(self, x): 87 | residual = x 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.se(out) 95 | 96 | if self.downsample is not None: 97 | residual = self.downsample(x) 98 | 99 | out += residual 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | 105 | 106 | class PreActBlock(nn.Module): 107 | '''Pre-activation version of the BasicBlock.''' 108 | expansion = 1 109 | 110 | def __init__(self, in_planes, planes, stride, *args, **kwargs): 111 | super(PreActBlock, self).__init__() 112 | self.bn1 = nn.BatchNorm2d(in_planes) 113 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 114 | self.bn2 = nn.BatchNorm2d(planes) 115 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 116 | 117 | if stride != 1 or in_planes != self.expansion*planes: 118 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 119 | 120 | def forward(self, x): 121 | out = F.relu(self.bn1(x)) 122 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 123 | out = self.conv1(out) 124 | out = self.conv2(F.relu(self.bn2(out))) 125 | out += shortcut 126 | return out 127 | 128 | 129 | class PreActBottleneck(nn.Module): 130 | '''Pre-activation version of the original Bottleneck module.''' 131 | expansion = 4 132 | 133 | def __init__(self, in_planes, planes, stride, *args, **kwargs): 134 | super(PreActBottleneck, self).__init__() 135 | self.bn1 = nn.BatchNorm2d(in_planes) 136 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 137 | self.bn2 = nn.BatchNorm2d(planes) 138 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 139 | self.bn3 = nn.BatchNorm2d(planes) 140 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 141 | 142 | if stride != 1 or in_planes != self.expansion*planes: 143 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 144 | 145 | def forward(self, x): 146 | out = F.relu(self.bn1(x)) 147 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 148 | out = self.conv1(out) 149 | out = self.conv2(F.relu(self.bn2(out))) 150 | out = self.conv3(F.relu(self.bn3(out))) 151 | out += shortcut 152 | return out 153 | 154 | def conv1x1(in_planes, out_planes, stride=1): 155 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 156 | 157 | # RESNET_CONFIGS = {'18': [[2, 2, 2, 2], PreActBlock], 158 | # '28': [[3, 4, 6, 3], PreActBlock], 159 | # '34': [[3, 4, 6, 3], PreActBlock], 160 | # '50': [[3, 4, 6, 3], PreActBottleneck], 161 | # '101': [[3, 4, 23, 3], PreActBottleneck] 162 | # } 163 | RESNET_CONFIGS = {'18': [[2, 2, 2, 2], SEBasicBlock]} 164 | 165 | class DCCRN(nn.Module): 166 | 167 | def __init__( 168 | self, 169 | rnn_layers=2, 170 | rnn_units=128, 171 | win_len=400, 172 | win_inc=100, 173 | fft_len=512, 174 | win_type='hanning', 175 | use_clstm=False, 176 | use_cbn = False, 177 | kernel_size=5, 178 | kernel_num=[16,32,64,128,256,256], 179 | resnet_type='18' 180 | ): 181 | ''' 182 | 183 | rnn_layers: the number of lstm layers in the crn, 184 | rnn_units: for clstm, rnn_units = real+imag 185 | 186 | ''' 187 | 188 | super(DCCRN, self).__init__() 189 | 190 | # for fft 191 | self.win_len = win_len 192 | self.win_inc = win_inc 193 | self.fft_len = fft_len 194 | self.win_type = win_type 195 | 196 | input_dim = win_len 197 | 198 | self.rnn_units = rnn_units 199 | self.input_dim = input_dim 200 | self.hidden_layers = rnn_layers 201 | self.kernel_size = kernel_size 202 | #self.kernel_num = [2, 8, 16, 32, 128, 128, 128] 203 | #self.kernel_num = [2, 16, 32, 64, 128, 256, 256] 204 | self.kernel_num = [2]+kernel_num 205 | self.use_clstm = use_clstm 206 | 207 | #bidirectional=True 208 | bidirectional=False 209 | fac = 2 if bidirectional else 1 210 | 211 | 212 | fix=True 213 | self.fix = fix 214 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix) 215 | 216 | # resnet 217 | self.in_planes = 16 218 | enc_dim = 256 219 | layers, block = RESNET_CONFIGS[resnet_type] 220 | self._norm_layer = nn.BatchNorm2d 221 | 222 | self.conv1 = nn.Conv2d(2, 16, kernel_size=(9, 3), stride=(3, 1), padding=(1, 1), bias=False) 223 | self.bn1 = nn.BatchNorm2d(16) 224 | self.activation = nn.ReLU() 225 | 226 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 227 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 228 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 229 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 230 | 231 | self.conv5 = nn.Conv2d(512 * block.expansion, 256, kernel_size=(6, 3), stride=(1, 1), padding=(0, 1), 232 | bias=False) 233 | self.bn5 = nn.BatchNorm2d(256) 234 | self.conv6 = nn.Conv2d(256, 256, kernel_size=(6, 3), stride=(1, 1), padding=(0, 1), 235 | bias=False) 236 | self.bn6 = nn.BatchNorm2d(256) 237 | self.fc = nn.Linear(256 * 2, enc_dim) 238 | self.fc_mu = nn.Linear(enc_dim, 1) 239 | 240 | self.initialize_params() 241 | self.attention = SelfAttention(256) 242 | 243 | def initialize_params(self): 244 | for layer in self.modules(): 245 | if isinstance(layer, torch.nn.Conv2d): 246 | init.kaiming_normal_(layer.weight, a=0, mode='fan_out') 247 | elif isinstance(layer, torch.nn.Linear): 248 | init.kaiming_uniform_(layer.weight) 249 | elif isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.BatchNorm1d): 250 | layer.weight.data.fill_(1) 251 | layer.bias.data.zero_() 252 | 253 | def _make_layer(self, block, planes, num_blocks, stride=1): 254 | norm_layer = self._norm_layer 255 | downsample = None 256 | if stride != 1 or self.in_planes != planes * block.expansion: 257 | downsample = nn.Sequential(conv1x1(self.in_planes, planes * block.expansion, stride), 258 | norm_layer(planes * block.expansion)) 259 | layers = [] 260 | layers.append(block(self.in_planes, planes, stride, downsample, 1, 64, 1, norm_layer)) 261 | self.in_planes = planes * block.expansion 262 | for _ in range(1, num_blocks): 263 | layers.append( 264 | block(self.in_planes, planes, 1, groups=1, base_width=64, dilation=False, norm_layer=norm_layer)) 265 | 266 | return nn.Sequential(*layers) 267 | 268 | 269 | def forward(self, inputs, lens=None): 270 | # print('input: ', inputs.size()) 271 | specs = self.stft(inputs) 272 | real = specs[:,:self.fft_len//2+1] 273 | imag = specs[:,self.fft_len//2+1:] 274 | spec_mags = torch.sqrt(real**2+imag**2+1e-8) 275 | spec_mags = spec_mags 276 | spec_phase = torch.atan2(imag, real) 277 | spec_phase = spec_phase 278 | cspecs = torch.stack([real,imag],1) 279 | cspecs = cspecs[:,:,1:] 280 | # print('cspecs: ', cspecs.size()) 281 | ''' 282 | means = torch.mean(cspecs, [1,2,3], keepdim=True) 283 | std = torch.std(cspecs, [1,2,3], keepdim=True ) 284 | normed_cspecs = (cspecs-means)/(std+1e-8) 285 | out = normed_cspecs 286 | ''' 287 | 288 | # print('cspecs: ', cspecs.size()) 289 | x = self.conv1(cspecs) 290 | # print('x2: ', x.size()) 291 | x = self.activation(self.bn1(x)) 292 | x = self.layer1(x) 293 | # print('x3: ', x.size()) 294 | x = self.layer2(x) 295 | # print('x4: ', x.size()) 296 | x = self.layer3(x) 297 | # print('x5: ', x.size()) 298 | x = self.layer4(x) 299 | # print('layer4: ', x.size()) 300 | x = self.bn5(self.conv5(x)) 301 | # print('conv5: ', x.size()) 302 | x = self.bn6(self.conv6(x)) 303 | x = self.activation(x).squeeze(2) 304 | # print('conv6: ', x.size()) 305 | 306 | weighted = self.attention(x.permute(0, 2, 1).contiguous()) 307 | 308 | return weighted 309 | 310 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import logging 6 | import os 7 | import pickle 8 | batch_size=64 9 | logging.basicConfig(level=logging.DEBUG) 10 | logger=logging.getLogger(__name__) 11 | 12 | 13 | def test(net,testloader,device): 14 | net.eval() 15 | batch_loss=0 16 | criterion = nn.MSELoss() 17 | with torch.no_grad(): 18 | for i,data in enumerate(testloader): 19 | images, labels = data 20 | images, labels=images.to(device),labels.to(device) 21 | outputs = net(images) 22 | outputs=outputs.squeeze(-1) 23 | outputs=outputs.to(device) 24 | loss = criterion(outputs, labels.float()) 25 | batch_loss += loss.item() 26 | batch_num=i+1 27 | mse=(batch_loss/batch_num)**0.5 28 | logger.info("test mse is: %5f" % mse) 29 | return mse 30 | 31 | def test_num(net,testloader,device): 32 | net.eval() 33 | batch_loss=0 34 | j=1 35 | tem_name="316_1" 36 | tem_label=0 37 | total_mse=0 38 | total_mae=0 39 | batch_num=0 40 | sum_outputs = torch.tensor([0]) 41 | sum_outputs = sum_outputs.to(device) 42 | with torch.no_grad(): 43 | for i,data in enumerate(testloader): 44 | images, labels,name = data 45 | name=str(name).split("'")[1] 46 | # print(name) 47 | if (tem_name==name): 48 | images, labels = images.to(device), labels.to(device) 49 | outputs = net(images) 50 | outputs = outputs.squeeze(-1) 51 | outputs = outputs.to(device) 52 | sum_outputs=sum_outputs+outputs 53 | batch_loss += 1 54 | batch_num = batch_num + 1 55 | else: 56 | if batch_num == 0: 57 | predict=sum_outputs 58 | else: 59 | predict=sum_outputs/batch_num 60 | logger.info("%s test label is : %d ,predict is: %5f, sum_outputs: %5f, batch_num: %d" % 61 | (tem_name,tem_label,float(predict), sum_outputs, batch_num)) 62 | total_mse = total_mse + math.pow(float(predict)-tem_label,2) 63 | total_mae = total_mae + abs(float(predict) - tem_label) 64 | j += 1 65 | batch_loss = 0 66 | batch_num = 0 67 | sum_outputs = 0 68 | tem_name = name 69 | tem_label = labels 70 | predict = sum_outputs / batch_num 71 | logger.info("%s test label is : %d ,predict is: %5f" % (name, labels, predict)) 72 | total_mse = total_mse + math.pow(float(predict)-int(labels),2) 73 | total_mae = total_mae + abs(float(predict) - tem_label) 74 | total_mse=math.sqrt(total_mse/j) 75 | total_mae = (total_mae / j).item() 76 | logger.info(total_mse) 77 | logger.info(total_mae) 78 | return total_mse,total_mae 79 | 80 | def figure(net,testloader,device): 81 | net.eval() 82 | tem_name="317_4" 83 | j=0 84 | with torch.no_grad(): 85 | for i,data in enumerate(testloader): 86 | images, labels, name = data 87 | if(tem_name!=name): 88 | j=0 89 | tem_name=name 90 | images=images.to(device) 91 | y,atty = net(images) 92 | root="../figure/y/"+str(int(labels))+"/"+str(name[0]) 93 | if not os.path.exists(root): 94 | os.makedirs(root) 95 | y_output = open(root+"/"+str(j)+".pkl", 'wb') 96 | pickle.dump(y, y_output) 97 | y_output.close() 98 | 99 | attroot = "../figure/atty/" + str(int(labels)) + "/" + str(name[0]) 100 | if not os.path.exists(attroot): 101 | os.makedirs(attroot) 102 | atty_output = open(attroot + "/" + str(j) + ".pkl", 'wb') 103 | pickle.dump(atty, atty_output) 104 | atty_output.close() 105 | j+=1 -------------------------------------------------------------------------------- /eval_avg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | import logging 6 | import os 7 | import pickle 8 | batch_size=64 9 | logging.basicConfig(level=logging.DEBUG) 10 | logger=logging.getLogger(__name__) 11 | 12 | 13 | def test(net,testloader,device): 14 | net.eval() 15 | batch_loss=0 16 | criterion = nn.MSELoss() 17 | with torch.no_grad(): 18 | for i,data in enumerate(testloader): 19 | images, labels = data 20 | images, labels=images.to(device),labels.to(device) 21 | outputs = net(images) 22 | outputs=outputs.squeeze(-1) 23 | outputs=outputs.to(device) 24 | loss = criterion(outputs, labels.float()) 25 | batch_loss += loss.item() 26 | batch_num=i+1 27 | mse=(batch_loss/batch_num)**0.5 28 | logger.info("test mse is: %5f" % mse) 29 | return mse 30 | 31 | def test_num(net, net2, testloader,device): 32 | net.eval() 33 | batch_loss=0 34 | j=1 35 | tem_name="316_1" 36 | tem_label=0 37 | total_mse=0 38 | total_mae=0 39 | batch_num=0 40 | sum_outputs = torch.tensor([0]) 41 | sum_outputs = sum_outputs.to(device) 42 | with torch.no_grad(): 43 | for i,data in enumerate(testloader): 44 | images, labels,name = data 45 | name=str(name).split("'")[1] 46 | # print(name) 47 | if (tem_name==name): 48 | images, labels = images.to(device), labels.to(device) 49 | weighted = net(images) 50 | if batch_num == 0: 51 | sum_weighted = weighted*0 52 | sum_weighted = sum_weighted + weighted 53 | # outputs = outputs.squeeze(-1) 54 | # outputs = outputs.to(device) 55 | # sum_outputs=sum_outputs+outputs 56 | # batch_loss += 1 57 | batch_num = batch_num + 1 58 | else: 59 | if batch_num == 0: 60 | # predict=sum_weighted 61 | predict = net2(images, sum_weighted) 62 | else: 63 | weighted_avg = sum_weighted/batch_num 64 | 65 | predict = net2(images, weighted_avg) 66 | # print(predict.size()) 67 | # sum_outputs = sum_outputs + predict 68 | logger.info("%s test label is : %d ,predict is: %5f" % 69 | (tem_name,tem_label,float(predict))) 70 | total_mse = total_mse + math.pow(float(predict)-tem_label,2) 71 | total_mae = total_mae + abs(float(predict) - tem_label) 72 | j += 1 73 | batch_loss = 0 74 | batch_num = 0 75 | sum_outputs = 0 76 | tem_name = name 77 | tem_label = labels 78 | # predict = sum_outputs / batch_num 79 | weighted_avg = sum_weighted/batch_num 80 | predict = net2(images, weighted_avg) 81 | logger.info("%s test label is : %d ,predict is: %5f" % (name, labels, predict)) 82 | total_mse = total_mse + math.pow(float(predict)-int(labels),2) 83 | total_mae = total_mae + abs(float(predict) - tem_label) 84 | total_mse=math.sqrt(total_mse/j) 85 | total_mae = (total_mae / j).item() 86 | logger.info(total_mse) 87 | logger.info(total_mae) 88 | return total_mse,total_mae 89 | 90 | def figure(net,testloader,device): 91 | net.eval() 92 | tem_name="317_4" 93 | j=0 94 | with torch.no_grad(): 95 | for i,data in enumerate(testloader): 96 | images, labels, name = data 97 | if(tem_name!=name): 98 | j=0 99 | tem_name=name 100 | images=images.to(device) 101 | y,atty = net(images) 102 | root="../figure/y/"+str(int(labels))+"/"+str(name[0]) 103 | if not os.path.exists(root): 104 | os.makedirs(root) 105 | y_output = open(root+"/"+str(j)+".pkl", 'wb') 106 | pickle.dump(y, y_output) 107 | y_output.close() 108 | 109 | attroot = "../figure/atty/" + str(int(labels)) + "/" + str(name[0]) 110 | if not os.path.exists(attroot): 111 | os.makedirs(attroot) 112 | atty_output = open(attroot + "/" + str(j) + ".pkl", 'wb') 113 | pickle.dump(atty, atty_output) 114 | atty_output.close() 115 | j+=1 116 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import logging 4 | import os 5 | from torchsummary import summary 6 | #from resnet import * 7 | from pathlib import Path 8 | from eval import test,test_num 9 | from TasNet import ConvTasNet 10 | from dc_crn import DCCRN 11 | from train import train 12 | from model import * 13 | from dataload import * 14 | classes_num=1 15 | logging.basicConfig(level=logging.DEBUG) 16 | logger=logging.getLogger(__name__) 17 | 18 | def load_net(net,model_pkl): 19 | logger.info("load:%s"%model_pkl) 20 | net.load_state_dict(torch.load("../pkl/"+model_pkl)) 21 | return net 22 | def count_parameters(model): 23 | parameters_sum = sum(p.numel() for p in model.parameters() if p.requires_grad) 24 | print(parameters_sum) 25 | 26 | 27 | def train_3s(data_root, save_path): 28 | torch.manual_seed(1234) 29 | torch.cuda.manual_seed(1234) 30 | torch.backends.cudnn.deterministic = True 31 | 32 | dev_root = os.path.join(data_root, 'dev') 33 | val_data=val_data_loader(dev_root, batch_size=64, shuffle=False) 34 | # val_data=val_data_loader('../audio_3s/dev/',batch_size=64,shuffle=False) 35 | # net = ConvTasNet(X=4,R=2) # 模型 36 | net = DCCRN(rnn_units=256,use_clstm=True,kernel_num=[32, 64, 128, 256, 256,256]) 37 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 38 | net = torch.nn.DataParallel(net, device_ids=[0]) #GPU配置 39 | net.to(device) 40 | train(data_root, net=net, epoch_num=100, trainloader="3s", valloader=val_data,batch_size=64, 41 | device=device,save_path=save_path,info_num=200,step_size=5, flag=-3) 42 | 43 | def test_3s(): 44 | #train_data=train_data_loader('../audio_3s/train/',batch_size=64,shuffle=False) 45 | test_data=test_data_loader('../audio_3s/test/',batch_size=1,shuffle=False) 46 | net = ConvTasNet(X=4,R=2) # 模型 47 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 48 | net = torch.nn.DataParallel(net, device_ids=[0,1]) #GPU配置 49 | net.to(device) 50 | # net = load_net(net, "Net_conv1d2_99.pkl") 51 | # test_num(net=net, testloader=test_data, device=device) 52 | for i in range(100,0,-1): 53 | path="3s_1d_"+str(i)+".pkl" 54 | my_file = Path("../pkl/"+path) 55 | if my_file.is_file(): 56 | net=load_net(net,path) 57 | test_num(net=net, testloader=test_data, device=device) 58 | 59 | def train_5s(): 60 | val_data=val_data_loader('../audio_5s/dev/',batch_size=8,shuffle=False) 61 | net = ConvTasNet(X=4,R=2) # 模型 62 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 63 | net = torch.nn.DataParallel(net, device_ids=[1]) #GPU配置 64 | net.to(device) 65 | train(net=net, epoch_num=1, trainloader="5s", valloader=val_data,batch_size=1, 66 | device=device,save_path="best_relu",info_num=10,step_size=5) 67 | 68 | def test_5s(): 69 | test_data=test_data_loader('../audio_5s/test/',batch_size=1,shuffle=False) 70 | net = ConvTasNet(X=4,R=2) # 模型 71 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 72 | net = torch.nn.DataParallel(net, device_ids=[1]) #GPU配置 73 | net.to(device) 74 | best_mse = 100 75 | best_mae = 100 76 | for i in range(100, 0, -1): 77 | path = "best_relu_" + str(i) + ".pkl" 78 | my_file = Path("../pkl/" + path) 79 | if my_file.is_file(): 80 | net = load_net(net, path) 81 | tem_mse,tem_mae = test_num(net=net, testloader=test_data, device=device) 82 | if (tem_mse < best_mse): 83 | best_mse = tem_mse 84 | if (tem_mae < best_mae): 85 | best_mae = tem_mae 86 | print(best_mse) 87 | print(best_mae) 88 | 89 | 90 | if __name__ == "__main__": 91 | data_root = '/data3/fancunhang/Depression/audio_good_without_move/AVEC2013_3s/' 92 | model_save_path = 'exp_0.002/' 93 | if not os.path.exists(model_save_path): 94 | os.mkdir(model_save_path) 95 | train_3s(data_root, model_save_path) 96 | 97 | 98 | -------------------------------------------------------------------------------- /main_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import logging 4 | import os 5 | from torchsummary import summary 6 | #from resnet import * 7 | from pathlib import Path 8 | from eval import test,test_num 9 | # from TasNet import * 10 | from dc_crn import DCCRN 11 | from train import train 12 | from model import * 13 | from dataload_vad import * 14 | classes_num=1 15 | logging.basicConfig(level=logging.DEBUG) 16 | logger=logging.getLogger(__name__) 17 | 18 | def load_net(net,model_pkl): 19 | logger.info("load:%s"%model_pkl) 20 | net.load_state_dict(torch.load(model_pkl)) 21 | return net 22 | def count_parameters(model): 23 | parameters_sum = sum(p.numel() for p in model.parameters() if p.requires_grad) 24 | print(parameters_sum) 25 | 26 | 27 | def train_3s(): 28 | torch.manual_seed(1234) 29 | torch.cuda.manual_seed(1234) 30 | torch.backends.cudnn.deterministic = True 31 | 32 | val_data=val_data_loader('../audio_3s/dev/',batch_size=64,shuffle=False) 33 | net = ConvTasNet(X=4,R=2) # 模型 34 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 35 | net = torch.nn.DataParallel(net, device_ids=[0,1]) #GPU配置 36 | net.to(device) 37 | train(net=net, epoch_num=100, trainloader="3s", valloader=val_data,batch_size=64, 38 | device=device,save_path="3s_1d",info_num=200,step_size=5) 39 | 40 | def test_3s(data_root, model_path): 41 | # test_root = os.path.join(data_root, 'test') 42 | test_data=test_data_loader(data_root,batch_size=1,shuffle=False) 43 | net = DCCRN(rnn_units=256,use_clstm=True,kernel_num=[32, 64, 128, 256, 256,256]) 44 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 45 | net = torch.nn.DataParallel(net, device_ids=[0]) #GPU配置 46 | net.to(device) 47 | net = load_net(net, model_path) 48 | test_num(net=net, testloader=test_data, device=device) 49 | # for i in range(100,0,-1): 50 | # path="3s_1d_"+str(i)+".pkl" 51 | # my_file = Path("../pkl/"+path) 52 | # if my_file.is_file(): 53 | # net=load_net(net,path) 54 | # test_num(net=net, testloader=test_data, device=device) 55 | 56 | def train_5s(): 57 | val_data=val_data_loader('../audio_5s/dev/',batch_size=8,shuffle=False) 58 | net = ConvTasNet(X=4,R=2) # 模型 59 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 60 | net = torch.nn.DataParallel(net, device_ids=[1]) #GPU配置 61 | net.to(device) 62 | train(net=net, epoch_num=1, trainloader="5s", valloader=val_data,batch_size=1, 63 | device=device,save_path="best_relu",info_num=10,step_size=5) 64 | 65 | def test_5s(): 66 | test_data=test_data_loader('../audio_5s/test/',batch_size=1,shuffle=False) 67 | net = ConvTasNet(X=4,R=2) # 模型 68 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 69 | net = torch.nn.DataParallel(net, device_ids=[1]) #GPU配置 70 | net.to(device) 71 | best_mse = 100 72 | best_mae = 100 73 | for i in range(100, 0, -1): 74 | path = "best_relu_" + str(i) + ".pkl" 75 | my_file = Path("../pkl/" + path) 76 | if my_file.is_file(): 77 | net = load_net(net, path) 78 | tem_mse,tem_mae = test_num(net=net, testloader=test_data, device=device) 79 | if (tem_mse < best_mse): 80 | best_mse = tem_mse 81 | if (tem_mae < best_mae): 82 | best_mae = tem_mae 83 | print(best_mse) 84 | print(best_mae) 85 | 86 | 87 | if __name__ == "__main__": 88 | data_root = '/data3/fancunhang/Depression/audio_good_without_move/AVEC2013_3s/test/' 89 | # model_path = 'exp/best_model.pkl' 90 | model_path = 'exp_0.002/checkpoint/best_model_epoch_75.pkl' 91 | # model_path = 'exp/checkpoint/model_epoch_96.pkl' 92 | test_3s(data_root, model_path) 93 | 94 | 95 | -------------------------------------------------------------------------------- /main_test_avg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import logging 4 | import os 5 | from torchsummary import summary 6 | #from resnet import * 7 | from pathlib import Path 8 | from eval_avg import test,test_num 9 | # from TasNet import * 10 | from dc_crn_test_avg import DCCRN 11 | from test_avg import DCCRN_avg 12 | from train import train 13 | from model import * 14 | from dataload_vad import * 15 | classes_num=1 16 | logging.basicConfig(level=logging.DEBUG) 17 | logger=logging.getLogger(__name__) 18 | 19 | def load_net(net,model_pkl): 20 | logger.info("load:%s"%model_pkl) 21 | net.load_state_dict(torch.load(model_pkl)) 22 | return net 23 | def count_parameters(model): 24 | parameters_sum = sum(p.numel() for p in model.parameters() if p.requires_grad) 25 | print(parameters_sum) 26 | 27 | 28 | def train_3s(): 29 | torch.manual_seed(1234) 30 | torch.cuda.manual_seed(1234) 31 | torch.backends.cudnn.deterministic = True 32 | 33 | val_data=val_data_loader('../audio_3s/dev/',batch_size=64,shuffle=False) 34 | net = ConvTasNet(X=4,R=2) # 模型 35 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 36 | net = torch.nn.DataParallel(net, device_ids=[0,1]) #GPU配置 37 | net.to(device) 38 | train(net=net, epoch_num=100, trainloader="3s", valloader=val_data,batch_size=64, 39 | device=device,save_path="3s_1d",info_num=200,step_size=5) 40 | 41 | def test_3s(data_root, model_path): 42 | # test_root = os.path.join(data_root, 'test') 43 | test_data=test_data_loader(data_root,batch_size=1,shuffle=False) 44 | net = DCCRN(rnn_units=256,use_clstm=True,kernel_num=[32, 64, 128, 256, 256,256]) 45 | net2 = DCCRN_avg(rnn_units=256,use_clstm=True,kernel_num=[32, 64, 128, 256, 256,256]) 46 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 47 | net = torch.nn.DataParallel(net, device_ids=[0]) #GPU配置 48 | net.to(device) 49 | net = load_net(net, model_path) 50 | 51 | net2 = torch.nn.DataParallel(net2, device_ids=[0]) #GPU配置 52 | net2.to(device) 53 | net2 = load_net(net2, model_path) 54 | test_num(net=net, net2=net2, testloader=test_data, device=device) 55 | # for i in range(100,0,-1): 56 | # path="3s_1d_"+str(i)+".pkl" 57 | # my_file = Path("../pkl/"+path) 58 | # if my_file.is_file(): 59 | # net=load_net(net,path) 60 | # test_num(net=net, testloader=test_data, device=device) 61 | 62 | def train_5s(): 63 | val_data=val_data_loader('../audio_5s/dev/',batch_size=8,shuffle=False) 64 | net = ConvTasNet(X=4,R=2) # 模型 65 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 66 | net = torch.nn.DataParallel(net, device_ids=[1]) #GPU配置 67 | net.to(device) 68 | train(net=net, epoch_num=1, trainloader="5s", valloader=val_data,batch_size=1, 69 | device=device,save_path="best_relu",info_num=10,step_size=5) 70 | 71 | def test_5s(): 72 | test_data=test_data_loader('../audio_5s/test/',batch_size=1,shuffle=False) 73 | net = ConvTasNet(X=4,R=2) # 模型 74 | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") 75 | net = torch.nn.DataParallel(net, device_ids=[1]) #GPU配置 76 | net.to(device) 77 | best_mse = 100 78 | best_mae = 100 79 | for i in range(100, 0, -1): 80 | path = "best_relu_" + str(i) + ".pkl" 81 | my_file = Path("../pkl/" + path) 82 | if my_file.is_file(): 83 | net = load_net(net, path) 84 | tem_mse,tem_mae = test_num(net=net, testloader=test_data, device=device) 85 | if (tem_mse < best_mse): 86 | best_mse = tem_mse 87 | if (tem_mae < best_mae): 88 | best_mae = tem_mae 89 | print(best_mse) 90 | print(best_mae) 91 | 92 | 93 | if __name__ == "__main__": 94 | data_root = '/data3/fancunhang/Depression/audio_good_without_move/AVEC2013_3s/test/' 95 | # model_path = 'exp/best_model.pkl' 96 | model_path = 'exp_0.002/checkpoint/best_model_epoch_75.pkl' 97 | # model_path = 'exp/checkpoint/model_epoch_96.pkl' 98 | test_3s(data_root, model_path) 99 | 100 | 101 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import math 4 | import torch.utils.checkpoint as cp 5 | from torch.nn.utils import weight_norm 6 | import torch.nn.functional as F 7 | import torch 8 | import torch.nn as nn 9 | from torchsummary import summary 10 | 11 | logging.basicConfig(level=logging.DEBUG) 12 | logger=logging.getLogger(__name__) 13 | from collections import OrderedDict 14 | def _bn_function_factory(norm, relu, conv): 15 | def bn_function(*inputs): 16 | concated_features = torch.cat(inputs, 1) 17 | bottleneck_output = conv(relu(norm(concated_features))) 18 | return bottleneck_output 19 | 20 | return bn_function 21 | class _DenseLayer(nn.Module): 22 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False): 23 | super(_DenseLayer, self).__init__() 24 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 25 | self.add_module('relu1', nn.ReLU(inplace=True)), 26 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate, 27 | kernel_size=1, stride=1, bias=False)), 28 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 29 | self.add_module('relu2', nn.ReLU(inplace=True)), 30 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 31 | kernel_size=3, stride=1, padding=1, bias=False)), 32 | self.drop_rate = drop_rate 33 | self.efficient = efficient 34 | 35 | def forward(self, *prev_features): 36 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 37 | if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 38 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 39 | else: 40 | bottleneck_output = bn_function(*prev_features) 41 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 42 | if self.drop_rate > 0: 43 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 44 | 45 | return new_features 46 | class _Transition(nn.Sequential): 47 | def __init__(self, num_input_features, num_output_features): 48 | super(_Transition, self).__init__() 49 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 50 | self.add_module('relu', nn.ReLU(inplace=True)) 51 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 52 | kernel_size=1, stride=1, bias=False)) 53 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 54 | self.add_module('pool', nn.BatchNorm2d(num_output_features)) 55 | #self.add_module('pool', nn.BatchNorm2d(num_output_features)) 56 | class _DenseBlock(nn.Module): 57 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False): 58 | super(_DenseBlock, self).__init__() 59 | for i in range(num_layers): 60 | layer = _DenseLayer( 61 | num_input_features + i * growth_rate, 62 | growth_rate=growth_rate, 63 | bn_size=bn_size, 64 | drop_rate=drop_rate, 65 | efficient=efficient, 66 | ) 67 | self.add_module('denselayer%d' % (i + 1), layer) 68 | 69 | def forward(self, init_features): 70 | features = [init_features] 71 | for name, layer in self.named_children(): 72 | new_features = layer(*features) 73 | features.append(new_features) 74 | return torch.cat(features, 1) 75 | class DenseNet(nn.Module): 76 | r"""Densenet-BC model class, based on 77 | `"Densely Connected Convolutional Networks" ` 78 | Args: 79 | growth_rate (int) - how many filters to add each layer (`k` in paper) 80 | block_config (list of 3 or 4 ints) - how many layers in each pooling block 81 | num_init_features (int) - the number of filters to learn in the first convolution layer 82 | bn_size (int) - multiplicative factor for number of bottle neck layers 83 | (i.e. bn_size * k features in the bottleneck layer) 84 | drop_rate (float) - dropout rate after each dense layer 85 | num_classes (int) - number of classification classes 86 | small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger. 87 | efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower. 88 | """ 89 | def __init__(self, growth_rate=3, block_config=(3, 6, 12), compression=0.5, 90 | num_init_features=6, bn_size=4, drop_rate=0.5, 91 | classes_num=1,input_channel=1, small_inputs=False, efficient=False): 92 | 93 | super(DenseNet, self).__init__() 94 | assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1' 95 | 96 | # First convolution 97 | if small_inputs: 98 | self.features = nn.Sequential(OrderedDict([ 99 | ('conv0', nn.Conv2d(input_channel, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)), ])) 100 | else: 101 | self.features = nn.Sequential(OrderedDict([ 102 | ('conv0', nn.Conv2d(input_channel, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),])) 103 | self.features.add_module('norm0', nn.BatchNorm2d(num_init_features)) 104 | self.features.add_module('relu0', nn.ReLU(inplace=True)) 105 | self.features.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1, 106 | ceil_mode=False)) 107 | 108 | # Each denseblock 109 | num_features = num_init_features 110 | for i, num_layers in enumerate(block_config): 111 | block = _DenseBlock( 112 | num_layers=num_layers, 113 | num_input_features=num_features, 114 | bn_size=bn_size, 115 | growth_rate=growth_rate, 116 | drop_rate=drop_rate, 117 | efficient=efficient, 118 | ) 119 | self.features.add_module('denseblock%d' % (i + 1), block) 120 | num_features = num_features + num_layers * growth_rate 121 | if i != len(block_config) - 1: 122 | trans = _Transition(num_input_features=num_features, 123 | num_output_features=int(num_features * compression)) 124 | self.features.add_module('transition%d' % (i + 1), trans) 125 | num_features = int(num_features * compression) 126 | 127 | # Final batch norm 128 | self.features.add_module('norm_final', nn.BatchNorm2d(num_features)) 129 | 130 | # Linear layer 131 | self.classifier = nn.Linear(num_features, classes_num) 132 | self.relu=nn.ReLU(inplace=True) 133 | self.adaptive_avg_pool2d=nn.AdaptiveAvgPool2d((1, 1)) 134 | # Initialization 135 | for name, param in self.named_parameters(): 136 | if 'conv' in name and 'weight' in name: 137 | n = param.size(0) * param.size(2) * param.size(3) 138 | param.data.normal_().mul_(math.sqrt(2. / n)) 139 | elif 'norm' in name and 'weight' in name: 140 | param.data.fill_(1) 141 | elif 'norm' in name and 'bias' in name: 142 | param.data.fill_(0) 143 | elif 'classifier' in name and 'bias' in name: 144 | param.data.fill_(0) 145 | 146 | def forward(self, x): 147 | features = self.features(x) 148 | out = self.relu(features) 149 | out = self.adaptive_avg_pool2d(out) 150 | out = torch.flatten(out, 1) 151 | out = self.classifier(out) 152 | return out 153 | 154 | class SelfAttention(nn.Module): 155 | 156 | def __init__(self,len): 157 | super(SelfAttention, self).__init__() 158 | self.query = nn.Linear(len, len) # 128, 128 159 | self.key = nn.Linear(len, len) 160 | self.value = nn.Linear(len, len) 161 | 162 | def forward(self, x): 163 | x = x.to(torch.float32) 164 | #x=x.permute(0, 2, 1).contiguous() 165 | q, k, v = self.query(x), self.key(x), self.value(x) 166 | y=torch.bmm(q, k.permute(0, 2, 1).contiguous()) 167 | beta=F.softmax(y,dim=1) 168 | y=torch.bmm(beta, v) 169 | #y = y.permute(0, 2, 1).contiguous() 170 | return y 171 | 172 | class Conv3d_Net(nn.Module): 173 | def __init__(self,channel,num,len): 174 | super(Conv3d_Net, self).__init__() 175 | self.bn1 = nn.BatchNorm2d(3) 176 | self.bn2 = nn.BatchNorm2d(32) 177 | self.bn3 = nn.BatchNorm2d(64) 178 | self.conv1 = nn.Conv2d(3,32,kernel_size=3,stride=1,padding=1) 179 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 180 | self.relu= nn.ReLU() 181 | self.maxpool1=nn.AdaptiveAvgPool2d((int(num/2),int(len/2))) 182 | self.maxpool2 = nn.AdaptiveAvgPool2d((int(num / 4), int(len / 4))) 183 | self.flatten = nn.Flatten(2) 184 | self.linear1 = nn.Linear(1240 if num==498 else 740,1) 185 | self.linear2 = nn.Linear(64, 1) 186 | self.dropout = nn.Dropout(p=0.5) 187 | 188 | def forward(self, x): #(3,498,40) 189 | x = x.float() 190 | y = self.bn1(x) 191 | y = self.relu(self.conv1(y)) 192 | y = self.dropout(y) 193 | y = self.bn2(y) 194 | y = self.maxpool1(y) 195 | y = self.relu(self.conv2(y)) 196 | y = self.dropout(y) 197 | y = self.bn3(y) 198 | y = self.maxpool2(y) 199 | y = self.flatten(y) 200 | y = self.linear1(y) 201 | y = torch.squeeze(y) 202 | y = self.linear2(y) 203 | return y 204 | 205 | class Conv2d_Net(nn.Module): 206 | def __init__(self,num,len): 207 | super(Conv2d_Net, self).__init__() 208 | self.bn1 = nn.BatchNorm2d(num) 209 | self.bn2 = nn.BatchNorm2d(16) 210 | self.bn3 = nn.BatchNorm2d(64) 211 | self.conv1 = nn.Conv2d(1,16,kernel_size=5,stride=2,padding=2) 212 | self.conv2 = nn.Conv2d(16, 64, kernel_size=5, stride=2, padding=2) 213 | self.relu= nn.ReLU() 214 | self.maxpool1=nn.AdaptiveMaxPool2d((int(num/4),int(len/4))) 215 | self.maxpool2 = nn.AdaptiveMaxPool2d((int(num / 16), int(len / 16))) 216 | self.flatten = nn.Flatten(2) 217 | self.linear1 = nn.Linear(185, 1) 218 | self.linear2 = nn.Linear(64, 1) 219 | self.dropout = nn.Dropout(p=0.5) 220 | 221 | def forward(self, x): #(498,40) 222 | x = x.float() 223 | y = self.bn1(x) 224 | y = y.unsqueeze(0) 225 | y = self.relu(self.conv1(y)) 226 | 227 | y = self.dropout(y) 228 | y = self.bn2(y) 229 | y = self.maxpool1(y) 230 | y = self.relu(self.conv2(y)) 231 | y = self.dropout(y) 232 | y = self.bn3(y) 233 | y = self.maxpool2(y) 234 | y = self.flatten(y) 235 | y = self.linear1(y) 236 | y = torch.squeeze(y) 237 | y = self.linear2(y) 238 | return y 239 | 240 | class Conv_regression(nn.Module): 241 | def __init__(self,channel,num,len): #(1,100,240) 242 | super(Conv_regression, self).__init__() 243 | self.conv1 = nn.Conv2d(channel,32,kernel_size=3,stride=1,padding=1) 244 | self.bn1 = nn.BatchNorm2d(32) 245 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 246 | self.bn2 = nn.BatchNorm2d(64) 247 | self.relu= nn.ReLU() 248 | self.avgpool1=nn.AdaptiveAvgPool2d((int(num/2),int(len/2))) 249 | self.avgpool2 = nn.AdaptiveAvgPool2d((int(num / 4), int(len / 4))) 250 | self.flatten = nn.Flatten(2) 251 | self.linear1 = nn.Linear(int(num/4)*int(len/4),1) 252 | self.linear2 = nn.Linear(64, 1) 253 | self.dropout = nn.Dropout(p=0.5) 254 | 255 | def forward(self, x): #(3,498,40) 256 | x = x.float() 257 | y = self.relu(self.bn1(self.conv1(x))) 258 | #y = self.dropout(y) 259 | y = self.avgpool1(y) 260 | y = self.relu(self.bn2(self.conv2(y))) 261 | #y = self.dropout(y) 262 | y = self.avgpool2(y) 263 | y = self.flatten(y) 264 | y = self.linear1(y) 265 | y = torch.squeeze(y) 266 | y = self.linear2(y) 267 | return y 268 | 269 | class Conv_operation(nn.Module): 270 | def __init__(self): 271 | super(Conv_operation, self).__init__() 272 | self.bn1=nn.BatchNorm1d(1) 273 | self.con1d=nn.Conv1d(1, 240, kernel_size=240, stride=1, padding=1) 274 | #self.con2d=nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1) 275 | #self.con2d2 = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1) 276 | self.con1d2 = nn.Conv1d(32, 1, kernel_size=3, stride=1, padding=1) 277 | 278 | def forward(self, x): 279 | #y=self.bn1(x) 280 | y=self.con1d(x) 281 | y=y.unsqueeze(1) 282 | y=self.con2d(y) 283 | y=self.con2d2(y) 284 | y = y.squeeze(1) 285 | y = self.con1d2(y) 286 | return y 287 | 288 | class Sample_Net_dense(nn.Module): 289 | def __init__(self,len): 290 | super(Sample_Net_dense, self).__init__() 291 | self.densenet = DenseNet(input_channel=1) 292 | #self.conv_regression = Conv_regression(1, 100, int(len / 100)) 293 | self.conv_operation = Conv_operation() 294 | self.len = int(len / 100) 295 | self.bn = nn.BatchNorm1d(1) 296 | self.con2d = nn.Conv2d(1, int(len / 100), kernel_size=(1, int(len / 100)), stride=1, padding=0) 297 | 298 | def forward(self, x): 299 | x = self.bn(x) 300 | y = x.view(-1, 100, self.len) 301 | y = y.unsqueeze(1) 302 | y = self.con2d(y) 303 | y = y.permute(0, 3, 2, 1) 304 | # y = self.conv_operation(y) 305 | y = self.densenet(y) 306 | return y 307 | 308 | class Sample_Net_conv2d(nn.Module): 309 | def __init__(self,len): 310 | super(Sample_Net_conv2d, self).__init__() 311 | self.densenet=DenseNet(input_channel=1) 312 | self.conv_regression=Conv_regression(1,100,int(len/100)) 313 | self.conv_operation=Conv_operation() 314 | self.len=int(len/100) 315 | self.bn=nn.BatchNorm1d(1) 316 | self.con2d = nn.Conv2d(1,int(len/100),kernel_size=(1,int(len/100)), stride=1, padding=0) 317 | 318 | def forward(self, x): 319 | x = self.bn(x) 320 | y = x.view(-1, 100, self.len) 321 | y = y.unsqueeze(1) 322 | y = self.con2d(y) 323 | y = y.permute(0, 3, 2,1) 324 | y = self.conv_regression(y) 325 | return y 326 | 327 | class Sample_Net_conv1d(nn.Module): 328 | def __init__(self,sum_len): 329 | super(Sample_Net_conv1d, self).__init__() 330 | len = int(sum_len / 100) 331 | self.densenet=DenseNet(input_channel=1) 332 | self.conv_regression=Conv_regression(1,100,len) 333 | self.conv_operation=Conv_operation() 334 | self.len=int(sum_len/100) 335 | self.bn=nn.BatchNorm1d(1) 336 | self.con1d = nn.Conv1d(1,len,kernel_size=len, stride=len) 337 | 338 | def forward(self, x): 339 | #print(x.size()) 340 | x = self.bn(x) 341 | y = self.con1d(x) 342 | y = y.unsqueeze(1) 343 | y = y.permute(0,1, 3,2).contiguous() #(b,1,100,240) 344 | y = self.conv_regression(y) 345 | return y 346 | 347 | class Sample_Net_conv1d2(nn.Module): 348 | def __init__(self,sum_len): 349 | super(Sample_Net_conv1d2, self).__init__() 350 | len = int(sum_len / 100) 351 | self.densenet=DenseNet(input_channel=1) 352 | self.conv_regression=Conv_regression(1,100,len) 353 | self.conv_operation=Conv_operation() 354 | self.len=int(sum_len/100) 355 | self.bn=nn.BatchNorm1d(1) 356 | self.relu=nn.ReLU() 357 | self.con1d = nn.Conv1d(1,64,kernel_size=len, stride=len) 358 | self.con1d2 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1) 359 | self.con1d3 = nn.Conv1d(128, len, kernel_size=3, stride=1, padding=1) 360 | 361 | def forward(self, x): 362 | #print(x.size()) 363 | x = self.bn(x) 364 | y = self.relu(self.con1d(x)) 365 | y = self.relu(self.con1d2(y)) 366 | y = self.relu(self.con1d3(y)) 367 | y = y.unsqueeze(1) 368 | y = y.permute(0,1, 3,2).contiguous() #(b,1,100,240) 369 | y = self.conv_regression(y) 370 | return y 371 | 372 | class Sample_Net_conv1d5(nn.Module): 373 | def __init__(self,sum_len): 374 | super(Sample_Net_conv1d5, self).__init__() 375 | self.densenet=DenseNet(input_channel=1) 376 | self.conv_regression=Conv_regression(1,100,400) 377 | self.conv_operation=Conv_operation() 378 | self.len=int(sum_len/100) 379 | self.bn = nn.BatchNorm1d(1) 380 | self.bn1 = nn.BatchNorm1d(32) 381 | self.bn2 = nn.BatchNorm1d(64) 382 | self.bn3 = nn.BatchNorm1d(128) 383 | self.bn4 = nn.BatchNorm1d(256) 384 | self.bn5 = nn.BatchNorm1d(512) 385 | self.bn6 = nn.BatchNorm1d(100) 386 | self.relu=nn.ReLU() 387 | self.conv = nn.Conv1d(1,32,kernel_size=200, stride=100,padding=50) 388 | self.con1d1 = nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1) 389 | self.con1d2 = nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1) 390 | self.con1d3 = nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1) 391 | self.con1d4 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1) 392 | self.con1d5 = nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1) 393 | self.con1d6 = nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1) 394 | self.con1d7 = nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1) 395 | self.con1d8 = nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1) 396 | self.con1d9 = nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1) 397 | self.con1d10 = nn.Conv1d(512, 100, kernel_size=3, stride=1, padding=1) 398 | 399 | def forward(self, x): 400 | #print(x.size()) 401 | x = self.bn(x) 402 | y = self.relu(self.bn1(self.conv(x))) 403 | y = self.relu(self.bn1(self.con1d1(y))) 404 | y = self.relu(self.bn2(self.con1d2(y))) 405 | y = self.relu(self.bn2(self.con1d3(y))) 406 | y = self.relu(self.bn3(self.con1d4(y))) 407 | y = self.relu(self.bn3(self.con1d5(y))) 408 | y = self.relu(self.bn4(self.con1d6(y))) 409 | y = self.relu(self.bn4(self.con1d7(y))) 410 | y = self.relu(self.bn5(self.con1d8(y))) 411 | y = self.relu(self.bn5(self.con1d9(y))) 412 | y = self.relu(self.bn6(self.con1d10(y))) 413 | y = y.unsqueeze(1) 414 | #y = y.permute(0,1, 3,2).contiguous() #(b,1,100,240) 415 | y = self.conv_regression(y) 416 | return y 417 | 418 | class Sample_Net_conv1d6(nn.Module): 419 | def __init__(self,sum_len): 420 | super(Sample_Net_conv1d6, self).__init__() 421 | self.densenet=DenseNet(input_channel=1) 422 | self.conv_regression=Conv_regression(1,400,512) 423 | self.conv_operation=Conv_operation() 424 | self.len=int(sum_len/100) 425 | self.bn = nn.BatchNorm1d(1) 426 | self.bn1 = nn.BatchNorm1d(32) 427 | self.bn2 = nn.BatchNorm1d(64) 428 | self.bn3 = nn.BatchNorm1d(128) 429 | self.bn4 = nn.BatchNorm1d(256) 430 | self.bn5 = nn.BatchNorm1d(512) 431 | self.relu=nn.ReLU() 432 | self.conv = nn.Conv1d(1,32,kernel_size=200, stride=100,padding=50) 433 | self.con1d1 = nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1) 434 | self.con1d2 = nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1) 435 | self.con1d3 = nn.Conv1d(64, 64, kernel_size=3, stride=1, padding=1) 436 | self.con1d4 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1) 437 | self.con1d5 = nn.Conv1d(128, 128, kernel_size=3, stride=1, padding=1) 438 | self.con1d6 = nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1) 439 | self.con1d7 = nn.Conv1d(256, 256, kernel_size=3, stride=1, padding=1) 440 | self.con1d8 = nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1) 441 | self.con1d9 = nn.Conv1d(512, 512, kernel_size=3, stride=1, padding=1) 442 | 443 | def forward(self, x): 444 | #print(x.size()) 445 | x = self.bn(x) 446 | y = self.relu(self.bn1(self.conv(x))) 447 | y = self.relu(self.bn1(self.con1d1(y))) 448 | y = self.relu(self.bn2(self.con1d2(y))) 449 | y = self.relu(self.bn2(self.con1d3(y))) 450 | y = self.relu(self.bn3(self.con1d4(y))) 451 | y = self.relu(self.bn3(self.con1d5(y))) 452 | y = self.relu(self.bn4(self.con1d6(y))) 453 | y = self.relu(self.bn4(self.con1d7(y))) 454 | y = self.relu(self.bn5(self.con1d8(y))) 455 | y = self.relu(self.bn5(self.con1d9(y))) 456 | y = y.unsqueeze(1) 457 | y = y.permute(0,1, 3,2).contiguous() #(b,1,100,240) 458 | y = self.conv_regression(y) 459 | return y 460 | 461 | # class Sample_Net_conv1d6(nn.Module): 462 | # def __init__(self,sum_len): 463 | # super(Sample_Net_conv1d6, self).__init__() 464 | # len = int(sum_len / 100) 465 | # self.densenet=DenseNet(input_channel=1) 466 | # self.conv_regression=Conv_regression(1,100,len) 467 | # self.conv_operation=Conv_operation() 468 | # self.len=512 469 | # self.bn=nn.BatchNorm1d(1) 470 | # self.bn1 = nn.BatchNorm1d(32) 471 | # self.bn2 = nn.BatchNorm1d(64) 472 | # self.bn3 = nn.BatchNorm1d(128) 473 | # self.bn4 = nn.BatchNorm1d(256) 474 | # self.bn5 = nn.BatchNorm1d(512) 475 | # self.bn6 = nn.BatchNorm1d(len) 476 | # self.relu=nn.ReLU() 477 | # self.con1d = nn.Conv1d(1,32,kernel_size=400, stride=200,padding=100) 478 | # 479 | # self.con1d1 = nn.Conv1d(32, 32, kernel_size=3, stride=1, padding=1) 480 | # self.con1d2 = nn.Conv1d(32, 64, kernel_size=3, stride=1, padding=1) 481 | # self.rescon1 = nn.Conv1d(32,64, kernel_size=1, stride=1, padding=0) 482 | # 483 | # self.con1d3 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1) 484 | # self.con1d4 = nn.Conv1d(128, 256, kernel_size=3, stride=1, padding=1) 485 | # self.rescon2 = nn.Conv1d(64, 256, kernel_size=1, stride=1, padding=0) 486 | # 487 | # self.con1d5 = nn.Conv1d(256, 512, kernel_size=3, stride=1, padding=1) 488 | # self.con1d6 = nn.Conv1d(512, len, kernel_size=3, stride=1, padding=1) 489 | # self.rescon3 = nn.Conv1d(256, len, kernel_size=1, stride=1, padding=0) 490 | # self.dropout=nn.Dropout(0.5) 491 | # #self.con1d7 = nn.Conv1d(32, len, kernel_size=3, stride=1, padding=1) 492 | # 493 | # def forward(self, x): 494 | # #print(x.size()) 495 | # x = self.bn(x) 496 | # y = self.relu(self.bn1(self.con1d(x))) 497 | # 498 | # res = y 499 | # y = self.relu(self.bn1(self.con1d1(y))) 500 | # y = self.relu(self.bn2(self.con1d2(y))) 501 | # res = self.relu((self.rescon1(res))) 502 | # y=y+res 503 | # 504 | # res=y 505 | # y = self.relu(self.bn3(self.con1d3(y))) 506 | # y = self.relu(self.bn4(self.con1d4(y))) 507 | # res = self.relu((self.rescon2(res))) 508 | # y=y+res 509 | # 510 | # res = y 511 | # y = self.relu(self.bn5(self.con1d5(y))) 512 | # y = self.relu(self.bn6(self.con1d6(y))) 513 | # res = self.relu((self.rescon3(res))) 514 | # y = y + res 515 | # 516 | # #y = self.relu((self.con1d7(y))) 517 | # 518 | # y = y.unsqueeze(1) 519 | # y = y.permute(0,1, 3,2).contiguous() #(b,1,100,240) 520 | # y = self.conv_regression(y) 521 | # return y 522 | 523 | class Chomp1d(nn.Module): 524 | def __init__(self, chomp_size): 525 | super(Chomp1d, self).__init__() 526 | self.chomp_size = chomp_size 527 | 528 | def forward(self, x): 529 | return x[:, :, :-self.chomp_size].contiguous() 530 | 531 | 532 | class TemporalBlock(nn.Module): 533 | def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): 534 | super(TemporalBlock, self).__init__() 535 | self.conv1 = weight_norm(nn.Conv1d(n_inputs, n_outputs, kernel_size, 536 | stride=stride, padding=padding, dilation=dilation)) 537 | self.chomp1 = Chomp1d(padding) 538 | self.relu1 = nn.ReLU() 539 | self.dropout1 = nn.Dropout(dropout) 540 | 541 | self.conv2 = weight_norm(nn.Conv1d(n_outputs, n_outputs, kernel_size, 542 | stride=stride, padding=padding, dilation=dilation)) 543 | self.chomp2 = Chomp1d(padding) 544 | self.relu2 = nn.ReLU() 545 | self.dropout2 = nn.Dropout(dropout) 546 | 547 | self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, 548 | self.conv2, self.chomp2, self.relu2, self.dropout2) 549 | self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None 550 | self.relu = nn.ReLU() 551 | self.init_weights() 552 | 553 | def init_weights(self): 554 | self.conv1.weight.data.normal_(0, 0.01) 555 | self.conv2.weight.data.normal_(0, 0.01) 556 | if self.downsample is not None: 557 | self.downsample.weight.data.normal_(0, 0.01) 558 | 559 | def forward(self, x): 560 | out = self.net(x) 561 | res = x if self.downsample is None else self.downsample(x) 562 | return self.relu(out + res) 563 | 564 | 565 | class TemporalConvNet(nn.Module): 566 | def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2): 567 | super(TemporalConvNet, self).__init__() 568 | layers = [] 569 | num_levels = len(num_channels) 570 | for i in range(num_levels): 571 | dilation_size = 2 ** i 572 | in_channels = num_inputs if i == 0 else num_channels[i-1] 573 | out_channels = num_channels[i] 574 | layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, 575 | padding=(kernel_size-1) * dilation_size, dropout=dropout)] 576 | 577 | self.network = nn.Sequential(*layers) 578 | 579 | def forward(self, x): 580 | return self.network(x) 581 | 582 | 583 | class TCN(nn.Module): 584 | def __init__(self,sum_len): 585 | super(TCN, self).__init__() 586 | len = int(sum_len / 100) 587 | self.tcn=TemporalConvNet(1,[32,64,128,64,32,1]) 588 | self.conv_regression=Conv_regression(1,100,len) 589 | self.con1d = nn.Conv1d(1,512,kernel_size=len, stride=(int)(len/2)) 590 | 591 | 592 | def forward(self, x): 593 | y=self.tcn(x) 594 | y=self.con1d(y) 595 | y = y.unsqueeze(1) 596 | y = y.permute(0,1, 3,2).contiguous() #(b,1,100,240) 597 | y = self.conv_regression(y) 598 | return y 599 | 600 | 601 | if __name__ == "__main__": 602 | net=Conv_regression(1, 64, 7999) 603 | torch.cuda.set_device(5) 604 | # a=torch.randn(1,1,40000) 605 | # b=net(a) 606 | net=net.cuda() 607 | summary(net,(1,64,7999)) 608 | -------------------------------------------------------------------------------- /network_selfattention: -------------------------------------------------------------------------------- 1 | inputs: torch.Size([32, 243, 256]) 2 | att_weights: torch.Size([32, 256, 1]) 3 | weights: torch.Size([32, 243, 1]) 4 | attentions: torch.Size([32, 243]) 5 | weighted: torch.Size([32, 243, 256]) 6 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import logging 3 | import torch.nn as nn 4 | import torch 5 | from torchsummary import summary 6 | 7 | logging.basicConfig(level=logging.DEBUG) 8 | logger=logging.getLogger(__name__) 9 | 10 | class BasicBlock(nn.Module): 11 | def __init__(self,in_channel,out_channel): 12 | super(BasicBlock, self).__init__() 13 | if(in_channel==out_channel): 14 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size=3, stride=1, padding=1) 15 | self.downsample = None 16 | else: 17 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size=5, stride=4, padding=2) 18 | self.downsample = nn.Sequential( 19 | nn.Conv1d(in_channel,out_channel,kernel_size=1, stride=4, padding=0), 20 | nn.BatchNorm1d(out_channel) 21 | ) 22 | self.bn = nn.BatchNorm1d(out_channel) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.conv2 = nn.Conv1d(out_channel, out_channel, kernel_size=3, stride=1, padding=1) 25 | 26 | def forward(self, x): 27 | res = x 28 | out = self.conv1(x) 29 | out = self.bn(out) 30 | out = self.relu(out) 31 | out = self.conv2(out) 32 | out = self.bn(out) 33 | if self.downsample is not None: 34 | res = self.downsample(x) 35 | out += res 36 | out = self.relu(out) 37 | return out 38 | 39 | class BasicBlock_avg(nn.Module): 40 | def __init__(self,in_channel,out_channel): 41 | super(BasicBlock_avg, self).__init__() 42 | if(in_channel==out_channel): 43 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size=3, stride=1, padding=1) 44 | self.downsample = None 45 | else: 46 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size=5, stride=4, padding=1) 47 | self.downsample = nn.Sequential( 48 | nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=1, padding=0), 49 | nn.AvgPool1d(4), 50 | nn.BatchNorm1d(out_channel) 51 | ) 52 | self.bn = nn.BatchNorm1d(out_channel) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = nn.Conv1d(out_channel, out_channel, kernel_size=3, stride=1, padding=1) 55 | 56 | def forward(self, x): 57 | res = x 58 | out = self.conv1(x) 59 | out = self.bn(out) 60 | out = self.relu(out) 61 | out = self.conv2(out) 62 | out = self.bn(out) 63 | if self.downsample is not None: 64 | res = self.downsample(x) 65 | out += res 66 | out = self.relu(out) 67 | return out 68 | 69 | class BasicBlock_max(nn.Module): 70 | def __init__(self,in_channel,out_channel): 71 | super(BasicBlock_max, self).__init__() 72 | if(in_channel==out_channel): 73 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size=3, stride=1, padding=1) 74 | self.downsample = None 75 | else: 76 | self.conv1 = nn.Conv1d(in_channel, out_channel, kernel_size=5, stride=4, padding=1) 77 | self.downsample = nn.Sequential( 78 | nn.Conv1d(in_channel, out_channel, kernel_size=1, stride=1, padding=0), 79 | nn.MaxPool1d(4), 80 | nn.BatchNorm1d(out_channel) 81 | ) 82 | self.bn = nn.BatchNorm1d(out_channel) 83 | self.relu = nn.ReLU(inplace=True) 84 | self.conv2 = nn.Conv1d(out_channel, out_channel, kernel_size=3, stride=1, padding=1) 85 | 86 | 87 | 88 | def forward(self, x): 89 | res = x 90 | out = self.conv1(x) 91 | out = self.bn(out) 92 | out = self.relu(out) 93 | out = self.conv2(out) 94 | out = self.bn(out) 95 | if self.downsample is not None: 96 | res = self.downsample(x) 97 | out += res 98 | out = self.relu(out) 99 | return out 100 | 101 | class Conv_regression(nn.Module): 102 | def __init__(self,channel,num,len): #(1,100,240) 103 | super(Conv_regression, self).__init__() 104 | self.conv1 = nn.Conv2d(channel,32,kernel_size=3,stride=1,padding=1) 105 | self.bn1 = nn.BatchNorm2d(32) 106 | self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 107 | self.bn2 = nn.BatchNorm2d(64) 108 | self.relu= nn.ReLU() 109 | self.avgpool1=nn.AdaptiveAvgPool2d((int(num/2),int(len/2))) 110 | self.avgpool2 = nn.AdaptiveAvgPool2d((int(num / 4), int(len / 4))) 111 | self.flatten = nn.Flatten(2) 112 | self.linear1 = nn.Linear(int(num/4)*int(len/4),1) 113 | self.linear2 = nn.Linear(64, 1) 114 | self.dropout = nn.Dropout(p=0.5) 115 | 116 | def forward(self, x): #(3,498,40) 117 | x = x.float() 118 | y = self.relu(self.bn1(self.conv1(x))) 119 | #y = self.dropout(y) 120 | y = self.avgpool1(y) 121 | y = self.relu(self.bn2(self.conv2(y))) 122 | #y = self.dropout(y) 123 | y = self.avgpool2(y) 124 | y = self.flatten(y) 125 | y = self.linear1(y) 126 | y = torch.squeeze(y) 127 | y = self.linear2(y) 128 | return y 129 | 130 | class Resnet(nn.Module): 131 | def __init__(self): 132 | super(Resnet, self).__init__() 133 | self.conv1 = nn.Conv1d(1,200,kernel_size=300, stride=200,padding=100) 134 | self.block = BasicBlock(1,32) 135 | self.block2 = BasicBlock(32, 32) 136 | self.block3 = BasicBlock(32, 64) 137 | self.block4 = BasicBlock(64, 64) 138 | self.block5 = BasicBlock(64, 128) 139 | self.block6 = BasicBlock(128, 128) 140 | self.block7 = BasicBlock(128, 256) 141 | self.block8 = BasicBlock(256, 256) 142 | self.conv_regression=Conv_regression(1,128,625) 143 | self.avgpooling = nn.AdaptiveAvgPool1d(1) 144 | self.flatten = nn.Flatten() 145 | self.linear = nn.Linear(128,1) 146 | 147 | def forward(self, x): 148 | y = self.block(x) 149 | y = self.block2(y) 150 | y = self.block3(y) 151 | y = self.block4(y) 152 | y = self.block5(y) 153 | y = self.block6(y) 154 | #y = self.block7(y) 155 | #y = self.block8(y) 156 | y = y.unsqueeze(1) 157 | #y = y.permute(0, 1, 3, 2).contiguous() # (b,1,100,240) 158 | y = self.conv_regression(y) 159 | return y 160 | 161 | class Resnet_avg(nn.Module): 162 | def __init__(self): 163 | super(Resnet_avg, self).__init__() 164 | self.conv1 = nn.Conv1d(1,200,kernel_size=300, stride=200,padding=100) 165 | self.block = BasicBlock_avg(1,32) 166 | self.block2 = BasicBlock_avg(32, 32) 167 | self.block3 = BasicBlock_avg(32, 64) 168 | self.block4 = BasicBlock_avg(64, 64) 169 | self.block5 = BasicBlock_avg(64, 128) 170 | self.block6 = BasicBlock_avg(128, 128) 171 | self.block7 = BasicBlock_avg(128, 256) 172 | self.block8 = BasicBlock_avg(256, 256) 173 | self.block9 = BasicBlock_avg(256, 512) 174 | self.block10 = BasicBlock_avg(512, 512) 175 | self.avgpooling = nn.AdaptiveAvgPool1d(1) 176 | self.flatten = nn.Flatten() 177 | self.linear = nn.Linear(512,1) 178 | 179 | def forward(self, x): 180 | y = self.block(x) 181 | y = self.block2(y) 182 | y = self.block3(y) 183 | y = self.block4(y) 184 | y = self.block5(y) 185 | y = self.block6(y) 186 | y = self.block7(y) 187 | y = self.block8(y) 188 | y = self.block9(y) 189 | y = self.block10(y) 190 | y = self.avgpooling(y) 191 | y = self.flatten(y) 192 | y = self.linear(y) 193 | return y 194 | 195 | class Resnet_max(nn.Module): 196 | def __init__(self): 197 | super(Resnet_max, self).__init__() 198 | self.conv1 = nn.Conv1d(1,200,kernel_size=300, stride=200,padding=100) 199 | self.block = BasicBlock_max(1,32) 200 | self.block2 = BasicBlock_max(32, 32) 201 | self.block3 = BasicBlock_max(32, 64) 202 | self.block4 = BasicBlock_max(64, 64) 203 | self.block5 = BasicBlock_max(64, 128) 204 | self.block6 = BasicBlock_max(128, 128) 205 | self.block7 = BasicBlock_max(128, 256) 206 | self.block8 = BasicBlock_max(256, 256) 207 | self.block9 = BasicBlock_max(256, 512) 208 | self.block10 = BasicBlock_max(512, 512) 209 | self.avgpooling = nn.AdaptiveAvgPool1d(1) 210 | self.flatten = nn.Flatten() 211 | self.linear = nn.Linear(512,1) 212 | 213 | def forward(self, x): 214 | y = self.block(x) 215 | y = self.block2(y) 216 | y = self.block3(y) 217 | y = self.block4(y) 218 | y = self.block5(y) 219 | y = self.block6(y) 220 | y = self.block7(y) 221 | y = self.block8(y) 222 | y = self.block9(y) 223 | y = self.block10(y) 224 | y = self.avgpooling(y) 225 | y = self.flatten(y) 226 | y = self.linear(y) 227 | return y 228 | 229 | if __name__ == "__main__": 230 | net=Resnet() 231 | a=torch.randn(1,1,40000) 232 | b=net(a) 233 | #net=net.cuda() 234 | #summary(net,(1,40000)) 235 | 236 | -------------------------------------------------------------------------------- /result: -------------------------------------------------------------------------------- 1 | ******************lr 0.002****************** 2 | best_model_epoch_8 3 | INFO:eval:9.24515682715933 4 | INFO:eval:7.132399082183838 5 | 6 | best_model_epoch_12 7 | INFO:eval:9.568244989165777 8 | INFO:eval:7.3114013671875 9 | 10 | best_model_epoch_20 11 | INFO:eval:9.537200322188992 12 | INFO:eval:7.163089752197266 13 | 14 | best_model_epoch_59 15 | INFO:eval:9.54351271387428 16 | INFO:eval:6.976365566253662 17 | 18 | best_model_epoch_75 19 | INFO:eval:9.28349272885641 20 | INFO:eval:6.7915120124816895 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | ############resnet############## 31 | best_model_epoch_58 32 | INFO:eval:9.41985929916302 33 | INFO:eval:7.188505172729492 34 | 35 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eu 4 | 5 | gpu_id="2" 6 | 7 | CUDA_VISIBLE_DEVICES=$gpu_id python -u main.py 8 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import wave 4 | #from model import * 5 | import torch 6 | import torch.nn as nn 7 | from scipy import signal 8 | 9 | def setup_seed(seed): 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) 12 | np.random.seed(seed) 13 | #torch.random.seed(seed) 14 | torch.backends.cudnn.deterministic = True 15 | 16 | #con1d = nn.Conv1d(240, 240, kernel_size=100, stride=1, padding=0) 17 | 18 | class Sample_Net_conv(nn.Module): 19 | def __init__(self): 20 | super(Sample_Net_conv, self).__init__() 21 | #self.densenet=DenseNet(input_channel=1) 22 | #self.conv_regression=Conv_regression(1,100,240) 23 | #self.conv_operation=Conv_operation() 24 | #self.len=len 25 | self.bn=nn.BatchNorm1d(1) 26 | self.con2d = nn.Conv2d(1,240,kernel_size=(1,240), stride=1, padding=0) 27 | 28 | def forward(self, x): 29 | x=self.bn(x) 30 | y=x.view(-1,100,240) 31 | y = y.unsqueeze(1) 32 | y=self.con2d(y) 33 | y = y.squeeze(-1) 34 | y = y.permute(0,2,1) 35 | 36 | #y = self.conv_operation(y) 37 | #y = self.conv_regression(y) 38 | return y 39 | 40 | a=torch.randn(1,1,24000) 41 | #print(a) 42 | net=Sample_Net_conv() 43 | b=net(a) 44 | print(b.size()) 45 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eu 4 | 5 | gpu_id="1" 6 | 7 | # CUDA_VISIBLE_DEVICES=$gpu_id python -u main_test.py 8 | # CUDA_VISIBLE_DEVICES=$gpu_id python -u test_attention_visual.py 9 | CUDA_VISIBLE_DEVICES=$gpu_id python -u main_test_avg.py 10 | -------------------------------------------------------------------------------- /test_attention_visual.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import logging 4 | import os 5 | from torchsummary import summary 6 | #from resnet import * 7 | from pathlib import Path 8 | from eval import test,test_num 9 | from dc_crn_test_attention import DCCRN 10 | from train import train 11 | from model import * 12 | from dataload_vad import * 13 | classes_num=1 14 | logging.basicConfig(level=logging.DEBUG) 15 | logger=logging.getLogger(__name__) 16 | 17 | def test_attention_vis(net,testloader,device): 18 | net.eval() 19 | batch_loss=0 20 | j=1 21 | sum_outputs = torch.tensor([0]) 22 | sum_outputs = sum_outputs.to(device) 23 | count = 0 24 | count1 = 0 25 | with torch.no_grad(): 26 | for i,data in enumerate(testloader): 27 | images, labels,name = data 28 | if name[0] == '220_3': 29 | count1 += 1 30 | images, labels = images.to(device), labels.to(device) 31 | stats, weighted = net(images) 32 | weighted_path = os.path.join('attention_vis_220_3/', str(count1)+'_weighted.npy') 33 | # stats_path = os.path.join('attention_vis_316_1/', str(i)+'_stats.npy') 34 | # np.save(stats_path, stats.cpu()) 35 | np.save(weighted_path, weighted.cpu()) 36 | print(count1, name[0]) 37 | # if name[0] == '237_1': 38 | # count += 1 39 | # images, labels = images.to(device), labels.to(device) 40 | # stats, weighted = net(images) 41 | # weighted_path = os.path.join('attention_vis_237_1/', str(count)+'_weighted.npy') 42 | # # stats_path = os.path.join('attention_vis_246_1/', str(i)+'_stats.npy') 43 | # # np.save(stats_path, stats.cpu()) 44 | # np.save(weighted_path, weighted.cpu()) 45 | # print(count, name[0]) 46 | 47 | # return total_mse,total_mae 48 | 49 | def load_net(net,model_pkl): 50 | logger.info("load:%s"%model_pkl) 51 | net.load_state_dict(torch.load(model_pkl)) 52 | return net 53 | def count_parameters(model): 54 | parameters_sum = sum(p.numel() for p in model.parameters() if p.requires_grad) 55 | print(parameters_sum) 56 | 57 | def test_3s(data_root, model_path): 58 | # test_root = os.path.join(data_root, 'test') 59 | test_data=test_data_loader(data_root,batch_size=1,shuffle=False) 60 | net = DCCRN(rnn_units=256,use_clstm=True,kernel_num=[32, 64, 128, 256, 256,256]) 61 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 62 | net = torch.nn.DataParallel(net, device_ids=[0]) #GPU配置 63 | net.to(device) 64 | net = load_net(net, model_path) 65 | test_attention_vis(net=net, testloader=test_data, device=device) 66 | # for i in range(100,0,-1): 67 | # path="3s_1d_"+str(i)+".pkl" 68 | # my_file = Path("../pkl/"+path) 69 | # if my_file.is_file(): 70 | # net=load_net(net,path) 71 | # test_num(net=net, testloader=test_data, device=device) 72 | 73 | 74 | if __name__ == "__main__": 75 | data_root = '/data3/fancunhang/Depression/audio_good_without_move/AVEC2013_3s/test/' 76 | # model_path = 'exp/best_model.pkl' 77 | model_path = 'exp_0.002/checkpoint/best_model_epoch_75.pkl' 78 | # model_path = 'exp/checkpoint/model_epoch_96.pkl' 79 | test_3s(data_root, model_path) 80 | 81 | 82 | -------------------------------------------------------------------------------- /test_avg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import os 5 | import sys 6 | # from show import show_params, show_model 7 | import torch.nn.functional as F 8 | from conv_stft import ConvSTFT, ConviSTFT 9 | 10 | from complexnn import ComplexConv2d, ComplexConvTranspose2d, NavieComplexLSTM, complex_cat, ComplexBatchNorm 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | class SelfAttention(nn.Module): 16 | def __init__(self, hidden_size, mean_only=False): 17 | super(SelfAttention, self).__init__() 18 | 19 | #self.output_size = output_size 20 | self.hidden_size = hidden_size 21 | self.att_weights = nn.Parameter(torch.Tensor(1, hidden_size),requires_grad=True) 22 | 23 | self.mean_only = mean_only 24 | 25 | init.kaiming_uniform_(self.att_weights) 26 | 27 | def forward(self, inputs): 28 | 29 | batch_size = inputs.size(0) 30 | weights = torch.bmm(inputs, self.att_weights.permute(1, 0).unsqueeze(0).repeat(batch_size, 1, 1)) 31 | 32 | if inputs.size(0)==1: 33 | attentions = F.softmax(torch.tanh(weights),dim=1) 34 | weighted = torch.mul(inputs, attentions.expand_as(inputs)) 35 | else: 36 | attentions = F.softmax(torch.tanh(weights.squeeze()),dim=1) 37 | weighted = torch.mul(inputs, attentions.unsqueeze(2).expand_as(inputs)) 38 | 39 | if self.mean_only: 40 | return weighted.sum(1) 41 | else: 42 | noise = 1e-5*torch.randn(weighted.size()) 43 | 44 | if inputs.is_cuda: 45 | noise = noise.to(inputs.device) 46 | avg_repr, std_repr = weighted.sum(1), (weighted+noise).std(1) 47 | 48 | representations = torch.cat((avg_repr,std_repr),1) 49 | 50 | return representations 51 | 52 | class SELayer(nn.Module): 53 | def __init__(self, channel, reduction=16): 54 | super(SELayer, self).__init__() 55 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 56 | self.fc = nn.Sequential( 57 | nn.Linear(channel, channel // reduction, bias=False), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(channel // reduction, channel, bias=False), 60 | nn.Sigmoid() 61 | ) 62 | 63 | def forward(self, x): 64 | b, c, _, _ = x.size() 65 | y = self.avg_pool(x).view(b, c) 66 | y = self.fc(y).view(b, c, 1, 1) 67 | return x * y.expand_as(x) 68 | 69 | class SEBasicBlock(nn.Module): 70 | expansion = 1 71 | 72 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 73 | base_width=64, dilation=1, norm_layer=None, 74 | *, reduction=16): 75 | super(SEBasicBlock, self).__init__() 76 | self.conv1 = conv3x3(inplanes, planes, stride) 77 | self.bn1 = nn.BatchNorm2d(planes) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.conv2 = conv3x3(planes, planes, 1) 80 | self.bn2 = nn.BatchNorm2d(planes) 81 | self.se = SELayer(planes, reduction) 82 | self.downsample = downsample 83 | self.stride = stride 84 | 85 | def forward(self, x): 86 | residual = x 87 | out = self.conv1(x) 88 | out = self.bn1(out) 89 | out = self.relu(out) 90 | 91 | out = self.conv2(out) 92 | out = self.bn2(out) 93 | out = self.se(out) 94 | 95 | if self.downsample is not None: 96 | residual = self.downsample(x) 97 | 98 | out += residual 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | class Conv_regression(nn.Module): 104 | def __init__(self,channel, length): #(64,7999) 105 | super(Conv_regression, self).__init__() 106 | self.conv1 = nn.Conv1d(channel,512,kernel_size=3,stride=1,padding=1) 107 | self.bn1 = nn.BatchNorm1d(512) 108 | self.conv2 = nn.Conv1d(512, 256, kernel_size=3, stride=1, padding=1) 109 | self.bn2 = nn.BatchNorm1d(256) 110 | self.conv3 = nn.Conv1d(256, 128, kernel_size=3, stride=1, padding=1) 111 | self.bn3 = nn.BatchNorm1d(128) 112 | self.relu= nn.ReLU() 113 | self.avgpool1=nn.AdaptiveAvgPool1d(int(length/3)) 114 | self.avgpool2 = nn.AdaptiveAvgPool1d( int(length / 9)) 115 | self.avgpool3 = nn.AdaptiveAvgPool1d( int(length / 27)) 116 | self.linear1 = nn.Linear(int(length / 27),1) 117 | self.linear2 = nn.Linear(128, 1) 118 | 119 | 120 | def forward(self, x): 121 | y = self.relu(self.bn1(self.conv1(x))) 122 | # print('y1: ', y.size()) 123 | y = self.avgpool1(y) 124 | # print('y2: ', y.size()) 125 | y = self.relu(self.bn2(self.conv2(y))) 126 | # print('y3: ', y.size()) 127 | y = self.avgpool2(y) 128 | # print('y4: ', y.size()) 129 | y = self.relu(self.bn3(self.conv3(y))) 130 | # print('y5: ', y.size()) 131 | y = self.avgpool3(y) 132 | # print('y6: ', y.size()) 133 | y = self.linear1(self.relu(y)) 134 | y = torch.squeeze(y) 135 | y = self.linear2(self.relu(y)) 136 | return y 137 | 138 | class Conv_regression_selfattention(nn.Module): 139 | def __init__(self,channel): #(64,7999) 140 | super(Conv_regression_selfattention, self).__init__() 141 | self.conv5 = nn.Conv2d(channel, 256, kernel_size=(4, 3), stride=(1, 1), padding=(0, 1), bias=False) 142 | self.bn5 = nn.BatchNorm2d(256) 143 | 144 | self.activation = nn.ReLU() 145 | 146 | self.attention = SelfAttention(256) 147 | 148 | self.fc = nn.Linear(256 * 2, 128) 149 | self.fc_mu = nn.Linear(128, 1) 150 | 151 | 152 | def forward(self, x): 153 | # print('x1: ', x.size()) 154 | x = self.conv5(x) 155 | # print('x2: ', x.size()) 156 | x = self.activation(self.bn5(x)).squeeze(2) 157 | # print('x3: ', x.size()) 158 | 159 | stats = self.attention(x.permute(0, 2, 1).contiguous()) 160 | # print('stats: ', stats.size()) 161 | 162 | feat = self.fc(stats) 163 | # print('x4: ', feat.size()) 164 | 165 | mu = self.fc_mu(feat) 166 | # print('x5: ', mu.size()) 167 | return mu 168 | 169 | 170 | class PreActBlock(nn.Module): 171 | '''Pre-activation version of the BasicBlock.''' 172 | expansion = 1 173 | 174 | def __init__(self, in_planes, planes, stride, *args, **kwargs): 175 | super(PreActBlock, self).__init__() 176 | self.bn1 = nn.BatchNorm2d(in_planes) 177 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 178 | self.bn2 = nn.BatchNorm2d(planes) 179 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 180 | 181 | if stride != 1 or in_planes != self.expansion*planes: 182 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 183 | 184 | def forward(self, x): 185 | out = F.relu(self.bn1(x)) 186 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 187 | out = self.conv1(out) 188 | out = self.conv2(F.relu(self.bn2(out))) 189 | out += shortcut 190 | return out 191 | 192 | 193 | class PreActBottleneck(nn.Module): 194 | '''Pre-activation version of the original Bottleneck module.''' 195 | expansion = 4 196 | 197 | def __init__(self, in_planes, planes, stride, *args, **kwargs): 198 | super(PreActBottleneck, self).__init__() 199 | self.bn1 = nn.BatchNorm2d(in_planes) 200 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 201 | self.bn2 = nn.BatchNorm2d(planes) 202 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 203 | self.bn3 = nn.BatchNorm2d(planes) 204 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 205 | 206 | if stride != 1 or in_planes != self.expansion*planes: 207 | self.shortcut = nn.Sequential(nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)) 208 | 209 | def forward(self, x): 210 | out = F.relu(self.bn1(x)) 211 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 212 | out = self.conv1(out) 213 | out = self.conv2(F.relu(self.bn2(out))) 214 | out = self.conv3(F.relu(self.bn3(out))) 215 | out += shortcut 216 | return out 217 | 218 | def conv1x1(in_planes, out_planes, stride=1): 219 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 220 | 221 | # RESNET_CONFIGS = {'18': [[2, 2, 2, 2], PreActBlock], 222 | # '28': [[3, 4, 6, 3], PreActBlock], 223 | # '34': [[3, 4, 6, 3], PreActBlock], 224 | # '50': [[3, 4, 6, 3], PreActBottleneck], 225 | # '101': [[3, 4, 23, 3], PreActBottleneck] 226 | # } 227 | RESNET_CONFIGS = {'18': [[2, 2, 2, 2], SEBasicBlock]} 228 | 229 | class DCCRN_avg(nn.Module): 230 | 231 | def __init__( 232 | self, 233 | rnn_layers=2, 234 | rnn_units=128, 235 | win_len=400, 236 | win_inc=100, 237 | fft_len=512, 238 | win_type='hanning', 239 | use_clstm=False, 240 | use_cbn = False, 241 | kernel_size=5, 242 | kernel_num=[16,32,64,128,256,256], 243 | resnet_type='18' 244 | ): 245 | ''' 246 | 247 | rnn_layers: the number of lstm layers in the crn, 248 | rnn_units: for clstm, rnn_units = real+imag 249 | 250 | ''' 251 | 252 | super(DCCRN_avg, self).__init__() 253 | 254 | # for fft 255 | self.win_len = win_len 256 | self.win_inc = win_inc 257 | self.fft_len = fft_len 258 | self.win_type = win_type 259 | 260 | input_dim = win_len 261 | 262 | self.rnn_units = rnn_units 263 | self.input_dim = input_dim 264 | self.hidden_layers = rnn_layers 265 | self.kernel_size = kernel_size 266 | #self.kernel_num = [2, 8, 16, 32, 128, 128, 128] 267 | #self.kernel_num = [2, 16, 32, 64, 128, 256, 256] 268 | self.kernel_num = [2]+kernel_num 269 | self.use_clstm = use_clstm 270 | 271 | #bidirectional=True 272 | bidirectional=False 273 | fac = 2 if bidirectional else 1 274 | 275 | 276 | fix=True 277 | self.fix = fix 278 | self.stft = ConvSTFT(self.win_len, self.win_inc, fft_len, self.win_type, 'complex', fix=fix) 279 | 280 | # resnet 281 | self.in_planes = 16 282 | enc_dim = 256 283 | layers, block = RESNET_CONFIGS[resnet_type] 284 | self._norm_layer = nn.BatchNorm2d 285 | 286 | self.conv1 = nn.Conv2d(2, 16, kernel_size=(9, 3), stride=(3, 1), padding=(1, 1), bias=False) 287 | self.bn1 = nn.BatchNorm2d(16) 288 | self.activation = nn.ReLU() 289 | 290 | self.layer1 = self._make_layer(block, 64, layers[0], stride=1) 291 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 292 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 293 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 294 | 295 | self.conv5 = nn.Conv2d(512 * block.expansion, 256, kernel_size=(6, 3), stride=(1, 1), padding=(0, 1), 296 | bias=False) 297 | self.bn5 = nn.BatchNorm2d(256) 298 | self.conv6 = nn.Conv2d(256, 256, kernel_size=(6, 3), stride=(1, 1), padding=(0, 1), 299 | bias=False) 300 | self.bn6 = nn.BatchNorm2d(256) 301 | self.fc = nn.Linear(256 * 2, enc_dim) 302 | self.fc_mu = nn.Linear(enc_dim, 1) 303 | 304 | self.initialize_params() 305 | self.attention = SelfAttention(256) 306 | 307 | def initialize_params(self): 308 | for layer in self.modules(): 309 | if isinstance(layer, torch.nn.Conv2d): 310 | init.kaiming_normal_(layer.weight, a=0, mode='fan_out') 311 | elif isinstance(layer, torch.nn.Linear): 312 | init.kaiming_uniform_(layer.weight) 313 | elif isinstance(layer, torch.nn.BatchNorm2d) or isinstance(layer, torch.nn.BatchNorm1d): 314 | layer.weight.data.fill_(1) 315 | layer.bias.data.zero_() 316 | 317 | def _make_layer(self, block, planes, num_blocks, stride=1): 318 | norm_layer = self._norm_layer 319 | downsample = None 320 | if stride != 1 or self.in_planes != planes * block.expansion: 321 | downsample = nn.Sequential(conv1x1(self.in_planes, planes * block.expansion, stride), 322 | norm_layer(planes * block.expansion)) 323 | layers = [] 324 | layers.append(block(self.in_planes, planes, stride, downsample, 1, 64, 1, norm_layer)) 325 | self.in_planes = planes * block.expansion 326 | for _ in range(1, num_blocks): 327 | layers.append( 328 | block(self.in_planes, planes, 1, groups=1, base_width=64, dilation=False, norm_layer=norm_layer)) 329 | 330 | return nn.Sequential(*layers) 331 | 332 | 333 | def forward(self, inputs, weighted): 334 | # print('input: ', inputs.size()) 335 | specs = self.stft(inputs) 336 | real = specs[:,:self.fft_len//2+1] 337 | imag = specs[:,self.fft_len//2+1:] 338 | spec_mags = torch.sqrt(real**2+imag**2+1e-8) 339 | spec_mags = spec_mags 340 | spec_phase = torch.atan2(imag, real) 341 | spec_phase = spec_phase 342 | cspecs = torch.stack([real,imag],1) 343 | cspecs = cspecs[:,:,1:] 344 | # print('cspecs: ', cspecs.size()) 345 | ''' 346 | means = torch.mean(cspecs, [1,2,3], keepdim=True) 347 | std = torch.std(cspecs, [1,2,3], keepdim=True ) 348 | normed_cspecs = (cspecs-means)/(std+1e-8) 349 | out = normed_cspecs 350 | ''' 351 | 352 | # print('cspecs: ', cspecs.size()) 353 | x = self.conv1(cspecs) 354 | # print('x2: ', x.size()) 355 | x = self.activation(self.bn1(x)) 356 | x = self.layer1(x) 357 | # print('x3: ', x.size()) 358 | x = self.layer2(x) 359 | # print('x4: ', x.size()) 360 | x = self.layer3(x) 361 | # print('x5: ', x.size()) 362 | x = self.layer4(x) 363 | # print('layer4: ', x.size()) 364 | x = self.bn5(self.conv5(x)) 365 | # print('conv5: ', x.size()) 366 | x = self.bn6(self.conv6(x)) 367 | x = self.activation(x).squeeze(2) 368 | # print('x8: ', x.size()) 369 | 370 | stats = self.attention(x.permute(0, 2, 1).contiguous()) 371 | # print('stats: ', stats.size()) 372 | 373 | noise = 1e-5*torch.randn(weighted.size()) 374 | if inputs.is_cuda: 375 | noise = noise.to(inputs.device) 376 | avg_repr, std_repr = weighted.sum(1), (weighted+noise).std(1) 377 | stats = torch.cat((avg_repr,std_repr),1) 378 | 379 | feat = self.fc(stats) 380 | 381 | mu = self.fc_mu(feat) 382 | 383 | 384 | # y=self.Conv_regression_selfattention(out) 385 | return mu 386 | 387 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import time 6 | import math 7 | import logging 8 | from dataload import * 9 | 10 | logging.basicConfig(level=logging.DEBUG) 11 | logger=logging.getLogger(__name__) 12 | 13 | def asMinutes(s): 14 | m = math.floor(s / 60) 15 | s -= m * 60 16 | return '%dm %ds' % (m, s) 17 | def timeSince(since): 18 | now = time.time() 19 | s = now - since 20 | return '%s' % (asMinutes(s)) 21 | 22 | #训练 23 | def train(data_root, net,epoch_num,trainloader,batch_size,valloader,device=None,save_path=None,info_num=200,step_size=2, flag=-3): 24 | net.train() 25 | train_root = os.path.join(data_root, 'train') 26 | model_save_path = os.path.join(save_path, 'checkpoint') 27 | if not os.path.exists(model_save_path): 28 | os.mkdir(model_save_path) 29 | if trainloader == "3s": 30 | train_data = train_data_loader(train_root, batch_size=batch_size, shuffle=True,flag=flag) 31 | best_loss=1000 32 | criterion = nn.MSELoss() #损失函数 33 | optimizer = optim.Adam(net.parameters(), lr=0.002,weight_decay=0.01) #优化器 34 | # optimizer = optim.Adam(net.parameters(), lr=0.0003,weight_decay=0.01) #优化器 35 | #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1, last_epoch=-1) 36 | start = time.time() 37 | for epoch in range(epoch_num): 38 | net.train() 39 | if trainloader=="5s": 40 | train_data = train_data_loader('../audio_5s/train/', batch_size=batch_size, shuffle=False,flag=3) 41 | running_loss = 0.0 42 | for i, data in enumerate(train_data, 0): 43 | inputs, labels =data 44 | # print('dataload input: ', inputs.size()) 45 | inputs,labels=inputs.to(device),labels.to(device) 46 | optimizer.zero_grad() 47 | # print('dataload input: ', inputs.size()) 48 | # print('dataload input.float(): ', inputs.float().size()) 49 | outputs = net(inputs.float()) 50 | outputs=outputs.squeeze() 51 | loss = criterion(outputs, labels.float()) 52 | loss.backward() 53 | optimizer.step() 54 | running_loss += loss.item() 55 | if i % info_num == info_num-1: 56 | loss_mean=(running_loss/info_num)**0.5 57 | logger.info('[%d, %5d] %s loss: %.3f' %(epoch + 1, i + 1,timeSince(start), loss_mean)) 58 | running_loss = 0.0 59 | #scheduler.step() 60 | torch.save(net.state_dict(), model_save_path + "/model_epoch_" + str(epoch + 1) + ".pkl") 61 | val_loss = validate(valloader, net,device) 62 | if val_loss < best_loss: 63 | best_loss = val_loss 64 | torch.save(net.state_dict(), model_save_path + "/best_model_epoch_" + str(epoch + 1) + ".pkl") 65 | 66 | logger.info('Best loss: %5f'%best_loss) 67 | logger.info('Finished Training') 68 | 69 | 70 | def validate(val_loader, model,device): 71 | #切换模型为预测模型 72 | model.eval() 73 | batch_loss=0 74 | criterion = nn.MSELoss() 75 | with torch.no_grad(): 76 | for i, data in enumerate(val_loader): 77 | images, labels = data 78 | images = images.to(device) 79 | labels = labels.to(device) 80 | outputs = model(images) 81 | outputs = outputs.squeeze(-1) 82 | outputs = outputs.to(device) 83 | loss = criterion(outputs, labels.float()) 84 | batch_loss += loss.item() 85 | batch_num = i + 1 86 | mse = (batch_loss / batch_num) ** 0.5 87 | logger.info("eval mse is: %5f" % mse) 88 | return mse --------------------------------------------------------------------------------