├── README.md ├── module.py ├── models1.py └── main1.py /README.md: -------------------------------------------------------------------------------- 1 | # FDAN 2 | The code of the paper "MULTI-MODAL SPEECH EMOTION RECOGNITION USING FEATURE DISTRIBUTION ADAPTATION NETWORK". 3 | 4 | Other documents will be public after this paper is accepted. 5 | 6 | # The recognition results (%) of different models on the SAVEE dataset. 7 | 8 | ![TEXT](https://github.com/shaokai1209/shaokai1209/blob/main/ICASSP%202023%20SAVEE.png) 9 | 10 | # The recognition results (%) of different models on the MELD dataset. 11 | 12 | ![TEXT](https://github.com/shaokai1209/shaokai1209/blob/main/ICASSP%202023%20MELD.png) 13 | 14 | References:[[20](https://www.sciencedirect.com/science/article/abs/pii/S0950705120306766)][[21](https://www.sciencedirect.com/science/article/abs/pii/S0167639322000954)][[22](https://ieeexplore.ieee.org/abstract/document/9674867)][[23](https://ieeexplore.ieee.org/abstract/document/9745163)] 15 | -------------------------------------------------------------------------------- /module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | import torch.nn.functional as F 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | class Residual(nn.Module): 9 | def __init__(self, fn): 10 | super().__init__() 11 | self.fn = fn 12 | def forward(self, x, **kwargs): 13 | return self.fn(x, **kwargs) + x 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 46 | 47 | self.to_out = nn.Sequential( 48 | nn.Linear(inner_dim, dim), 49 | nn.Dropout(dropout) 50 | ) if project_out else nn.Identity() 51 | 52 | def forward(self, x): 53 | b, n, _, h = *x.shape, self.heads 54 | qkv = self.to_qkv(x).chunk(3, dim = -1) 55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 56 | 57 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 58 | 59 | attn = dots.softmax(dim=-1) 60 | 61 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 62 | out = rearrange(out, 'b h n d -> b n (h d)') 63 | out = self.to_out(out) 64 | return out 65 | 66 | class CrossAttention(nn.Module): 67 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 68 | super().__init__() 69 | inner_dim = dim_head * heads 70 | project_out = not (heads == 1 and dim_head == dim) 71 | 72 | self.heads = heads 73 | self.scale = dim_head ** -0.5 74 | 75 | self.to_k = nn.Linear(dim, inner_dim , bias=False) 76 | self.to_v = nn.Linear(dim, inner_dim , bias = False) 77 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 78 | 79 | self.to_out = nn.Sequential( 80 | nn.Linear(inner_dim, dim), 81 | nn.Dropout(dropout) 82 | ) if project_out else nn.Identity() 83 | 84 | def forward(self, x_qkv): 85 | b, n, _, h = *x_qkv.shape, self.heads 86 | 87 | k = self.to_k(x_qkv) 88 | k = rearrange(k, 'b n (h d) -> b h n d', h = h) 89 | 90 | v = self.to_v(x_qkv) 91 | v = rearrange(v, 'b n (h d) -> b h n d', h = h) 92 | 93 | q = self.to_q(x_qkv[:, 0].unsqueeze(1)) 94 | q = rearrange(q, 'b n (h d) -> b h n d', h = h) 95 | 96 | 97 | 98 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 99 | 100 | attn = dots.softmax(dim=-1) 101 | 102 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 103 | out = rearrange(out, 'b h n d -> b n (h d)') 104 | out = self.to_out(out) 105 | return out 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /models1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transfer_losses import TransferLoss 4 | import backbones 5 | from module import Attention, PreNorm, FeedForward, CrossAttention 6 | 7 | class TransferNet(nn.Module): 8 | def __init__(self, num_class, base_net='resnet34', transfer_loss='mmd', use_bottleneck=True, bottleneck_width=256, max_iter=1000, **kwargs): 9 | super(TransferNet, self).__init__() 10 | self.num_class = num_class 11 | self.base_network = backbones.get_backbone(base_net) 12 | self.use_bottleneck = use_bottleneck 13 | self.transfer_loss = transfer_loss 14 | if self.use_bottleneck: 15 | bottleneck_list = [ 16 | nn.Linear(self.base_network.output_num(), bottleneck_width), 17 | nn.ReLU() 18 | ] 19 | self.bottleneck_layer = nn.Sequential(*bottleneck_list) 20 | feature_dim = bottleneck_width 21 | else: 22 | feature_dim = self.base_network.output_num() 23 | 24 | self.classifier_layer = nn.Linear(feature_dim, num_class) 25 | transfer_loss_args = { 26 | "loss_type": self.transfer_loss, 27 | "max_iter": max_iter, 28 | "num_class": num_class 29 | } 30 | self.adapt_loss = TransferLoss(**transfer_loss_args) 31 | self.criterion = torch.nn.CrossEntropyLoss() 32 | # cross-attention 33 | self.cross_attn_layers = nn.ModuleList([]) 34 | for _ in range(1): 35 | self.cross_attn_layers.append(nn.ModuleList([ 36 | nn.Linear(256,256), 37 | nn.Linear(256, 256), 38 | PreNorm(256, CrossAttention(256, heads =2, dim_head = 32, dropout = 0)), 39 | nn.Linear(256, 256), 40 | nn.Linear(256, 256), 41 | PreNorm(256, CrossAttention(256, heads =2, dim_head = 32, dropout = 0)), 42 | ])) 43 | def forward(self, source, target, source_label, target_label): 44 | source = self.base_network(source) 45 | target = self.base_network(target) 46 | if self.use_bottleneck: 47 | source = self.bottleneck_layer(source) 48 | target = self.bottleneck_layer(target) 49 | # cross-attention 50 | source = source[None,:] 51 | target = target[None,:] 52 | xs = source 53 | xl = target 54 | for f_sl, g_ls, cross_attn_s, f_ls, g_sl, cross_attn_l in self.cross_attn_layers: 55 | small_class = xs[:, 0] 56 | x_small = xs[:, 1:] 57 | large_class = xl[:, 0] 58 | x_large = xl[:, 1:] 59 | 60 | # Cross Attn for target 61 | 62 | cal_q = f_ls(large_class.unsqueeze(1)) 63 | cal_qkv = torch.cat((cal_q, x_small), dim=1) 64 | cal_out = cal_q + cross_attn_l(cal_qkv) 65 | cal_out = g_sl(cal_out) 66 | xl = torch.cat((cal_out, x_large), dim=1) 67 | 68 | # Cross Attn for source 69 | cal_q = f_sl(small_class.unsqueeze(1)) 70 | cal_qkv = torch.cat((cal_q, x_large), dim=1) 71 | cal_out = cal_q + cross_attn_s(cal_qkv) 72 | cal_out = g_ls(cal_out) 73 | xs = torch.cat((cal_out, x_small), dim=1) 74 | source = xs.squeeze() # 32,256 75 | target = xl.squeeze() 76 | # classification 77 | source_clf = self.classifier_layer(source) 78 | target_clf = self.classifier_layer(target) 79 | clf_loss = self.criterion(source_clf, source_label) + self.criterion( target_clf, target_label) # 80 | # transfer 81 | kwargs = {} 82 | if self.transfer_loss == "lmmd": 83 | kwargs['source_label'] = source_label 84 | #kwargs['target_label'] = target_label 85 | target_clf = self.classifier_layer(target) 86 | #kwargs['target_logits'] = torch.nn.functional.softmax(target_clf, dim=1) 87 | aa = torch.nn.functional.softmax(target_clf, dim=1) 88 | #print("######################") 89 | #print("target_label:",target_label[:5]) 90 | #print("label_pro",target_clf[:5]) 91 | for i in range(len(aa)): 92 | for j in range(len(aa[i])): 93 | if j==target_label[i]: 94 | aa[i][j] = 1.0 95 | else: 96 | aa[i][j] = 0.0 97 | #print("label_matrix",aa[:5]) 98 | kwargs['target_logits'] =aa 99 | 100 | transfer_loss = self.adapt_loss(source, target, **kwargs) 101 | return clf_loss, transfer_loss 102 | 103 | def get_parameters(self, initial_lr=1.0): 104 | params = [ 105 | {'params': self.base_network.parameters(), 'lr': 0.1 * initial_lr}, 106 | {'params': self.classifier_layer.parameters(), 'lr': 1.0 * initial_lr}, 107 | ] 108 | if self.use_bottleneck: 109 | params.append( 110 | {'params': self.bottleneck_layer.parameters(), 'lr': 1.0 * initial_lr} 111 | ) 112 | return params 113 | 114 | def predict(self, x): 115 | features = self.base_network(x) 116 | x = self.bottleneck_layer(features) 117 | clf = self.classifier_layer(x) 118 | return clf 119 | 120 | def epoch_based_processing(self, *args, **kwargs): 121 | if self.transfer_loss == "daan": 122 | self.adapt_loss.loss_func.update_dynamic_factor(*args, **kwargs) 123 | else: 124 | pass 125 | -------------------------------------------------------------------------------- /main1.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import data_loader 3 | import os 4 | import torch 5 | import models1 6 | import utils 7 | from utils import str2bool 8 | import numpy as np 9 | import random 10 | from sklearn.metrics import recall_score 11 | from sklearn.preprocessing import MultiLabelBinarizer 12 | 13 | def get_parser(): 14 | """Get default arguments.""" 15 | parser = configargparse.ArgumentParser( 16 | description="Transfer learning config parser", 17 | config_file_parser_class=configargparse.YAMLConfigFileParser, 18 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter, 19 | ) 20 | # general configuration 21 | parser.add("--config", is_config_file=True, help="config file path") 22 | parser.add("--seed", type=int, default=0) 23 | parser.add_argument('--num_workers', type=int, default=0) 24 | 25 | # network related 26 | parser.add_argument('--backbone', type=str, default='resnet50') 27 | parser.add_argument('--use_bottleneck', type=str2bool, default=True) 28 | 29 | # data loading related 30 | parser.add_argument('--data_dir', type=str, required=True) 31 | parser.add_argument('--src_domain', type=str, required=True) 32 | parser.add_argument('--tgt_domain', type=str, required=True) 33 | parser.add_argument('--src_val', type=str, required=True) 34 | # training related 35 | parser.add_argument('--batch_size', type=int, default=32) 36 | parser.add_argument('--n_epoch', type=int, default=100) 37 | parser.add_argument('--early_stop', type=int, default=0, help="Early stopping") 38 | parser.add_argument('--epoch_based_training', type=str2bool, default=False, help="Epoch-based training / Iteration-based training") 39 | parser.add_argument("--n_iter_per_epoch", type=int, default=20, help="Used in Iteration-based training") 40 | 41 | # optimizer related 42 | parser.add_argument('--lr', type=float, default=1e-3) 43 | parser.add_argument('--momentum', type=float, default=0.9) 44 | parser.add_argument('--weight_decay', type=float, default=5e-4) 45 | 46 | # learning rate scheduler related 47 | parser.add_argument('--lr_gamma', type=float, default=0.0003) 48 | parser.add_argument('--lr_decay', type=float, default=0.75) 49 | parser.add_argument('--lr_scheduler', type=str2bool, default=True) 50 | 51 | # transfer related 52 | parser.add_argument('--transfer_loss_weight', type=float, default=10) 53 | parser.add_argument('--transfer_loss', type=str, default='mmd') 54 | return parser 55 | 56 | def set_random_seed(seed=0): 57 | # seed setting 58 | random.seed(seed) 59 | np.random.seed(seed) 60 | torch.manual_seed(seed) 61 | torch.cuda.manual_seed(seed) 62 | torch.backends.cudnn.deterministic = True 63 | torch.backends.cudnn.benchmark = False 64 | 65 | def load_data(args): 66 | ''' 67 | src_domain, tgt_domain data to load 68 | ''' 69 | folder_src = os.path.join(args.data_dir, args.src_domain) 70 | folder_tgt = os.path.join(args.data_dir, args.tgt_domain) 71 | folder_val = os.path.join(args.data_dir, args.src_val) 72 | source_loader, n_class = data_loader.load_data( 73 | folder_src, args.batch_size, infinite_data_loader=not args.epoch_based_training, train=True, num_workers=args.num_workers) 74 | target_train_loader, _ = data_loader.load_data( 75 | folder_tgt, args.batch_size, infinite_data_loader=not args.epoch_based_training, train=True, num_workers=args.num_workers) 76 | source_test_loader, _ = data_loader.load_data( 77 | folder_val, args.batch_size, infinite_data_loader=False , train=False, num_workers=args.num_workers) 78 | return source_loader, target_train_loader, source_test_loader, n_class 79 | 80 | def get_model(args): 81 | model = models1.TransferNet( 82 | args.n_class, transfer_loss=args.transfer_loss, base_net=args.backbone, max_iter=args.max_iter, use_bottleneck=args.use_bottleneck).to(args.device) 83 | return model 84 | 85 | def get_optimizer(model, args): 86 | initial_lr = args.lr if not args.lr_scheduler else 1.0 87 | params = model.get_parameters(initial_lr=initial_lr) 88 | optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False) 89 | return optimizer 90 | 91 | def get_scheduler(optimizer, args): 92 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) 93 | return scheduler 94 | 95 | 96 | def compute_uar(y_true, y_pred): 97 | unique_labels = np.unique(y_true) # 获取唯一的标签类别 98 | recalls = [] 99 | for label in unique_labels: 100 | true_positives = np.sum((y_true == label) & (y_pred == label)) 101 | possible_positives = np.sum(y_true == label) 102 | recall = true_positives / possible_positives if possible_positives > 0 else 0 103 | recalls.append(recall) 104 | 105 | uar = np.mean(recalls) 106 | return uar 107 | 108 | def test(model, source_test_loader, args): 109 | model.eval() 110 | test_loss = utils.AverageMeter() 111 | correct = 0 112 | criterion = torch.nn.CrossEntropyLoss() 113 | len_source_dataset = len(source_test_loader.dataset) 114 | pred_label = [] 115 | true_label = [] 116 | with torch.no_grad(): 117 | for data, source in source_test_loader: 118 | data, source = data.to(args.device), source.to(args.device) 119 | s_output = model.predict(data) 120 | loss = criterion(s_output, source) 121 | test_loss.update(loss.item()) 122 | pred = torch.max(s_output, 1)[1] 123 | pred_label.extend(pred.tolist()) 124 | true_label.extend(source.tolist()) 125 | correct += torch.sum(pred == source) 126 | acc = 100. * correct / len_source_dataset 127 | uar = compute_uar(true_label, pred_label) 128 | print(len_source_dataset) 129 | return acc, uar, test_loss.avg, pred_label, true_label 130 | 131 | def train(source_loader, target_train_loader, source_test_loader, model, optimizer, lr_scheduler, args): 132 | len_source_loader = len(source_loader) 133 | len_target_loader = len(target_train_loader) 134 | n_batch = min(len_source_loader, len_target_loader) 135 | if n_batch == 0: 136 | n_batch = args.n_iter_per_epoch 137 | 138 | iter_source, iter_target = iter(source_loader), iter(target_train_loader) 139 | 140 | best_acc = 0 141 | best_uar = 0 142 | stop = 0 143 | log = [] 144 | final_pred_label = [] 145 | final_true_label = [] 146 | for e in range(1, args.n_epoch+1): 147 | model.train() 148 | train_loss_clf = utils.AverageMeter() 149 | train_loss_transfer = utils.AverageMeter() 150 | train_loss_total = utils.AverageMeter() 151 | model.epoch_based_processing(n_batch) 152 | 153 | if max(len_target_loader, len_source_loader) != 0: 154 | iter_source, iter_target = iter(source_loader), iter(target_train_loader) 155 | 156 | criterion = torch.nn.CrossEntropyLoss() 157 | for _ in range(n_batch): 158 | data_source, label_source = next(iter_source) # .next() 159 | data_target, label_target = next(iter_target) # .next() 160 | #data_target, _ = next(iter_target) # .next() 161 | data_source, label_source = data_source.to(args.device), label_source.to(args.device) 162 | #data_target = data_target.to(args.device) 163 | data_target , label_target= data_target.to(args.device), label_target.to(args.device) 164 | clf_loss, transfer_loss = model(data_source, data_target, label_source,label_target) 165 | loss = clf_loss + args.transfer_loss_weight * transfer_loss 166 | 167 | optimizer.zero_grad() 168 | loss.backward() 169 | optimizer.step() 170 | if lr_scheduler: 171 | lr_scheduler.step() 172 | 173 | train_loss_clf.update(clf_loss.item()) 174 | train_loss_transfer.update(transfer_loss.item()) 175 | train_loss_total.update(loss.item()) 176 | 177 | log.append([train_loss_clf.avg, train_loss_transfer.avg, train_loss_total.avg]) 178 | 179 | info = 'Epoch: [{:2d}/{}], cls_loss: {:.4f}, transfer_loss: {:.4f}, total_Loss: {:.4f}'.format( 180 | e, args.n_epoch, train_loss_clf.avg, train_loss_transfer.avg, train_loss_total.avg) 181 | # Test 182 | stop += 1 183 | test_acc, uar, test_loss, pred_label, true_label = test(model, source_test_loader, args) 184 | info += ', test_loss {:4f}, test_acc: {:.4f}, test_uar: {:.4f}'.format(test_loss, test_acc, uar) 185 | np_log = np.array(log, dtype=float) 186 | np.savetxt('train_log.csv', np_log, delimiter=',', fmt='%.6f') 187 | if best_acc < test_acc: 188 | best_acc = test_acc 189 | best_uar = uar 190 | final_pred_label = pred_label 191 | final_true_label = true_label 192 | stop = 0 193 | if args.early_stop > 0 and stop >= args.early_stop: 194 | print(info) 195 | break 196 | print(info) 197 | final_true_label = [str(i)+'\n' for i in final_true_label] 198 | f=open("true_label.txt","w") 199 | f.writelines(final_true_label) 200 | f.close() 201 | final_pred_label = [str(i)+'\n' for i in final_pred_label] 202 | f1=open("pred_label.txt","w") 203 | f1.writelines(final_pred_label) 204 | f1.close() 205 | print('Transfer result: {:.4f}, {:.4f}'.format(best_acc, best_uar)) 206 | 207 | def main(): 208 | parser = get_parser() 209 | args = parser.parse_args() 210 | setattr(args, "device", torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 211 | print(args) 212 | set_random_seed(args.seed) 213 | source_loader, target_train_loader, source_test_loader, n_class = load_data(args) 214 | setattr(args, "n_class", n_class) 215 | if args.epoch_based_training: 216 | setattr(args, "max_iter", args.n_epoch * min(len(source_loader), len(target_train_loader))) 217 | else: 218 | setattr(args, "max_iter", args.n_epoch * args.n_iter_per_epoch) 219 | model = get_model(args) 220 | optimizer = get_optimizer(model, args) 221 | 222 | if args.lr_scheduler: 223 | scheduler = get_scheduler(optimizer, args) 224 | else: 225 | scheduler = None 226 | train(source_loader, target_train_loader, source_test_loader, model, optimizer, scheduler, args) 227 | 228 | 229 | if __name__ == "__main__": 230 | main() 231 | --------------------------------------------------------------------------------