├── Abbreviation.png ├── DEP-Former ├── model │ ├── attention.py │ ├── depformer.py │ ├── embed.py │ └── encoder.py ├── my_dataset.py ├── train.py └── utils.py ├── P2.pdf ├── P2_00.png └── README.md /Abbreviation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QLUTEmoTechCrew/DEP-Former/a882bb3a4d2adf884f49c3310773368469f58ebe/Abbreviation.png -------------------------------------------------------------------------------- /DEP-Former/model/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from scipy.stats import spearmanr 5 | import torch.nn.functional as F 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | class CrossAttention(nn.Module): 10 | def __init__(self, d_k, d_v, d_model, n_heads, dropout, mix=False): 11 | super(CrossAttention, self).__init__() 12 | self.d_k = d_k 13 | self.d_v = d_v 14 | self.d_model = d_model 15 | self.n_heads = n_heads 16 | self.mix = mix 17 | 18 | self.W_Q = nn.Linear(self.d_model, self.d_k * self.n_heads, bias=False) 19 | self.W_K = nn.Linear(self.d_model, self.d_k * self.n_heads, bias=False) 20 | self.W_V = nn.Linear(self.d_model, self.d_v * self.n_heads, bias=False) 21 | self.fc = nn.Linear(self.n_heads * self.d_v, self.d_model, bias=False) 22 | 23 | self.norm = nn.LayerNorm(self.d_model) 24 | self.dropout = nn.Dropout(dropout) 25 | 26 | def forward(self, input_Face, input_Voice, input_Adapter, attn_mask): 27 | residual, batch_size = input_Face.clone(), input_Face.size(0) 28 | Q = ( 29 | self.W_Q(input_Face) 30 | .view(batch_size, -1, self.n_heads, self.d_k) 31 | .transpose(1, 2) 32 | ) 33 | K = ( 34 | self.W_K(input_Voice) 35 | .view(batch_size, -1, self.n_heads, self.d_k) 36 | .transpose(1, 2) 37 | ) 38 | V = ( 39 | self.W_V(input_Adapter) 40 | .view(batch_size, -1, self.n_heads, self.d_v) 41 | .transpose(1, 2) 42 | ) 43 | 44 | scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt( 45 | self.d_k 46 | ) 47 | 48 | if attn_mask is not None: 49 | attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) 50 | scores.masked_fill_( 51 | attn_mask.type(torch.bool), 52 | torch.from_numpy(np.array(-np.inf)).to(device), 53 | ) 54 | 55 | attn = nn.Softmax(dim=-1)(scores) 56 | context = torch.matmul(attn, V) 57 | 58 | context = context.transpose(1, 2) 59 | if self.mix: 60 | context = context.transpose(1, 2) 61 | context = context.reshape(batch_size, -1, self.n_heads * self.d_v) 62 | output = self.dropout(self.fc(context)) 63 | return self.norm(output + residual) 64 | 65 | 66 | class ProbAttention(nn.Module): 67 | def __init__(self, d_k, d_v, d_model, n_heads, c, dropout, index, mix=False): 68 | super(ProbAttention, self).__init__() 69 | self.d_k = d_k 70 | self.d_v = d_v 71 | self.d_model = d_model 72 | self.n_heads = n_heads 73 | self.c = c 74 | self.mix = mix 75 | 76 | self.W_Q = nn.Linear(self.d_model, self.d_k * self.n_heads, bias=False) 77 | self.W_K = nn.Linear(self.d_model, self.d_k * self.n_heads, bias=False) 78 | self.W_V = nn.Linear(self.d_model, self.d_v * self.n_heads, bias=False) 79 | self.fc = nn.Linear(self.n_heads * self.d_v, self.d_model, bias=False) 80 | 81 | self.norm = nn.LayerNorm(self.d_model) 82 | self.dropout = nn.Dropout(dropout) 83 | self.ind = index 84 | 85 | self.alpha = nn.Sequential(nn.Linear(3072, 1), nn.Sigmoid()) 86 | 87 | def _prob_QK(self, Q, K, sample_k, n_top, ind): 88 | B, H, L_K, E = K.shape 89 | _, _, L_Q, _ = Q.shape 90 | 91 | K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E).clone() 92 | index_sample = torch.randint( 93 | 0, L_K, (L_Q, sample_k) 94 | ) 95 | K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :] 96 | 97 | Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze() 98 | M = torch.max(Q_K_sample, dim=-1).values - torch.mean(Q_K_sample, dim=-1) 99 | M_top_index = torch.topk(M, n_top, dim=-1).indices 100 | 101 | M_top_index_out = 0 102 | 103 | if self.ind == 1 and ind is not None: 104 | 105 | Q_a = Q.reshape(Q.size(0), -1) 106 | attention_weights = self.alpha(Q_a) 107 | _, top_idx = torch.topk(attention_weights.squeeze(), int(0.5 * B)) 108 | _, low_idx = torch.topk(attention_weights.squeeze(), int(0.5 * B), largest=False) 109 | index = torch.cat([top_idx, low_idx], dim=0) 110 | M_top_index_top = M_top_index[top_idx, :, :] 111 | ind_low = ind[low_idx, :, :] 112 | 113 | Q_sample_low = Q[ 114 | torch.arange(B)[top_idx, None, None], 115 | torch.arange(H)[None, :, None], 116 | ind_low, 117 | :, 118 | ] 119 | Q_sample_top = Q[ 120 | torch.arange(B)[low_idx, None, None], 121 | torch.arange(H)[None, :, None], 122 | M_top_index_top, 123 | :, 124 | ] 125 | Q_sample = torch.cat([Q_sample_top, Q_sample_low], dim=0) 126 | Q_sample = Q_sample[index] 127 | 128 | else: 129 | Q_sample = Q[ 130 | torch.arange(B)[:, None, None], 131 | torch.arange(H)[None, :, None], 132 | M_top_index, 133 | :, 134 | ] 135 | 136 | return Q_sample, M_top_index, M_top_index_out 137 | 138 | def _get_initial_context(self, V, L_Q, attn_mask): 139 | B, H, L_V, D = V.shape 140 | if attn_mask is None: 141 | V_sum = V.mean(dim=-2) 142 | contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, D).clone() 143 | else: 144 | contex = V.cumsum(dim=-2) # 累积和 145 | return contex 146 | 147 | def forward(self, input_Q, input_K, input_V, attn_mask, ind=None): 148 | residual, batch_size = input_Q.clone(), input_Q.size(0) 149 | L_K, L_Q = input_K.size(1), input_Q.size(1) 150 | 151 | u_k = min(int(self.c * np.log(L_K)), L_Q) 152 | u_q = min(int(self.c * np.log(L_Q)), L_Q) 153 | 154 | Q = ( 155 | self.W_Q(input_Q) 156 | .view(batch_size, -1, self.n_heads, self.d_k) 157 | .transpose(1, 2) 158 | ) 159 | K = ( 160 | self.W_K(input_K) 161 | .view(batch_size, -1, self.n_heads, self.d_k) 162 | .transpose(1, 2) 163 | ) 164 | V = ( 165 | self.W_V(input_V) 166 | .view(batch_size, -1, self.n_heads, self.d_v) 167 | .transpose(1, 2) 168 | ) 169 | 170 | Q_sample, index, M_top_index_out = self._prob_QK(Q, K, sample_k=u_k, n_top=u_q, ind=ind) 171 | scores = torch.matmul(Q_sample, K.transpose(-1, -2)) / np.sqrt(self.d_k) 172 | 173 | if attn_mask is not None: 174 | attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1) 175 | 176 | attn_mask = attn_mask[ 177 | torch.arange(batch_size)[:, None, None], 178 | torch.arange(self.n_heads)[None, :, None], 179 | index, 180 | :, 181 | ] 182 | scores.masked_fill_( 183 | attn_mask.type(torch.bool), 184 | torch.from_numpy(np.array(-np.inf)).to(device), 185 | ) 186 | 187 | attn = nn.Softmax(dim=-1)(scores) 188 | values = torch.matmul(attn, V) 189 | 190 | context = self._get_initial_context(V, L_Q, attn_mask) 191 | context[ 192 | torch.arange(batch_size)[:, None, None], 193 | torch.arange(self.n_heads)[None, :, None], 194 | index, 195 | :, 196 | ] = values 197 | 198 | context = context.transpose(1, 2) 199 | if self.mix: 200 | context = context.transpose(1, 2) 201 | context = context.reshape(batch_size, -1, self.n_heads * self.d_v) 202 | 203 | output = self.dropout(self.fc(context)) 204 | return self.norm(output + residual), index, M_top_index_out 205 | 206 | 207 | -------------------------------------------------------------------------------- /DEP-Former/model/depformer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | from model.encoder import Encoder 4 | from model.attention import CrossAttention 5 | 6 | 7 | class Depformer(nn.Module): 8 | def __init__( 9 | self, 10 | d_k=64, 11 | d_v=64, 12 | d_model=224, 13 | d_ff=32, 14 | n_heads=8, 15 | e_layer=3, 16 | d_layer=2, 17 | e_stack=3, 18 | d_feature=256, 19 | d_mark=224, 20 | dropout=0.1, 21 | c=5, 22 | ): 23 | super(Depformer, self).__init__() 24 | 25 | self.encoder = Encoder( 26 | d_k=d_k, 27 | d_v=d_v, 28 | d_model=d_model, 29 | d_ff=d_ff, 30 | n_heads=n_heads, 31 | n_layer=e_layer, 32 | n_stack=e_stack, 33 | d_feature=d_feature, 34 | d_mark=d_mark, 35 | dropout=dropout, 36 | c=c, 37 | index=0, 38 | ) 39 | self.encoder_fv = Encoder( 40 | d_k=d_k, 41 | d_v=d_v, 42 | d_model=d_model, 43 | d_ff=d_ff, 44 | n_heads=n_heads, 45 | n_layer=e_layer, 46 | n_stack=e_stack, 47 | d_feature=128, 48 | d_mark=d_mark, 49 | dropout=dropout, 50 | c=c, 51 | index=1, 52 | ) 53 | self.cross = CrossAttention(d_k=d_k, 54 | d_v=d_v, 55 | d_model=d_model, 56 | n_heads=n_heads, 57 | dropout=dropout, ) 58 | 59 | self.projection = nn.Linear(d_model, d_feature, bias=True) 60 | self.fc = nn.Linear(4032, 2) 61 | 62 | self.LN = nn.LayerNorm(1000) 63 | self.silu = nn.SiLU() 64 | self.fc1 = nn.Linear(224 * 224, 1000) 65 | self.fc2 = nn.Linear(1000, 1000) 66 | self.fc3 = nn.Linear(1000, 128) 67 | self.fc4 = nn.Linear(4032, 2) 68 | 69 | def Adapter(self, face_data, voice_data): 70 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 71 | face_data = face_data.view(face_data.shape[0], face_data.shape[1], -1).to(device) 72 | face_data = self.fc1(face_data).to(device) 73 | face_data = self.LN(face_data).to(device) 74 | face_data = self.fc3(face_data).to(device) 75 | face_data = self.silu(face_data).to(device) 76 | 77 | voice_data = voice_data.to(device) 78 | voice_data = self.fc2(voice_data).to(device) 79 | voice_data = self.LN(voice_data).to(device) 80 | voice_data = self.fc3(voice_data).to(device) 81 | voice_data = self.silu(voice_data).to(device) 82 | return face_data, voice_data 83 | 84 | def forward(self, face, voice, label): 85 | face = face.permute(0, 3, 1, 2) 86 | face_data, voice_data = self.Adapter(face, voice) 87 | enc_x = torch.cat((face_data, voice_data), dim=2) 88 | 89 | enc_outputs, index, M_out = self.encoder(enc_x) 90 | enc_outputs_face, index_f, F_out = self.encoder_fv(face_data, index) 91 | enc_outputs_voice, index_v, V_out = self.encoder_fv(voice_data, index) 92 | 93 | # torch.set_printoptions(threshold=float('inf')) 94 | # print("Ma_out", M_out) 95 | # print("Face_out", F_out) 96 | # print("Voice_out", V_out) 97 | # torch.set_printoptions(threshold=1000) 98 | 99 | # if index is not None: 100 | # for i in range(len(M_out)): 101 | # if label[i] == 1: 102 | # output_file_path = "/mnt/public/home/wangqx/Yejiayu/depmul2/output_dep_m.txt" 103 | # s = M_out[i, :, :10] 104 | # with open(output_file_path, 'a') as f: 105 | # f.write('mul') 106 | # torch.set_printoptions(threshold=float('inf')) 107 | # f.write(str(s)) 108 | # torch.set_printoptions(threshold=1000) 109 | # if label[i] == 0: 110 | # output_file_path = "/mnt/public/home/wangqx/Yejiayu/depmul2/output_nor_m.txt" 111 | # s = M_out[i, :, :10] 112 | # with open(output_file_path, 'a') as f: 113 | # f.write('mul') 114 | # torch.set_printoptions(threshold=float('inf')) 115 | # f.write(str(s)) 116 | # torch.set_printoptions(threshold=1000) 117 | # 118 | # if index_f is not None: 119 | # for i in range(len(F_out)): 120 | # if label[i] == 1: 121 | # output_file_path = "/mnt/public/home/wangqx/Yejiayu/depmul2/output_dep_f.txt" 122 | # s = F_out[i, :, :10] 123 | # with open(output_file_path, 'a') as f: 124 | # f.write('f') 125 | # torch.set_printoptions(threshold=float('inf')) 126 | # f.write(str(s)) 127 | # torch.set_printoptions(threshold=1000) 128 | # if label[i] == 0: 129 | # output_file_path = "/mnt/public/home/wangqx/Yejiayu/depmul2/output_nor_f.txt" 130 | # s = F_out[i, :, :10] 131 | # with open(output_file_path, 'a') as f: 132 | # f.write('f') 133 | # torch.set_printoptions(threshold=float('inf')) 134 | # f.write(str(s)) 135 | # torch.set_printoptions(threshold=1000) 136 | # 137 | # if index_v is not None: 138 | # for i in range(len(V_out)): 139 | # if label[i] == 1: 140 | # output_file_path = "/mnt/public/home/wangqx/Yejiayu/depmul2/output_dep_v.txt" 141 | # s = V_out[i, :, :10] 142 | # with open(output_file_path, 'a') as f: 143 | # f.write('v') 144 | # torch.set_printoptions(threshold=float('inf')) 145 | # f.write(str(s)) 146 | # torch.set_printoptions(threshold=1000) 147 | # if label[i] == 0: 148 | # output_file_path = "/mnt/public/home/wangqx/Yejiayu/depmul2/output_nor_v.txt" 149 | # s = V_out[i, :, :10] 150 | # with open(output_file_path, 'a') as f: 151 | # f.write('v') 152 | # torch.set_printoptions(threshold=float('inf')) 153 | # f.write(str(s)) 154 | # torch.set_printoptions(threshold=1000) 155 | 156 | cross_out = self.cross(enc_outputs_face, enc_outputs_voice, enc_outputs, None) 157 | 158 | enc_outputs = self.fc(enc_outputs.view(enc_outputs.shape[0], -1)) 159 | enc_outputs_face = self.fc4(enc_outputs_face.view(enc_outputs_face.shape[0], -1)) 160 | enc_outputs_voice = self.fc4(enc_outputs_voice.view(enc_outputs_voice.shape[0], -1)) 161 | cross_out = self.fc4(cross_out.view(cross_out.shape[0], -1)) 162 | outputs = enc_outputs + 0.1 * cross_out + 0.1 * enc_outputs_face + 0.1 * enc_outputs_voice 163 | 164 | return outputs 165 | 166 | 167 | -------------------------------------------------------------------------------- /DEP-Former/model/embed.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class PositionalEmbedding(nn.Module): 8 | def __init__(self, d_model, max_len=5000): 9 | super(PositionalEmbedding, self).__init__() 10 | pe = torch.zeros(max_len, d_model) 11 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 12 | div_term = torch.exp( 13 | torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) 14 | ) 15 | pe[:, 0::2] = torch.sin(position * div_term) 16 | pe[:, 1::2] = torch.cos(position * div_term) 17 | self.register_buffer("pe", pe) 18 | 19 | def forward(self, x): 20 | x = self.pe[: x.size(1), :] 21 | return x 22 | 23 | 24 | class TimeFeatureEmbedding(nn.Module): 25 | def __init__(self, d_mark, d_model): 26 | super(TimeFeatureEmbedding, self).__init__() 27 | self.embed = nn.Linear(d_mark, d_model) 28 | 29 | def forward(self, x): 30 | return self.embed(x) 31 | 32 | 33 | class TokenEmbedding(nn.Module): 34 | def __init__(self, d_feature, d_model): 35 | super(TokenEmbedding, self).__init__() 36 | self.Conv = nn.Conv1d( 37 | in_channels=d_feature, 38 | out_channels=d_model, 39 | kernel_size=(3,), 40 | padding=(1,), 41 | stride=(1,), 42 | padding_mode="circular", 43 | ) 44 | 45 | nn.init.kaiming_normal_( 46 | self.Conv.weight, mode="fan_in", nonlinearity="leaky_relu" 47 | ) 48 | 49 | def forward(self, x): 50 | x = self.Conv(x.permute(0, 2, 1)).transpose(1, 2) 51 | return x 52 | 53 | 54 | class DataEmbedding(nn.Module): 55 | def __init__(self, d_feature, d_mark, d_model, dropout=0.1): 56 | super(DataEmbedding, self).__init__() 57 | 58 | self.value_embedding = TokenEmbedding(d_feature=d_feature, d_model=d_model) 59 | self.position_embedding = PositionalEmbedding(d_model=d_model) 60 | self.time_embedding = TimeFeatureEmbedding(d_mark=d_mark, d_model=d_model) 61 | 62 | self.dropout = nn.Dropout(p=dropout) 63 | 64 | def forward(self, x): 65 | x = ( 66 | self.value_embedding(x) 67 | + self.position_embedding(x) 68 | ) 69 | return self.dropout(x) 70 | -------------------------------------------------------------------------------- /DEP-Former/model/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from model.attention import ProbAttention 6 | 7 | from model.embed import DataEmbedding 8 | 9 | 10 | class ConvLayer(nn.Module): 11 | def __init__(self, c_in): 12 | super(ConvLayer, self).__init__() 13 | self.downConv = nn.Conv1d( 14 | in_channels=c_in, 15 | out_channels=c_in, 16 | kernel_size=(3,), 17 | padding=(1,), 18 | stride=(1,), 19 | padding_mode="circular", 20 | ) 21 | self.norm = nn.BatchNorm1d(c_in) 22 | self.activation = nn.ELU() 23 | self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) 24 | 25 | def forward(self, x): 26 | if isinstance(x, tuple): 27 | x = x[0] 28 | else: 29 | x = x 30 | x = self.downConv(x.permute(0, 2, 1)) 31 | x = self.norm(x) 32 | x = self.activation(x) 33 | x = self.maxPool(x) 34 | x = x.transpose(1, 2) 35 | return x 36 | 37 | 38 | class EncoderLayer(nn.Module): 39 | def __init__(self, d_k, d_v, d_model, d_ff, n_heads, c, dropout, index): 40 | super(EncoderLayer, self).__init__() 41 | self.attention = ProbAttention(d_k, d_v, d_model, n_heads, c, dropout, index, mix=False) 42 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=(1,)) 43 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=(1,)) 44 | self.norm = nn.LayerNorm(d_model) 45 | self.dropout = nn.Dropout(dropout) 46 | self.activation = F.gelu 47 | 48 | def forward(self, x, ind=None, attn_mask=None): 49 | x, index, M_top_index_out = self.attention(x, x, x, attn_mask=attn_mask, ind=ind) 50 | residual = x.clone() 51 | x = self.dropout(self.activation(self.conv1(x.permute(0, 2, 1)))) 52 | y = self.dropout(self.conv2(x).permute(0, 2, 1)) 53 | return self.norm(residual + y), index, M_top_index_out 54 | 55 | 56 | class EncoderLayer_fv(nn.Module): 57 | def __init__(self, d_k, d_v, d_model, d_ff, n_heads, c, dropout, index): 58 | super(EncoderLayer_fv, self).__init__() 59 | self.attention = ProbAttention(d_k, d_v, d_model, n_heads, c, dropout, index, mix=False) 60 | self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=(1,)) 61 | self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=(1,)) 62 | self.norm = nn.LayerNorm(d_model) 63 | self.dropout = nn.Dropout(dropout) 64 | self.activation = F.gelu 65 | 66 | def forward(self, x, ind=None, attn_mask=None): 67 | x, index, M_top_index_out = self.attention(x, x, x, attn_mask=attn_mask, ind=ind) 68 | residual = x.clone() 69 | x = self.dropout(self.activation(self.conv1(x.permute(0, 2, 1)))) 70 | y = self.dropout(self.conv2(x).permute(0, 2, 1)) 71 | return self.norm(residual + y) 72 | 73 | 74 | class Encoder(nn.Module): 75 | def __init__( 76 | self, 77 | d_k, 78 | d_v, 79 | d_model, 80 | d_ff, 81 | n_heads, 82 | n_layer, 83 | n_stack, 84 | d_feature, 85 | d_mark, 86 | dropout, 87 | c, 88 | index, 89 | ): 90 | super(Encoder, self).__init__() 91 | 92 | self.embedding = DataEmbedding(d_feature, d_mark, d_model, dropout) 93 | 94 | self.stacks = nn.ModuleList() 95 | for i in range(n_stack): 96 | stack = nn.Sequential() 97 | stack.add_module( 98 | "elayer" + str(i) + "0", 99 | EncoderLayer(d_k, d_v, d_model, d_ff, n_heads, c, dropout, index), 100 | ) 101 | 102 | for j in range(n_layer - i - 1): 103 | stack.add_module("clayer" + str(i) + str(j + 1), ConvLayer(d_model)) 104 | stack.add_module( 105 | "elayer" + str(i) + str(j + 1), 106 | EncoderLayer(d_k, d_v, d_model, d_ff, n_heads, c, dropout, index), 107 | ) 108 | 109 | self.stacks.append(stack) 110 | self.norm = nn.LayerNorm(d_model) 111 | 112 | self.index = index 113 | self.en_fv = EncoderLayer_fv(d_k, d_v, d_model, d_ff, n_heads, c, dropout, index) 114 | 115 | def forward(self, enc_x, ind=None): 116 | x = self.embedding(enc_x) 117 | out = [] 118 | for i, stack in enumerate(self.stacks): 119 | inp_len = x.shape[1] // (2 ** i) 120 | y = x[:, -inp_len:, :] 121 | y, index, M_top_index_out = stack(y) 122 | y = self.norm(y) 123 | 124 | if self.index == 1: 125 | y1 = self.en_fv(y, ind) 126 | y1 = self.norm(y1) 127 | y = y + y1 128 | out.append(y) 129 | out = torch.cat(out, -2) 130 | 131 | return out, index, M_top_index_out 132 | -------------------------------------------------------------------------------- /DEP-Former/my_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from torch.utils.data import Dataset 4 | import logging 5 | import torch 6 | from torch.utils import data 7 | from torchvision import transforms, datasets 8 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 9 | import os 10 | from PIL import Image 11 | import numpy as np 12 | import albumentations as alb 13 | import random 14 | from torch.utils.data.dataloader import default_collate 15 | 16 | 17 | class MyDataSet(Dataset): 18 | 19 | def __init__(self, root1, root2): 20 | face = os.listdir(root1) 21 | self.face = [os.path.join(root1, k) for k in face] 22 | voice = os.listdir(root2) 23 | self.voice = [os.path.join(root2, k) for k in voice] 24 | 25 | def __getitem__(self, index): 26 | face = self.face[index] 27 | voice = self.voice[index] 28 | label = face[-8] 29 | if label == 'D': 30 | label = 1 31 | else: 32 | label = 0 33 | 34 | face_data = torch.from_numpy(np.load(face, allow_pickle=True).astype(float)).float() 35 | voice_data = torch.from_numpy(np.load(voice, allow_pickle=True).astype(float)).float() 36 | 37 | return face_data[:, :, 0: 24], voice_data[0: 24, :], label 38 | 39 | def __len__(self): 40 | return len(self.face) 41 | 42 | def normalization(self, data): 43 | _range = np.max(data) - np.min(data) 44 | return (data - np.min(data)) / _range 45 | 46 | 47 | if __name__ == "__main__": 48 | path1 = r'D:\Codedemo\dep-mul\mul\data\face' 49 | path2 = r'D:\Codedemo\dep-mul\mul\data\voic' 50 | print(MyDataSet(path1, path2)) 51 | -------------------------------------------------------------------------------- /DEP-Former/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import random 5 | 6 | import torch 7 | import torch.optim as optim 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | from my_dataset import MyDataSet 11 | from model.depformer import Depformer 12 | from utils import read_split_data, create_lr_scheduler, get_params_groups, train_one_epoch, evaluate 13 | 14 | import warnings 15 | 16 | warnings.filterwarnings('ignore') 17 | 18 | 19 | def setup_seed(seed): 20 | random.seed(seed) 21 | os.environ['PYTHONHASHSEED'] = str(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | torch.backends.cudnn.enabled = False 29 | 30 | 31 | def main(args): 32 | setup_seed(3407) 33 | device = torch.device(args.device if torch.cuda.is_available() else "cpu") 34 | print(f"using {device} device.") 35 | 36 | if os.path.exists("./weights") is False: 37 | os.makedirs("./weights") 38 | 39 | tb_writer = SummaryWriter() 40 | 41 | train_images_path_fac = r"" 42 | val_images_path_fac = r"" 43 | 44 | train_images_path_au = r"" 45 | val_images_path_au = r"" 46 | 47 | batch_size = args.batch_size 48 | 49 | train_dataset = MyDataSet(train_images_path_fac, train_images_path_au) 50 | 51 | val_dataset = MyDataSet(val_images_path_fac,val_images_path_au) 52 | 53 | nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers 54 | print('Using {} dataloader workers every process'.format(nw)) 55 | train_loader = torch.utils.data.DataLoader(train_dataset, 56 | batch_size=batch_size, 57 | shuffle=True, 58 | pin_memory=True, 59 | drop_last=False) 60 | 61 | val_loader = torch.utils.data.DataLoader(val_dataset, 62 | batch_size=batch_size, 63 | shuffle=True, 64 | pin_memory=True, 65 | drop_last=False) 66 | 67 | model = Informer().to(device) 68 | 69 | pg = get_params_groups(model, weight_decay=args.wd) 70 | optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=args.wd) 71 | lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs, 72 | warmup=True, warmup_epochs=1) 73 | 74 | best_acc = 0. 75 | for epoch in range(args.epochs): 76 | # train 77 | train_loss, train_acc = train_one_epoch(model=model, 78 | optimizer=optimizer, 79 | data_loader=train_loader, 80 | device=device, 81 | epoch=epoch, 82 | lr_scheduler=lr_scheduler) 83 | 84 | # validate 85 | val_loss, val_acc = evaluate(model=model, 86 | data_loader=val_loader, 87 | device=device, 88 | epoch=epoch) 89 | 90 | tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"] 91 | tb_writer.add_scalar(tags[0], train_loss, epoch) 92 | tb_writer.add_scalar(tags[1], train_acc, epoch) 93 | tb_writer.add_scalar(tags[2], val_loss, epoch) 94 | tb_writer.add_scalar(tags[3], val_acc, epoch) 95 | tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch) 96 | 97 | if best_acc < val_acc: 98 | torch.save(model.state_dict(), "./weights/best_model.pth") 99 | best_acc = val_acc 100 | 101 | 102 | if __name__ == '__main__': 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument('--num_classes', type=int, default=2) 105 | parser.add_argument('--epochs', type=int, default=300) 106 | parser.add_argument('--batch-size', type=int, default=32) 107 | parser.add_argument('--lr', type=float, default=0.001) 108 | parser.add_argument('--wd', type=float, default=5e-2) 109 | 110 | setup_seed(3407) 111 | 112 | parser.add_argument('--data-path', type=str, 113 | default="") 114 | 115 | parser.add_argument('--weights', type=str, 116 | default=r'', 117 | help='initial weights path') 118 | 119 | parser.add_argument('--freeze-layers', type=bool, default=False) 120 | parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)') 121 | 122 | opt = parser.parse_args() 123 | 124 | main(opt) 125 | -------------------------------------------------------------------------------- /DEP-Former/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import pickle 5 | import random 6 | import math 7 | import logging 8 | import torch 9 | from tqdm import tqdm 10 | import numpy as np 11 | 12 | import matplotlib.pyplot as plt 13 | from sklearn import metrics 14 | from sklearn.metrics import f1_score 15 | from sklearn.metrics import precision_score 16 | from sklearn.metrics import recall_score 17 | from sklearn import metrics 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def read_split_data(root: str, val_rate: float = 0.2): 23 | random.seed(3401) 24 | assert os.path.exists(root), "dataset root: {} does not exist.".format(root) 25 | 26 | flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] 27 | flower_class.sort() 28 | class_indices = dict((k, v) for v, k in enumerate(flower_class)) 29 | json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4) 30 | with open('class_indices.json', 'w') as json_file: 31 | json_file.write(json_str) 32 | 33 | train_images_path = [] 34 | train_images_label = [] 35 | val_images_path = [] 36 | val_images_label = [] 37 | every_class_num = [] 38 | supported = [".jpg", ".JPG", ".png", ".PNG"] 39 | 40 | for cla in flower_class: 41 | cla_path = os.path.join(root, cla) 42 | 43 | images = [os.path.join(root, cla, i) for i in os.listdir(cla_path) 44 | if os.path.splitext(i)[-1] in supported] 45 | 46 | image_class = class_indices[cla] 47 | 48 | every_class_num.append(len(images)) 49 | 50 | val_path = random.sample(images, k=int(len(images) * val_rate)) 51 | 52 | for img_path in images: 53 | if img_path in val_path: 54 | val_images_path.append(img_path) 55 | val_images_label.append(image_class) 56 | else: 57 | train_images_path.append(img_path) 58 | train_images_label.append(image_class) 59 | 60 | print("{} images were found in the dataset.".format(sum(every_class_num))) 61 | print("{} images for training.".format(len(train_images_path))) 62 | print("{} images for validation.".format(len(val_images_path))) 63 | assert len(train_images_path) > 0, "not find data for train." 64 | assert len(val_images_path) > 0, "not find data for eval" 65 | 66 | plot_image = False 67 | if plot_image: 68 | 69 | plt.bar(range(len(flower_class)), every_class_num, align='center') 70 | 71 | plt.xticks(range(len(flower_class)), flower_class) 72 | for i, v in enumerate(every_class_num): 73 | plt.text(x=i, y=v + 5, s=str(v), ha='center') 74 | plt.xlabel('image class') 75 | plt.ylabel('number of images') 76 | plt.title('flower class distribution') 77 | plt.show() 78 | 79 | return train_images_path, train_images_label, val_images_path, val_images_label 80 | 81 | 82 | def plot_data_loader_image(data_loader): 83 | batch_size = data_loader.batch_size 84 | plot_num = min(batch_size, 4) 85 | 86 | json_path = './class_indices.json' 87 | assert os.path.exists(json_path), json_path + " does not exist." 88 | json_file = open(json_path, 'r') 89 | class_indices = json.load(json_file) 90 | 91 | for data in data_loader: 92 | images, labels = data 93 | for i in range(plot_num): 94 | img = images[i].numpy().transpose(1, 2, 0) 95 | img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 96 | label = labels[i].item() 97 | plt.subplot(1, plot_num, i + 1) 98 | plt.xlabel(class_indices[str(label)]) 99 | plt.xticks([]) 100 | plt.yticks([]) 101 | plt.imshow(img.astype('uint8')) 102 | plt.show() 103 | 104 | 105 | def write_pickle(list_info: list, file_name: str): 106 | with open(file_name, 'wb') as f: 107 | pickle.dump(list_info, f) 108 | 109 | 110 | def read_pickle(file_name: str) -> list: 111 | with open(file_name, 'rb') as f: 112 | info_list = pickle.load(f) 113 | return info_list 114 | 115 | 116 | def train_one_epoch(model, optimizer, data_loader, device, epoch, lr_scheduler): 117 | model.train() 118 | loss_function = torch.nn.CrossEntropyLoss() 119 | accu_loss = torch.zeros(1).to(device) 120 | accu_num = torch.zeros(1).to(device) 121 | optimizer.zero_grad() 122 | 123 | sample_num = 0 124 | data_loader = tqdm(data_loader, file=sys.stdout) 125 | 126 | for step, data in enumerate(data_loader): 127 | face, voice, labels = data 128 | sample_num += face.shape[0] 129 | 130 | pred = model(face, voice, labels).to(device) 131 | pred_classes = torch.max(pred, dim=1)[1] 132 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 133 | 134 | loss = loss_function(pred, labels.to(device)) 135 | loss.backward() 136 | accu_loss += loss.detach() 137 | 138 | data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}, lr: {:.5f}".format( 139 | epoch, 140 | accu_loss.item() / (step + 1), 141 | accu_num.item() / sample_num, 142 | optimizer.param_groups[0]["lr"] 143 | ) 144 | 145 | if not torch.isfinite(loss): 146 | print('WARNING: non-finite loss, ending training ', loss) 147 | sys.exit(1) 148 | 149 | optimizer.step() 150 | optimizer.zero_grad() 151 | # update lr 152 | lr_scheduler.step() 153 | 154 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 155 | 156 | 157 | @torch.no_grad() 158 | def evaluate(model, data_loader, device, epoch): 159 | loss_function = torch.nn.CrossEntropyLoss() 160 | 161 | model.eval() 162 | 163 | accu_num = torch.zeros(1).to(device) 164 | accu_loss = torch.zeros(1).to(device) 165 | 166 | sample_num = 0 167 | data_loader = tqdm(data_loader, file=sys.stdout) 168 | all_preds, all_label = [], [] 169 | for step, data in enumerate(data_loader): 170 | face, voice, labels = data 171 | sample_num += face.shape[0] 172 | 173 | pred = model(face, voice, labels).to(device) 174 | pred_classes = torch.max(pred, dim=1)[1] 175 | 176 | accu_num += torch.eq(pred_classes, labels.to(device)).sum() 177 | loss = loss_function(pred, labels.to(device)) 178 | accu_loss += loss 179 | data_loader.desc = "[valid epoch {}] loss: {:.3f}, acc: {:.3f}".format( 180 | epoch, 181 | accu_loss.item() / (step + 1), 182 | accu_num.item() / sample_num 183 | ) 184 | if len(all_preds) == 0: 185 | all_preds.append(pred_classes.detach().cpu().numpy()) 186 | all_label.append(labels.detach().cpu().numpy()) 187 | else: 188 | all_preds[0] = np.append( 189 | all_preds[0], pred_classes.detach().cpu().numpy(), axis=0 190 | ) 191 | all_label[0] = np.append( 192 | all_label[0], labels.detach().cpu().numpy(), axis=0 193 | ) 194 | 195 | all_preds, all_label = all_preds[0], all_label[0] 196 | report = metrics.classification_report(all_label, all_preds, target_names=['0', '1'], digits=4) 197 | logger.info(report) 198 | print(report) 199 | print("AUC:{:.4f}".format(metrics.roc_auc_score(all_label, all_preds))) 200 | print("F1-Score:{:.4f}".format(f1_score(all_label, all_preds))) 201 | return accu_loss.item() / (step + 1), accu_num.item() / sample_num 202 | 203 | 204 | def create_lr_scheduler(optimizer, 205 | num_step: int, 206 | epochs: int, 207 | warmup=True, 208 | warmup_epochs=1, 209 | warmup_factor=1e-3, 210 | end_factor=1e-6): 211 | assert num_step > 0 and epochs > 0 212 | if warmup is False: 213 | warmup_epochs = 0 214 | 215 | def f(x): 216 | if warmup is True and x <= (warmup_epochs * num_step): 217 | alpha = float(x) / (warmup_epochs * num_step) 218 | return warmup_factor * (1 - alpha) + alpha 219 | else: 220 | current_step = (x - warmup_epochs * num_step) 221 | cosine_steps = (epochs - warmup_epochs) * num_step 222 | return ((1 + math.cos(current_step * math.pi / cosine_steps)) / 2) * (1 - end_factor) + end_factor 223 | 224 | return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f) 225 | 226 | 227 | def get_params_groups(model: torch.nn.Module, weight_decay: float = 1e-5): 228 | parameter_group_vars = {"decay": {"params": [], "weight_decay": weight_decay}, 229 | "no_decay": {"params": [], "weight_decay": 0.}} 230 | 231 | parameter_group_names = {"decay": {"params": [], "weight_decay": weight_decay}, 232 | "no_decay": {"params": [], "weight_decay": 0.}} 233 | 234 | for name, param in model.named_parameters(): 235 | if not param.requires_grad: 236 | continue 237 | 238 | if len(param.shape) == 1 or name.endswith(".bias"): 239 | group_name = "no_decay" 240 | else: 241 | group_name = "decay" 242 | 243 | parameter_group_vars[group_name]["params"].append(param) 244 | parameter_group_names[group_name]["params"].append(name) 245 | 246 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 247 | return list(parameter_group_vars.values()) 248 | -------------------------------------------------------------------------------- /P2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QLUTEmoTechCrew/DEP-Former/a882bb3a4d2adf884f49c3310773368469f58ebe/P2.pdf -------------------------------------------------------------------------------- /P2_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/QLUTEmoTechCrew/DEP-Former/a882bb3a4d2adf884f49c3310773368469f58ebe/P2_00.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DEP-Former 2 | DEP-Former: Multimodal Depression Recognition Based on Facial Expressions and Audio Features via Emotional Changes
3 | 4 | # Run 5 | Open the DEP-Former folder and run 'python train.py'.
6 | Required environment: Python 3.8 or higher.
7 | For technical issues, please contact: yejiayu97@outlook.com.
8 | 9 | # Data 10 | This is the official data link of DEP-Former.
11 | https://pan.baidu.com/s/1qMTpcw5na1gOfSq0ysujHg
12 | For password, please contact: wangqx@qlu.edu.cn.
13 | 14 | # Model 15 | ![Example Image](https://github.com/QLUTEmoTechCrew/DEP-Former/blob/main/P2_00.png) 16 | 17 | # Abbreviation 18 | ![Example Image](https://github.com/QLUTEmoTechCrew/DEP-Former/blob/main/Abbreviation.png) 19 | --------------------------------------------------------------------------------