├── .gitignore ├── data ├── common.py └── ecc.py ├── README.md ├── model ├── baseline_gru.py ├── attention.py ├── graph.py ├── tacotron.py ├── baseline.py └── proposed.py ├── save.py ├── test.py ├── hparams.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | enhanced 3 | segmented*/ 4 | save 5 | *.sw* 6 | *.tar* 7 | -------------------------------------------------------------------------------- /data/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class Collate: 4 | def __init__(self, device=None): 5 | self.device = device 6 | 7 | def __call__(self, batch): 8 | length = len(batch[0]) 9 | output = [[] for i in range(length)] 10 | 11 | for data in batch: 12 | for i, j in enumerate(data): 13 | if not torch.is_tensor(j): 14 | j = torch.from_numpy(j) 15 | output[i].append(j if self.device is None else j.to(self.device)) 16 | 17 | return tuple(output) 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Inferring Speaking Styles from Multi-modal Conversational Context by Multi-scale Relational Graph Convolutional Networks 2 | ==== 3 | 4 | The source code for [Inferring Speaking Styles from Multi-modal Conversational Context by Multi-scale Relational Graph Convolutional Networks](https://dl.acm.org/doi/10.1145/3503161.3547831) on ACM Multimedia 2022. 5 | 6 | This project also contains a re-implementation for [Enhancing Speaking Styles in Conversational Text-to-Speech Synthesis with Graph-based Multi-modal Context Modeling](https://ieeexplore.ieee.org/abstract/document/9747837/) on ICASSP 2022. 7 | 8 | The source code for the multi-scale speaking style enhanced FastSpeech 2 is available at [thuhcsi/mst-fastspeech2](https://github.com/thuhcsi/mst-fastspeech2). 9 | -------------------------------------------------------------------------------- /model/baseline_gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .baseline import GlobalEncoder, FakeMST, pad_sequence 5 | 6 | class BaselineGRU(nn.Module): 7 | 8 | def __init__(self, hparams): 9 | super().__init__() 10 | self.gru = nn.GRU(hparams.gru.input_dim, hparams.gru.dim, 1, batch_first=True) 11 | self.global_encoder = GlobalEncoder(hparams.global_encoder) 12 | self.global_linear = nn.Linear(hparams.global_linear.input_dim, hparams.global_linear.output_dim) 13 | self.softmax = nn.Softmax(dim=-1) 14 | self.mse = nn.MSELoss() 15 | 16 | def forward(self, length, speaker, bert, history_gst, sbert): 17 | batch_size = len(length) 18 | 19 | history_sbert = torch.stack([i[:-1] for i in sbert]) 20 | history_global_features, _ = self.gru(history_sbert) 21 | history_global_features = history_global_features[:,-1] 22 | 23 | current_length = [i[-1] for i in length] 24 | current_length = torch.stack(current_length) 25 | current_bert = [i[-1] for i in bert] 26 | current_bert = pad_sequence(current_bert) 27 | current_sbert = [i[-1] for i in sbert] 28 | current_sbert = torch.stack(current_sbert) 29 | current_global_feature = self.global_encoder(current_bert, current_length.cpu()) 30 | current_global_feature = current_global_feature[:,-1] 31 | 32 | current_speaker = torch.stack([i[-1] for i in speaker]) 33 | current_speaker = nn.functional.one_hot(current_speaker, num_classes=len(speaker[0])) 34 | current_global_feature = torch.cat([history_global_features, current_sbert, current_global_feature, current_speaker], dim=-1) 35 | 36 | current_gst = self.global_linear(current_global_feature) 37 | current_gst = current_gst.contiguous().view(batch_size, 4, 10) 38 | current_gst = self.softmax(current_gst) 39 | current_gst = current_gst.contiguous().view(batch_size, 40) 40 | return current_gst 41 | 42 | def gst_loss(self, p_gst, gst): 43 | return self.mse(p_gst, gst) 44 | 45 | if __name__ == '__main__': 46 | from data.ecc import ECC 47 | from data.common import Collate 48 | from hparams import baseline_gru 49 | 50 | device = 'cpu' 51 | data_loader = torch.utils.data.DataLoader(ECC('segmented-train'), batch_size=2, shuffle=True, collate_fn=Collate(device)) 52 | 53 | model = BaselineGRU(baseline_gru) 54 | fake = FakeMST(baseline_gru.fake_mst) 55 | model.to(device) 56 | 57 | for batch in data_loader: 58 | length, speaker, bert, gst, wst, gst_only, sbert = batch 59 | history_gst = [i[:-1] for i in gst_only] 60 | predicted_gst = model(length, speaker, bert, history_gst, sbert) 61 | print(predicted_gst.shape) 62 | 63 | current_length = [i[-1] for i in length] 64 | current_wst = [i[-1, :l] for i, l in zip(wst, current_length)] 65 | current_bert = [i[-1] for i in bert] 66 | predicted_gst, predicted_wst = fake(current_length, current_bert, predicted_gst.detach()) 67 | print([i.shape for i in predicted_wst]) 68 | break 69 | -------------------------------------------------------------------------------- /save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import torch 5 | import numpy as np 6 | from datetime import datetime 7 | from pathlib import Path 8 | 9 | from tornado.log import enable_pretty_logging 10 | from tornado.options import options 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | import matplotlib 14 | matplotlib.use("Agg") 15 | import matplotlib.pylab as plt 16 | 17 | def save_figure_to_numpy(fig): 18 | # save it to a numpy array. 19 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 20 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 21 | return data 22 | 23 | def plot_alignment_to_numpy(alignment, info=None): 24 | fig, ax = plt.subplots(figsize=(6, 4)) 25 | im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none') 26 | fig.colorbar(im, ax=ax) 27 | #xlabel = 'Decoder timestep' 28 | #if info is not None: 29 | # xlabel += '\n\n' + info 30 | #plt.xlabel(xlabel) 31 | #plt.ylabel('Encoder timestep') 32 | plt.tight_layout() 33 | 34 | fig.canvas.draw() 35 | data = save_figure_to_numpy(fig) 36 | plt.close() 37 | return data 38 | 39 | class Save: 40 | 41 | def __init__(self, name='noname'): 42 | 43 | self.name = name + datetime.now().strftime("-%Y%m%d-%H%M%S") 44 | self.path = Path('save') / self.name 45 | self.path.mkdir(parents=True, exist_ok=True) 46 | 47 | self.logger = logging.getLogger(self.name) 48 | options.logging = 'debug' 49 | enable_pretty_logging(options=options, logger=self.logger) 50 | 51 | self.writer = SummaryWriter(self.path) 52 | 53 | def save_log(self, stage, epoch, batch, step, loss): 54 | self.logger.info('[%s] %s epoch %d batch %d step %d loss %f', self.name, stage, epoch, batch, step, loss) 55 | self.writer.add_scalar(f"{stage}/loss", loss, step) 56 | 57 | def save_parameters(self, hparams): 58 | self.writer.add_text("hparams", json.dumps(hparams, indent=2)) 59 | 60 | def save_model(self, model, filename): 61 | torch.save(model.state_dict(), os.path.join(self.path, filename)) 62 | 63 | def save_boundary(self, stage, step, p_boundary, boundary, shape): 64 | figure = np.zeros(shape) 65 | 66 | for i in range(boundary.shape[0]): 67 | for j, k in [(p_boundary[i][0], 0.7), (p_boundary[i][1], 0.7), (boundary[i][0], 1), (boundary[i][1], 1)]: 68 | try: 69 | if j >= 0: 70 | #print(int(100 * j), i, k) 71 | figure[int(100 * j), i] = k 72 | except: 73 | pass 74 | 75 | self.writer.add_image(f"{stage}/boundary", plot_alignment_to_numpy(figure), step, dataformats='HWC') 76 | 77 | def save_attention(self, stage, step, w1, w2): 78 | self.writer.add_image(f"{stage}/w1", plot_alignment_to_numpy(w1.T.data.cpu().numpy()), step, dataformats='HWC') 79 | self.writer.add_image(f"{stage}/w2", plot_alignment_to_numpy(w2.T.data.cpu().numpy()), step, dataformats='HWC') 80 | 81 | if __name__ == '__main__': 82 | save = Save() 83 | save.logger.info('Test') 84 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import argparse 4 | import numpy as np 5 | from tqdm import tqdm 6 | from data.ecc import ECC 7 | from data.common import Collate 8 | from pathlib import Path 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--gpu', default=-1, type=int) 12 | parser.add_argument('--name', default=None) 13 | parser.add_argument('--load_model', default=None) 14 | parser.add_argument('--test_path', default='segmented-test1') 15 | parser.add_argument('--model', default='proposed', choices=['baseline', 'baseline_gru', 'proposed']) 16 | args = parser.parse_args() 17 | 18 | if args.gpu < 0: 19 | device = "cpu" 20 | else: 21 | device = "cuda:%d" % args.gpu 22 | 23 | if args.model == 'baseline': 24 | from hparams import baseline as hparams 25 | from model.baseline import Baseline, FakeMST 26 | model = Baseline(hparams) 27 | fake = FakeMST(hparams.fake_mst) 28 | load_model = Path(args.load_model) 29 | fake.load_state_dict(torch.load(load_model.parent / f'fake_{load_model.name}', map_location='cpu')) 30 | fake.to(device) 31 | elif args.model == 'baseline_gru': 32 | from hparams import baseline_gru as hparams 33 | from model.baseline_gru import BaselineGRU 34 | from model.baseline import FakeMST 35 | model = BaselineGRU(hparams) 36 | fake = FakeMST(hparams.fake_mst) 37 | load_model = Path(args.load_model) 38 | fake.load_state_dict(torch.load(load_model.parent / f'fake_{load_model.name}', map_location='cpu')) 39 | fake.to(device) 40 | elif args.model == 'proposed': 41 | from hparams import proposed as hparams 42 | from model.proposed import Proposed 43 | model = Proposed(hparams) 44 | 45 | test_dataset = ECC(args.test_path, chunk_size=hparams.input.chunk_size) 46 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size, shuffle=False, collate_fn=Collate(device)) 47 | 48 | model.load_state_dict(torch.load(args.load_model, map_location='cpu')) 49 | model.to(device) 50 | 51 | with torch.no_grad(): 52 | predicted_gst = [] 53 | predicted_wst = [] 54 | predicted_gst_only = [] 55 | for data in tqdm(test_dataloader): 56 | length, speaker, bert, gst, wst, gst_only, sbert = data 57 | current_length = [i[-1] for i in length] 58 | 59 | if args.model == 'baseline': 60 | history_gst_only = [i[:-1] for i in gst_only] 61 | current_bert = [i[-1] for i in bert] 62 | 63 | predicted_gst_only.append(model(length, speaker, bert, history_gst_only)) 64 | _predicted_gst, _predicted_wst = fake(current_length, current_bert, predicted_gst_only[-1].detach()) 65 | if args.model == 'baseline_gru': 66 | history_gst_only = [i[:-1] for i in gst_only] 67 | current_bert = [i[-1] for i in bert] 68 | 69 | predicted_gst_only.append(model(length, speaker, bert, history_gst_only, sbert)) 70 | _predicted_gst, _predicted_wst = fake(current_length, current_bert, predicted_gst_only[-1].detach()) 71 | if args.model == 'proposed': 72 | history_gst = [i[:-1] for i in gst] 73 | history_wst = [i[:-1] for i in wst] 74 | 75 | _predicted_gst, _predicted_wst = model(length, speaker, bert, history_gst, history_wst) 76 | 77 | predicted_gst.append(_predicted_gst) 78 | predicted_wst += _predicted_wst 79 | 80 | if args.model in ['baseline', 'baseline_gru']: 81 | predicted_gst_only = torch.cat(predicted_gst_only, dim=0) 82 | 83 | predicted_gst = torch.cat(predicted_gst, dim=0) 84 | 85 | current_utterances = [chunk[-1] for chunk in test_dataset.chunks] 86 | for i in range(len(current_utterances)): 87 | key = current_utterances[i].wav.stem 88 | if args.model in ['baseline', 'baseline_gru']: 89 | np.save(test_dataset.path / f'{key}.p_gst_only.npy', predicted_gst_only[i].cpu().numpy()) 90 | np.save(test_dataset.path / f'{key}.p_gst.npy', predicted_gst[i].cpu().numpy()) 91 | np.save(test_dataset.path / f'{key}.p_wst.npy', predicted_wst[i].cpu().numpy()) 92 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class BahdanauAttention(nn.Module): 5 | def __init__(self, query_dim, key_dim, value_dim, attention_dim): 6 | super(BahdanauAttention, self).__init__() 7 | self.q_layer = nn.Linear(query_dim, attention_dim) 8 | self.k_layer = nn.Linear(key_dim, attention_dim) 9 | self.o_layer = nn.Linear(attention_dim, 1) 10 | self.tanh = nn.Tanh() 11 | self.softmax = nn.Softmax(dim=1) 12 | 13 | def forward(self, query, key, value): 14 | query = torch.unsqueeze(query, 1) 15 | score = self.o_layer(self.tanh(self.q_layer(query) + self.k_layer(key))) 16 | attention_weights = self.softmax(score) 17 | context_vector = torch.bmm(attention_weights.transpose(1, 2), value) 18 | context_vector = torch.squeeze(context_vector) 19 | attention_weights = torch.squeeze(attention_weights) 20 | return context_vector, attention_weights 21 | 22 | class BidirectionalAttention(nn.Module): 23 | 24 | def __init__(self, k1_dim, k2_dim, v1_dim, v2_dim, attention_dim): 25 | super().__init__() 26 | self.k1_layer = nn.Linear(k1_dim, attention_dim) 27 | self.k2_layer = nn.Linear(k2_dim, attention_dim) 28 | self.score_layer = nn.Linear(attention_dim, 1) 29 | self.softmax1 = nn.Softmax(dim=-1) 30 | self.softmax2 = nn.Softmax(dim=-1) 31 | 32 | def forward(self, k1, k2, v1, v2, k1_lengths=None, k2_lengths=None): 33 | k1 = self.k1_layer(k1) 34 | k2 = self.k2_layer(k2) 35 | score = torch.bmm(k1, k2.transpose(1, 2)) 36 | 37 | if not k1_lengths is None or not k2_lengths is None: 38 | mask = torch.zeros(score.shape, dtype=torch.int).detach().to(score.device) 39 | for i, l in enumerate(k1_lengths): 40 | mask[i,l:,:] += 1 41 | for i, l in enumerate(k2_lengths): 42 | mask[i,:,l:] += 1 43 | mask = mask == 1 44 | score = score.clone().masked_fill_(mask, -float('inf')) 45 | 46 | w1 = self.softmax1(score.transpose(1, 2)) 47 | w2 = self.softmax2(score) 48 | 49 | o1 = torch.bmm(w1, v1) 50 | o2 = torch.bmm(w2, v2) 51 | 52 | w1 = [i[:l2, :l1] for i, l1, l2 in zip(w1, k1_lengths, k2_lengths)] 53 | w2 = [i[:l1, :l2] for i, l1, l2 in zip(w2, k1_lengths, k2_lengths)] 54 | score = [i[:l1, :l2] for i, l1, l2 in zip(score, k1_lengths, k2_lengths)] 55 | 56 | return o1, o2, w1, w2, score 57 | 58 | class BidirectionalAdditiveAttention(nn.Module): 59 | 60 | def __init__(self, k1_dim, k2_dim, v1_dim, v2_dim, attention_dim): 61 | super().__init__() 62 | self.k1_layer = nn.Linear(k1_dim, attention_dim) 63 | self.k2_layer = nn.Linear(k2_dim, attention_dim) 64 | self.score_layer = nn.Linear(attention_dim, 1) 65 | self.tanh = nn.Tanh() 66 | self.softmax1 = nn.Softmax(dim=-1) 67 | self.softmax2 = nn.Softmax(dim=-1) 68 | 69 | def forward(self, k1, k2, v1, v2, k1_lengths=None, k2_lengths=None): 70 | k1 = self.k1_layer(k1).repeat(k2.shape[1], 1, 1, 1).permute(1,2,0,3) 71 | k2 = self.k2_layer(k2).repeat(k1.shape[1], 1, 1, 1).permute(1,0,2,3) 72 | score = self.score_layer(self.tanh(k1 + k2)).squeeze(-1) 73 | 74 | if k1_lengths or k2_lengths: 75 | mask = torch.zeros(score.shape, dtype=torch.int).detach().to(score.device) 76 | for i, l in enumerate(k1_lengths): 77 | mask[i,l:,:] += 1 78 | for i, l in enumerate(k2_lengths): 79 | mask[i,:,l:] += 1 80 | mask = mask == 1 81 | score = score.masked_fill_(mask, -float('inf')) 82 | 83 | w1 = self.softmax1(score.transpose(1, 2)) 84 | w2 = self.softmax2(score) 85 | 86 | o1 = torch.bmm(w1, v1) 87 | o2 = torch.bmm(w2, v2) 88 | 89 | w1 = [i[:l2, :l1] for i, l1, l2 in zip(w1, k1_lengths, k2_lengths)] 90 | w2 = [i[:l1, :l2] for i, l1, l2 in zip(w2, k1_lengths, k2_lengths)] 91 | score = [i[:l1, :l2] for i, l1, l2 in zip(score, k1_lengths, k2_lengths)] 92 | 93 | return o1, o2, w1, w2, score 94 | -------------------------------------------------------------------------------- /model/graph.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | from torch.nn import Parameter 7 | from torch.nn import Parameter as Param 8 | from torch_scatter import scatter 9 | from torch_sparse import SparseTensor, masked_select_nnz, matmul 10 | 11 | from torch_geometric.typing import Adj, OptTensor 12 | 13 | from torch_geometric.nn.conv.rgcn_conv import RGCNConv, masked_edge_index 14 | 15 | def masked_edge_weight(edge_weight, edge_mask): 16 | return edge_weight[edge_mask] 17 | 18 | class RGCNConv(RGCNConv): 19 | 20 | def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], edge_index: Adj, edge_type: OptTensor = None, edge_weight = None): 21 | # Convert input features to a pair of node features or node indices. 22 | x_l: OptTensor = None 23 | if isinstance(x, tuple): 24 | x_l = x[0] 25 | else: 26 | x_l = x 27 | if x_l is None: 28 | x_l = torch.arange(self.in_channels_l, device=self.weight.device) 29 | 30 | x_r: Tensor = x_l 31 | if isinstance(x, tuple): 32 | x_r = x[1] 33 | 34 | size = (x_l.size(0), x_r.size(0)) 35 | 36 | if isinstance(edge_index, SparseTensor): 37 | edge_type = edge_index.storage.value() 38 | assert edge_type is not None 39 | 40 | # propagate_type: (x: Tensor) 41 | out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device) 42 | 43 | weight = self.weight 44 | for i in range(self.num_relations): 45 | tmp = masked_edge_index(edge_index, edge_type == i) 46 | tmp2 = masked_edge_weight(edge_weight, edge_type == i) 47 | if tmp2.shape[0] == 0: 48 | continue 49 | 50 | if x_l.dtype == torch.long: 51 | out += self.propagate(tmp, x=weight[i, x_l], size=size, edge_weight=tmp2) 52 | else: 53 | h = self.propagate(tmp, x=x_l, size=size, edge_weight=tmp2) 54 | out = out + (h @ weight[i]) 55 | 56 | root = self.root 57 | if root is not None: 58 | out += root[x_r] if x_r.dtype == torch.long else x_r @ root 59 | 60 | if self.bias is not None: 61 | out += self.bias 62 | 63 | return out 64 | 65 | def message(self, x_j: Tensor, edge_weight) -> Tensor: 66 | return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j 67 | 68 | class RGCNConv_FG(RGCNConv): 69 | 70 | def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]], edge_index: Adj, edge_type: OptTensor = None, edge_weight = None): 71 | # Convert input features to a pair of node features or node indices. 72 | x_l: OptTensor = None 73 | if isinstance(x, tuple): 74 | x_l = x[0] 75 | else: 76 | x_l = x 77 | if x_l is None: 78 | x_l = torch.arange(self.in_channels_l, device=self.weight.device) 79 | 80 | x_r: Tensor = x_l 81 | if isinstance(x, tuple): 82 | x_r = x[1] 83 | 84 | size = (x_l.size(0), x_r.size(0)) 85 | 86 | if isinstance(edge_index, SparseTensor): 87 | edge_type = edge_index.storage.value() 88 | assert edge_type is not None 89 | 90 | # propagate_type: (x: Tensor) 91 | out = torch.zeros(x_r.size(0), x_r.size(1), self.out_channels, device=x_r.device) 92 | 93 | weight = self.weight 94 | for i in range(self.num_relations): 95 | tmp = masked_edge_index(edge_index, edge_type == i) 96 | tmp2 = masked_edge_weight(edge_weight, edge_type == i) 97 | if tmp2.shape[0] == 0: 98 | continue 99 | 100 | if x_l.dtype == torch.long: 101 | out += self.propagate(tmp, x=weight[i, x_l], size=size, edge_weight=tmp2) 102 | else: 103 | h = self.propagate(tmp, x=x_l, size=size, edge_weight=tmp2) 104 | out = out + (h @ weight[i]) 105 | 106 | root = self.root 107 | if root is not None: 108 | out += root[x_r] if x_r.dtype == torch.long else x_r @ root 109 | 110 | if self.bias is not None: 111 | out += self.bias 112 | 113 | return out 114 | 115 | def message(self, x_j: Tensor, edge_weight) -> Tensor: 116 | return x_j if edge_weight is None else torch.bmm(edge_weight, x_j) 117 | -------------------------------------------------------------------------------- /model/tacotron.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch import nn 4 | from torch.nn.utils.rnn import pad_sequence 5 | 6 | class Prenet(nn.Module): 7 | 8 | def __init__(self, in_dim, sizes=[256, 128]): 9 | super(Prenet, self).__init__() 10 | in_sizes = [in_dim] + sizes[:-1] 11 | self.layers = nn.ModuleList( 12 | [nn.Linear(in_size, out_size) 13 | for (in_size, out_size) in zip(in_sizes, sizes)]) 14 | self.relu = nn.ReLU() 15 | self.dropout = nn.Dropout(0.5) 16 | 17 | def forward(self, inputs): 18 | for linear in self.layers: 19 | inputs = self.dropout(self.relu(linear(inputs))) 20 | return inputs 21 | 22 | class BatchNormConv1d(nn.Module): 23 | def __init__(self, in_dim, out_dim, kernel_size, stride, padding, 24 | activation=None): 25 | super(BatchNormConv1d, self).__init__() 26 | self.conv1d = nn.Conv1d(in_dim, out_dim, 27 | kernel_size=kernel_size, 28 | stride=stride, padding=padding, bias=False) 29 | self.bn = nn.BatchNorm1d(out_dim) 30 | self.activation = activation 31 | 32 | def forward(self, x): 33 | x = self.conv1d(x) 34 | if self.activation is not None: 35 | x = self.activation(x) 36 | return self.bn(x) 37 | 38 | class Highway(nn.Module): 39 | def __init__(self, in_size, out_size): 40 | super(Highway, self).__init__() 41 | self.H = nn.Linear(in_size, out_size) 42 | self.H.bias.data.zero_() 43 | self.T = nn.Linear(in_size, out_size) 44 | self.T.bias.data.fill_(-1) 45 | self.relu = nn.ReLU() 46 | self.sigmoid = nn.Sigmoid() 47 | 48 | def forward(self, inputs): 49 | H = self.relu(self.H(inputs)) 50 | T = self.sigmoid(self.T(inputs)) 51 | return H * T + inputs * (1.0 - T) 52 | 53 | class CBHG(nn.Module): 54 | """CBHG module: a recurrent neural network composed of: 55 | - 1-d convolution banks 56 | - Highway networks + residual connections 57 | - Bidirectional gated recurrent units 58 | """ 59 | 60 | def __init__(self, in_dim, K=16, projections=[128, 128]): 61 | super(CBHG, self).__init__() 62 | self.in_dim = in_dim 63 | self.relu = nn.ReLU() 64 | self.conv1d_banks = nn.ModuleList( 65 | [BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1, 66 | padding=k // 2, activation=self.relu) 67 | for k in range(1, K + 1)]) 68 | self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) 69 | 70 | in_sizes = [K * in_dim] + projections[:-1] 71 | activations = [self.relu] * (len(projections) - 1) + [None] 72 | self.conv1d_projections = nn.ModuleList( 73 | [BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, 74 | padding=1, activation=ac) 75 | for (in_size, out_size, ac) in zip( 76 | in_sizes, projections, activations)]) 77 | 78 | self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False) 79 | self.highways = nn.ModuleList( 80 | [Highway(in_dim, in_dim) for _ in range(4)]) 81 | 82 | self.gru = nn.GRU( 83 | in_dim, in_dim, 1, batch_first=True, bidirectional=True) 84 | 85 | def forward(self, inputs, input_lengths=None): 86 | # (B, T_in, in_dim) 87 | x = inputs 88 | 89 | # Needed to perform conv1d on time-axis 90 | # (B, in_dim, T_in) 91 | if x.size(-1) == self.in_dim: 92 | x = x.transpose(1, 2) 93 | 94 | T = x.size(-1) 95 | 96 | # (B, in_dim*K, T_in) 97 | # Concat conv1d bank outputs 98 | x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1) 99 | assert x.size(1) == self.in_dim * len(self.conv1d_banks) 100 | x = self.max_pool1d(x)[:, :, :T] 101 | 102 | for conv1d in self.conv1d_projections: 103 | x = conv1d(x) 104 | 105 | # (B, T_in, in_dim) 106 | # Back to the original shape 107 | x = x.transpose(1, 2) 108 | 109 | if x.size(-1) != self.in_dim: 110 | x = self.pre_highway(x) 111 | 112 | # Residual connection 113 | x += inputs 114 | for highway in self.highways: 115 | x = highway(x) 116 | 117 | if input_lengths is not None: 118 | x = nn.utils.rnn.pack_padded_sequence( 119 | x, input_lengths, batch_first=True, enforce_sorted=False) 120 | 121 | # (B, T_in, in_dim*2) 122 | outputs, _ = self.gru(x) 123 | 124 | if input_lengths is not None: 125 | outputs, _ = nn.utils.rnn.pad_packed_sequence( 126 | outputs, batch_first=True) 127 | 128 | return outputs 129 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | class Options(dict): 4 | 5 | def __getitem__(self, key): 6 | if not key in self.keys(): 7 | self.__setitem__(key, Options()) 8 | return super().__getitem__(key) 9 | 10 | def __getattr__(self, attr): 11 | if not attr in self.keys(): 12 | self[attr] = Options() 13 | return self[attr] 14 | 15 | def __setattr__(self, attr, value): 16 | self[attr] = value 17 | 18 | def __delattr__(self, attr): 19 | del self[attr] 20 | 21 | def __deepcopy__(self, memo=None): 22 | new = Options() 23 | for key in self.keys(): 24 | new[key] = deepcopy(self[key]) 25 | return new 26 | 27 | baseline = Options() 28 | baseline.max_epochs = 50 29 | baseline.batch_size = 32 30 | baseline.learning_rate = 1e-4 31 | 32 | baseline.input.chunk_size = 6 33 | baseline.input.bert_dim = 768 34 | baseline.input.gst_dim = 40 35 | baseline.input.wst_dim = 40 36 | 37 | baseline.global_encoder.input_dim = baseline.input.bert_dim 38 | baseline.global_encoder.prenet.sizes = [256, 128] 39 | baseline.global_encoder.cbhg.dim = 128 40 | baseline.global_encoder.cbhg.K = 16 41 | baseline.global_encoder.cbhg.projections = [128, 128] 42 | baseline.global_encoder.output_dim = baseline.global_encoder.cbhg.dim * 2 43 | 44 | baseline.dialogue_gcn.length = baseline.input.chunk_size - 1 45 | baseline.dialogue_gcn.global_feature_dim = baseline.global_encoder.output_dim + baseline.input.gst_dim 46 | baseline.dialogue_gcn.global_attention.input_dim = baseline.dialogue_gcn.global_feature_dim 47 | baseline.dialogue_gcn.global_attention.dim = 128 48 | baseline.dialogue_gcn.rgcn.dim = 128 49 | baseline.dialogue_gcn.gcn.dim = 128 50 | baseline.dialogue_gcn.output_dim = baseline.dialogue_gcn.gcn.dim + baseline.dialogue_gcn.global_feature_dim 51 | 52 | baseline.global_attention.dim = 128 53 | baseline.global_attention.query_dim = baseline.global_encoder.output_dim + baseline.input.chunk_size 54 | baseline.global_attention.key_dim = baseline.dialogue_gcn.output_dim 55 | baseline.global_attention.value_dim = baseline.dialogue_gcn.output_dim 56 | 57 | baseline.global_linear.input_dim = baseline.global_attention.value_dim + baseline.global_attention.query_dim 58 | baseline.global_linear.output_dim = baseline.input.gst_dim 59 | 60 | baseline.fake_mst.global_linear.input_dim = baseline.input.gst_dim 61 | baseline.fake_mst.global_linear.output_dim = baseline.input.gst_dim 62 | baseline.fake_mst.local_linear.input_dim = baseline.input.gst_dim + baseline.input.bert_dim 63 | baseline.fake_mst.local_linear.output_dim = baseline.input.wst_dim 64 | 65 | baseline_gru = deepcopy(baseline) 66 | baseline_gru.input.sbert_dim = 768 67 | baseline_gru.gru.input_dim = baseline_gru.input.sbert_dim 68 | baseline_gru.gru.dim = 256 69 | baseline_gru.global_linear.input_dim = baseline_gru.gru.dim + baseline_gru.input.sbert_dim + baseline_gru.global_encoder.output_dim + baseline.input.chunk_size 70 | 71 | proposed = deepcopy(baseline) 72 | 73 | proposed.global_encoder.input_dim = proposed.input.bert_dim + proposed.input.wst_dim 74 | proposed.local_encoder = deepcopy(proposed.global_encoder) 75 | 76 | proposed.dialogue_gcn.local_feature_dim = proposed.local_encoder.output_dim 77 | proposed.dialogue_gcn.local_attention.dim = 128 78 | proposed.dialogue_gcn.local_attention.k1_dim = proposed.dialogue_gcn.local_feature_dim 79 | proposed.dialogue_gcn.local_attention.k2_dim = proposed.dialogue_gcn.local_feature_dim 80 | proposed.dialogue_gcn.local_attention.v1_dim = proposed.dialogue_gcn.local_attention.k1_dim 81 | proposed.dialogue_gcn.local_attention.v2_dim = proposed.dialogue_gcn.local_attention.k2_dim 82 | proposed.dialogue_gcn.output_dim = proposed.dialogue_gcn.gcn.dim + proposed.dialogue_gcn.local_feature_dim 83 | proposed.dialogue_gcn.global_output_dim = proposed.dialogue_gcn.gcn.dim + proposed.dialogue_gcn.global_feature_dim 84 | 85 | proposed.post_global_encoder = deepcopy(proposed.global_encoder) 86 | proposed.post_global_encoder.input_dim = proposed.dialogue_gcn.output_dim 87 | 88 | #proposed.global_attention.key_dim = proposed.post_global_encoder.output_dim + proposed.dialogue_gcn.global_output_dim 89 | proposed.global_attention.key_dim = proposed.post_global_encoder.output_dim + proposed.dialogue_gcn.global_feature_dim 90 | proposed.global_attention.value_dim = proposed.global_attention.key_dim 91 | 92 | proposed.local_attention.dim = 128 93 | proposed.local_attention.k1_dim = proposed.local_encoder.output_dim 94 | proposed.local_attention.k2_dim = proposed.dialogue_gcn.output_dim 95 | proposed.local_attention.v1_dim = proposed.local_attention.k1_dim 96 | proposed.local_attention.v2_dim = proposed.local_attention.k2_dim 97 | 98 | proposed.global_linear.input_dim = proposed.global_attention.value_dim + proposed.global_attention.query_dim 99 | 100 | proposed.local_linear.input_dim = proposed.local_attention.v2_dim + proposed.local_encoder.output_dim 101 | proposed.local_linear.output_dim = proposed.input.wst_dim 102 | -------------------------------------------------------------------------------- /data/ecc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pathlib import Path 4 | from torch.nn.utils.rnn import pad_sequence 5 | 6 | def sort_key(i): 7 | return [int(k) for k in i.stem.split('.')[0].split('-')[:-1]] 8 | 9 | class Utterance: 10 | 11 | def __init__(self, text): 12 | with open(text) as f: 13 | self.text = f.readline() 14 | self.speaker = text.stem[-1] 15 | self.wav = text.parent / f'{text.stem}.wav' 16 | self.bert = text.parent / f'{text.stem}.bert.npy' 17 | self.sbert = text.parent / f'{text.stem}.sbert.npy' 18 | self.gst = text.parent / f'{text.stem}.gst.npy' 19 | self.wst = text.parent / f'{text.stem}.wst.npy' 20 | self.gst_only = text.parent / f'{text.stem}.gst_only.npy' 21 | 22 | class ECC(torch.utils.data.Dataset): 23 | 24 | def __init__(self, segmented, chunk_size=6): 25 | super().__init__() 26 | self.path = Path(segmented) 27 | self.chunk_size = chunk_size 28 | 29 | texts = sorted([i for i in self.path.rglob('*.txt')], key=sort_key) 30 | 31 | self.conversations = [] 32 | previous = None 33 | for i in texts: 34 | current = sort_key(i)[-1] 35 | 36 | if current - 1 != previous: 37 | self.conversations.append([]) 38 | self.conversations[-1].append(Utterance(i)) 39 | previous = current 40 | 41 | self.conversations = [[j for j in i if j.gst.exists() ] for i in self.conversations] 42 | self.conversations = [i for i in self.conversations if len(i) >= chunk_size] 43 | self.chunks = [i[j:j+chunk_size] for i in self.conversations for j in range(len(i)-chunk_size)] 44 | 45 | def __len__(self): 46 | return len(self.chunks) 47 | 48 | def __getitem__(self, index): 49 | speaker = [] 50 | length = [] 51 | bert = [] 52 | sbert = [] 53 | gst = [] 54 | wst = [] 55 | gst_only = [] 56 | speaker_cache = '' 57 | for i in self.chunks[index]: 58 | if not i.speaker in speaker_cache: 59 | speaker_cache += i.speaker 60 | speaker.append(speaker_cache.find(i.speaker)) 61 | bert.append(torch.as_tensor(np.load(i.bert))) 62 | sbert.append(torch.as_tensor(np.load(i.sbert))) 63 | length.append(bert[-1].shape[0]) 64 | gst.append(torch.as_tensor(np.load(i.gst))) 65 | wst.append(torch.as_tensor(np.load(i.wst))) 66 | gst_only.append(torch.as_tensor(np.load(i.gst_only))) 67 | speaker = np.array(speaker) 68 | length = np.array(length) 69 | bert = pad_sequence(bert, batch_first=True) 70 | sbert = torch.stack(sbert) 71 | gst = torch.stack(gst) 72 | wst = pad_sequence(wst, batch_first=True) 73 | gst_only = torch.stack(gst_only) 74 | return length, speaker, bert, gst, wst, gst_only, sbert 75 | 76 | def process_bert(model, tokenizer, utterance: Utterance): 77 | text = ''.join([i for i in utterance.text.lower() if i in "abcedfghijklmnopqrstuvwxyz' "]) 78 | words = text.split(' ') 79 | words = [i for i in words if i != ''] 80 | inputs = tokenizer(text, return_tensors="pt") 81 | outputs = model(**inputs) 82 | sbert = outputs.pooler_output[0].detach().numpy() 83 | bert = outputs.last_hidden_state[0][1:-1].detach().numpy() 84 | result = [] 85 | start = 0 86 | for word in words: 87 | subwords = tokenizer.tokenize(word) 88 | if len(subwords) > 1: 89 | result.append(np.mean(bert[start:start+len(subwords)], axis=0, keepdims=False)) 90 | elif len(subwords) == 1: 91 | result.append(bert[start]) 92 | start += len(subwords) 93 | try: 94 | np.save(utterance.bert, np.stack(result)) 95 | np.save(utterance.sbert, sbert) 96 | except: 97 | print(utterance.text, utterance.bert) 98 | 99 | if __name__ == '__main__': 100 | import sys 101 | from functools import partial 102 | from tqdm.contrib.concurrent import process_map, thread_map 103 | 104 | dataset = ECC(sys.argv[1], chunk_size=6) 105 | print(len(dataset.conversations), len(dataset.chunks)) 106 | 107 | if not dataset.chunks[0][0].bert.exists(): 108 | from transformers import AutoTokenizer, AutoModel 109 | model = AutoModel.from_pretrained("bert-base-uncased") 110 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 111 | #process_bert(model, tokenizer, [i for chunk in dataset.chunks for i in chunk][0]) 112 | thread_map(partial(process_bert, model, tokenizer), [i for chunk in dataset.chunks for i in chunk]) 113 | 114 | from .common import Collate 115 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=Collate('cuda:0'), drop_last=True) 116 | for batch in data_loader: 117 | for _list in batch: 118 | print([i.shape for i in _list]) 119 | print([i.dtype for i in _list]) 120 | break 121 | -------------------------------------------------------------------------------- /model/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn.conv import GraphConv 4 | 5 | from .tacotron import Prenet, CBHG 6 | from .graph import RGCNConv 7 | from .attention import BahdanauAttention 8 | 9 | def pad_sequence(sequences, **kwargs): 10 | return torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, **kwargs) 11 | 12 | class GlobalEncoder(nn.Module): 13 | 14 | def __init__(self, hparams): 15 | super().__init__() 16 | self.prenet = Prenet(hparams.input_dim, sizes=hparams.prenet.sizes) 17 | self.cbhg = CBHG(hparams.cbhg.dim, K=hparams.cbhg.K, projections=hparams.cbhg.projections) 18 | 19 | def forward(self, inputs, input_lengths=None): 20 | x = self.prenet(inputs) 21 | x = self.cbhg(x, input_lengths) 22 | return x 23 | 24 | class DialogueGCN(nn.Module): 25 | 26 | def __init__(self, hparams): 27 | super().__init__() 28 | self.global_attention = BahdanauAttention(hparams.global_attention.input_dim, hparams.global_attention.input_dim, hparams.global_attention.input_dim, hparams.global_attention.dim) 29 | self.rgcn = RGCNConv(hparams.global_feature_dim, hparams.rgcn.dim, 2 * hparams.length ** 2) 30 | self.gcn = GraphConv(hparams.rgcn.dim, hparams.gcn.dim) 31 | 32 | self.edges = [(i, j) for i in range(hparams.length) for j in range(hparams.length)] 33 | edge_types = [[f'{i}{j}0', f'{i}{j}1'] for i in range(hparams.length) for j in range(hparams.length)] 34 | edge_types = [j for i in edge_types for j in i] 35 | self.edge_type_to_id = {} 36 | for i, edge_type in enumerate(edge_types): 37 | self.edge_type_to_id[edge_type] = i 38 | 39 | def forward(self, global_features, speaker): 40 | edges = torch.tensor(self.edges).T.to(global_features.device) 41 | edge_type = [] 42 | for i in range(len(speaker)): 43 | for j in range(len(speaker)): 44 | direction = 0 if i < j else 1 45 | edge_type.append(self.edge_type_to_id[f'{speaker[i]}{speaker[j]}{direction}']) 46 | edge_type = torch.tensor(edge_type).to(global_features.device) 47 | 48 | global_attention_keys = torch.stack([global_features for i in range(len(speaker))]) 49 | _, global_attention_weights = self.global_attention(global_features, global_attention_keys, global_attention_keys) 50 | edge_weight = torch.flatten(global_attention_weights) 51 | 52 | x = self.rgcn(global_features, edges, edge_type, edge_weight=edge_weight) 53 | x = self.gcn(x, edges) 54 | return torch.cat([x, global_features], dim=-1) 55 | 56 | class Baseline(nn.Module): 57 | 58 | def __init__(self, hparams): 59 | super().__init__() 60 | self.global_encoder = GlobalEncoder(hparams.global_encoder) 61 | self.gcn = DialogueGCN(hparams.dialogue_gcn) 62 | self.global_attention = BahdanauAttention(hparams.global_attention.query_dim, hparams.global_attention.key_dim, hparams.global_attention.key_dim, hparams.global_attention.dim) 63 | self.global_linear = nn.Linear(hparams.global_linear.input_dim, hparams.global_linear.output_dim) 64 | self.mse = nn.MSELoss() 65 | 66 | def forward(self, length, speaker, bert, history_gst): 67 | batch_size = len(bert) 68 | 69 | global_features = [] 70 | for i in range(batch_size): 71 | length[i] = length[i].cpu() 72 | global_features.append(self.global_encoder(bert[i], length[i])) 73 | global_features[-1] = global_features[-1][range(global_features[-1].shape[0]), (length[i] - 1).long(), :] 74 | current_global_features = [i[-1] for i in global_features] 75 | history_global_features = [torch.cat([i[:-1], j], dim=-1) for i, j in zip(global_features[:], history_gst)] 76 | 77 | for i in range(batch_size): 78 | history_global_features[i] = self.gcn(history_global_features[i], speaker[i][:-1]) 79 | 80 | current_speaker = torch.stack([i[-1] for i in speaker]) 81 | current_speaker = nn.functional.one_hot(current_speaker, num_classes=len(speaker[0])) 82 | history_global_features = torch.stack(history_global_features) 83 | current_global_features = torch.stack(current_global_features) 84 | current_global_features = torch.cat([current_global_features, current_speaker], dim=-1) 85 | context_vector, _ = self.global_attention(current_global_features, history_global_features, history_global_features) 86 | context_vector = torch.cat([current_global_features, context_vector], dim=-1) 87 | current_gst = self.global_linear(context_vector) 88 | return current_gst 89 | 90 | def gst_loss(self, p_gst, gst): 91 | return self.mse(p_gst, gst) 92 | 93 | class FakeMST(nn.Module): 94 | 95 | def __init__(self, hparams): 96 | super().__init__() 97 | 98 | self.global_linear = nn.Linear(hparams.global_linear.input_dim, hparams.global_linear.output_dim) 99 | self.local_linear = nn.Linear(hparams.local_linear.input_dim, hparams.local_linear.output_dim) 100 | self.mse = nn.MSELoss() 101 | 102 | def forward(self, length, bert, gst): 103 | predicted_gst = self.global_linear(gst) 104 | bert = pad_sequence(bert) 105 | gst = torch.tile(torch.unsqueeze(gst, dim=1), (1, bert.shape[1], 1)) 106 | predicted_wst = self.local_linear(torch.cat([bert, gst], dim=-1)) 107 | predicted_wst = [i[:l] for i, l in zip(predicted_wst, length)] 108 | return predicted_gst, predicted_wst 109 | 110 | def gst_loss(self, p_gst, gst): 111 | return self.mse(p_gst, gst) 112 | 113 | def wst_loss(self, p_wst, wst): 114 | p_wst = torch.cat(p_wst, dim=0) 115 | wst = torch.cat(wst, dim=0) 116 | return self.mse(p_wst, wst) 117 | 118 | if __name__ == '__main__': 119 | from data.ecc import ECC 120 | from data.common import Collate 121 | from hparams import baseline 122 | 123 | device = 'cpu' 124 | data_loader = torch.utils.data.DataLoader(ECC('segmented-train'), batch_size=2, shuffle=True, collate_fn=Collate(device)) 125 | 126 | model = Baseline(baseline) 127 | fake = FakeMST(baseline.fake_mst) 128 | model.to(device) 129 | 130 | for batch in data_loader: 131 | length, speaker, bert, gst, wst, gst_only, sbert = batch 132 | history_gst = [i[:-1] for i in gst_only] 133 | predicted_gst = model(length, speaker, bert, history_gst) 134 | print(predicted_gst.shape) 135 | 136 | current_length = [i[-1] for i in length] 137 | current_wst = [i[-1, :l] for i, l in zip(wst, current_length)] 138 | current_bert = [i[-1] for i in bert] 139 | predicted_gst, predicted_wst = fake(current_length, current_bert, predicted_gst.detach()) 140 | print([i.shape for i in predicted_wst]) 141 | break 142 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import argparse 5 | from tqdm import tqdm 6 | from data.ecc import ECC 7 | from data.common import Collate 8 | from save import Save 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--gpu', default=0, type=int) 12 | parser.add_argument('--name', default=None) 13 | parser.add_argument('--load_model', default=None) 14 | parser.add_argument('--train_path', default='segmented-train') 15 | parser.add_argument('--test_path', default='segmented-test') 16 | parser.add_argument('--model', default='proposed', choices=['baseline_gru', 'baseline', 'proposed']) 17 | args = parser.parse_args() 18 | 19 | device = "cuda:%d" % args.gpu 20 | 21 | if args.model == 'baseline': 22 | from hparams import baseline as hparams 23 | from model.baseline import Baseline, FakeMST 24 | model = Baseline(hparams) 25 | fake = FakeMST(hparams.fake_mst) 26 | fake.to(device) 27 | elif args.model == 'baseline_gru': 28 | from hparams import baseline_gru as hparams 29 | from model.baseline_gru import BaselineGRU 30 | from model.baseline import FakeMST 31 | model = BaselineGRU(hparams) 32 | fake = FakeMST(hparams.fake_mst) 33 | fake.to(device) 34 | elif args.model == 'proposed': 35 | from hparams import proposed as hparams 36 | from model.proposed import Proposed 37 | model = Proposed(hparams) 38 | 39 | train_dataset = ECC(args.train_path, chunk_size=hparams.input.chunk_size) 40 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hparams.batch_size, shuffle=True, collate_fn=Collate(device), drop_last=True) 41 | test_dataset = ECC(args.test_path, chunk_size=hparams.input.chunk_size) 42 | test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=hparams.batch_size, shuffle=True, collate_fn=Collate(device)) 43 | 44 | if args.load_model: 45 | #model_dict = model.state_dict() 46 | #state_dict = torch.load(args.load_model) 47 | #state_dict = {k: v for k, v in state_dict.items() if not k.startswith('aligner.')} 48 | #model_dict.update(state_dict) 49 | #model.load_state_dict(model_dict) 50 | model.load_state_dict(torch.load(args.load_model, map_location='cpu')) 51 | 52 | model.to(device) 53 | optimizer = torch.optim.Adam(model.parameters(), lr=hparams.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) 54 | 55 | if args.name is None: 56 | args.name = args.model 57 | else: 58 | args.name = args.model + '_' + args.name 59 | 60 | save = Save(args.name) 61 | save.save_parameters(hparams) 62 | 63 | step = 1 64 | for epoch in range(hparams.max_epochs): 65 | save.logger.info('Epoch %d', epoch) 66 | 67 | batch = 1 68 | for data in train_dataloader: 69 | length, speaker, bert, gst, wst, gst_only, sbert = data 70 | current_gst = [i[-1] for i in gst] 71 | current_gst = torch.stack(current_gst) 72 | current_length = [i[-1] for i in length] 73 | current_wst = [i[-1, :l] for i, l in zip(wst, current_length)] 74 | 75 | if args.model == 'baseline': 76 | current_gst_only = [i[-1] for i in gst_only] 77 | current_gst_only = torch.stack(current_gst_only) 78 | current_bert = [i[-1] for i in bert] 79 | history_gst = [i[:-1] for i in gst_only] 80 | 81 | predicted_gst = model(length, speaker, bert, history_gst) 82 | gst_only_loss = model.gst_loss(predicted_gst, current_gst_only) 83 | save.writer.add_scalar(f'training/gst_only_loss', gst_only_loss, step) 84 | predicted_gst, predicted_wst = fake(current_length, current_bert, predicted_gst.detach()) 85 | wst_loss = fake.wst_loss(predicted_wst, current_wst) 86 | if args.model == 'baseline_gru': 87 | current_gst_only = [i[-1] for i in gst_only] 88 | current_gst_only = torch.stack(current_gst_only) 89 | current_bert = [i[-1] for i in bert] 90 | history_gst = [i[:-1] for i in gst_only] 91 | 92 | predicted_gst = model(length, speaker, bert, history_gst, sbert) 93 | gst_only_loss = model.gst_loss(predicted_gst, current_gst_only) 94 | save.writer.add_scalar(f'training/gst_only_loss', gst_only_loss, step) 95 | predicted_gst, predicted_wst = fake(current_length, current_bert, predicted_gst.detach()) 96 | wst_loss = fake.wst_loss(predicted_wst, current_wst) 97 | if args.model == 'proposed': 98 | history_gst = [i[:-1] for i in gst] 99 | history_wst = [i[:-1] for i in wst] 100 | 101 | predicted_gst, predicted_wst = model(length, speaker, bert, history_gst, history_wst) 102 | wst_loss = model.wst_loss(predicted_wst, current_wst) 103 | 104 | gst_loss = model.gst_loss(predicted_gst, current_gst) 105 | save.writer.add_scalar(f'training/gst_loss', gst_loss, step) 106 | save.writer.add_scalar(f'training/wst_loss', wst_loss, step) 107 | loss = gst_loss + wst_loss 108 | save.writer.add_scalar(f'training/loss', loss, step) 109 | 110 | save.save_log('training', epoch, batch, step, loss) 111 | 112 | optimizer.zero_grad() 113 | if args.model == 'baseline': 114 | gst_only_loss.backward() 115 | loss.backward() 116 | optimizer.step() 117 | 118 | step += 1 119 | batch += 1 120 | 121 | save.save_model(model, f'epoch{epoch}') 122 | if args.model in ['baseline', 'baseline_gru']: 123 | save.save_model(fake, f'fake_epoch{epoch}') 124 | 125 | with torch.no_grad(): 126 | predicted_gst = [] 127 | predicted_wst = [] 128 | predicted_gst_only = [] 129 | current_gst = [] 130 | current_wst = [] 131 | current_gst_only = [] 132 | for data in tqdm(test_dataloader): 133 | length, speaker, bert, gst, wst, gst_only, sbert = data 134 | current_gst += [i[-1] for i in gst] 135 | current_length = [i[-1] for i in length] 136 | current_wst += [i[-1, :l] for i, l in zip(wst, current_length)] 137 | 138 | if args.model == 'baseline': 139 | current_gst_only += [i[-1] for i in gst_only] 140 | history_gst_only = [i[:-1] for i in gst_only] 141 | current_bert = [i[-1] for i in bert] 142 | 143 | predicted_gst_only.append(model(length, speaker, bert, history_gst_only)) 144 | _predicted_gst, _predicted_wst = fake(current_length, current_bert, predicted_gst_only[-1].detach()) 145 | if args.model == 'baseline_gru': 146 | current_gst_only += [i[-1] for i in gst_only] 147 | history_gst_only = [i[:-1] for i in gst_only] 148 | current_bert = [i[-1] for i in bert] 149 | 150 | predicted_gst_only.append(model(length, speaker, bert, history_gst_only, sbert)) 151 | _predicted_gst, _predicted_wst = fake(current_length, current_bert, predicted_gst_only[-1].detach()) 152 | if args.model == 'proposed': 153 | history_gst = [i[:-1] for i in gst] 154 | history_wst = [i[:-1] for i in wst] 155 | 156 | _predicted_gst, _predicted_wst = model(length, speaker, bert, history_gst, history_wst) 157 | 158 | predicted_gst.append(_predicted_gst) 159 | predicted_wst += _predicted_wst 160 | 161 | if args.model in ['baseline', 'baseline_gru']: 162 | current_gst_only = torch.stack(current_gst_only) 163 | predicted_gst_only = torch.cat(predicted_gst_only, dim=0) 164 | gst_only_loss = model.gst_loss(predicted_gst_only, current_gst_only) 165 | save.writer.add_scalar(f'test/gst_only_loss', gst_only_loss, epoch) 166 | wst_loss = fake.wst_loss(predicted_wst, current_wst) 167 | if args.model == 'proposed': 168 | wst_loss = model.wst_loss(predicted_wst, current_wst) 169 | 170 | current_gst = torch.stack(current_gst) 171 | predicted_gst = torch.cat(predicted_gst, dim=0) 172 | gst_loss = model.gst_loss(predicted_gst, current_gst) 173 | save.writer.add_scalar(f'test/gst_loss', gst_loss, epoch) 174 | save.writer.add_scalar(f'test/wst_loss', wst_loss, epoch) 175 | 176 | loss = gst_loss + wst_loss 177 | save.save_log('test', epoch, batch, epoch, loss) 178 | -------------------------------------------------------------------------------- /model/proposed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from .tacotron import Prenet, CBHG 5 | from .graph import RGCNConv_FG 6 | from .attention import BahdanauAttention, BidirectionalAttention 7 | from .baseline import pad_sequence, GlobalEncoder 8 | 9 | def pad_attention_weights(attention_weights, length): 10 | result = [[torch.cat([j, torch.zeros((length - j.shape[0], j.shape[1]), device=j.device)], dim=0) for j in i] for i in attention_weights] 11 | result = [[torch.cat([j, torch.zeros((j.shape[0], length - j.shape[1]), device=j.device)], dim=1) for j in i] for i in result] 12 | result = torch.cat([torch.stack(i) for i in result]) 13 | return result 14 | 15 | class LocalAttention(BidirectionalAttention): 16 | 17 | def __init__(self, k1_dim, k2_dim, v1_dim, v2_dim, attention_dim): 18 | super().__init__(k1_dim, k2_dim, v1_dim, v2_dim, attention_dim) 19 | del(self.k2_layer) 20 | self.k2_layer = self.k1_layer 21 | 22 | class DialogueGCN_FG(nn.Module): 23 | 24 | def __init__(self, hparams): 25 | super().__init__() 26 | self.global_attention = BahdanauAttention(hparams.global_attention.input_dim, hparams.global_attention.input_dim, hparams.global_attention.input_dim, hparams.global_attention.dim) 27 | self.local_attention = LocalAttention(hparams.local_attention.k1_dim, hparams.local_attention.k2_dim, hparams.local_attention.v1_dim, hparams.local_attention.v2_dim, hparams.local_attention.dim) 28 | self.rgcn = RGCNConv_FG(hparams.local_feature_dim, hparams.rgcn.dim, 2 * hparams.length ** 2) 29 | self.gcn = RGCNConv_FG(hparams.rgcn.dim, hparams.gcn.dim, 1) 30 | 31 | self.edges = [(i, j) for i in range(hparams.length) for j in range(hparams.length)] 32 | edge_types = [[f'{i}{j}0', f'{i}{j}1'] for i in range(hparams.length) for j in range(hparams.length)] 33 | edge_types = [j for i in edge_types for j in i] 34 | self.edge_type_to_id = {} 35 | for i, edge_type in enumerate(edge_types): 36 | self.edge_type_to_id[edge_type] = i 37 | 38 | def forward(self, global_features, local_features, speaker, length): 39 | edges = torch.tensor(self.edges).T.to(global_features.device) 40 | edge_type = [] 41 | for i in range(len(speaker)): 42 | for j in range(len(speaker)): 43 | direction = 0 if i < j else 1 44 | edge_type.append(self.edge_type_to_id[f'{speaker[i]}{speaker[j]}{direction}']) 45 | edge_type = torch.tensor(edge_type).to(global_features.device) 46 | 47 | global_attention_keys = torch.stack([global_features for i in range(len(speaker))]) 48 | _, global_attention_weights = self.global_attention(global_features, global_attention_keys, global_attention_keys) 49 | global_attention_weights = torch.flatten(global_attention_weights) 50 | 51 | local_attention_weights = [] 52 | for i in range(len(speaker)): 53 | local_attention_k1 = torch.stack([local_features[i] for j in range(len(speaker))]) 54 | local_attention_k2 = torch.stack([local_features[j] for j in range(len(speaker))]) 55 | local_attention_k1_length = torch.stack([length[i] for j in range(len(speaker))]) 56 | local_attention_k2_length = torch.stack([length[j] for j in range(len(speaker))]) 57 | _, _, w1, w2, _ = self.local_attention(local_attention_k1, local_attention_k2, local_attention_k1, local_attention_k2, k1_lengths=local_attention_k1_length, k2_lengths=local_attention_k2_length) 58 | local_attention_weights.append(w1) 59 | #local_attention_weights = torch.cat(local_attention_weights) 60 | local_attention_weights = pad_attention_weights(local_attention_weights, local_features.shape[1]) 61 | 62 | edge_weight = torch.stack([global_attention_weights[i] * local_attention_weights[i] for i in range(len(self.edges))]) 63 | 64 | x = self.rgcn(local_features, edges, edge_type, edge_weight=edge_weight) 65 | 66 | edge_type = torch.zeros(edge_type.shape, device=edge_type.device) 67 | edge_weight = local_attention_weights 68 | x = self.gcn(x, edges, edge_type, edge_weight=edge_weight) 69 | 70 | return x 71 | 72 | class Proposed(nn.Module): 73 | 74 | def __init__(self, hparams): 75 | super().__init__() 76 | self.global_encoder = GlobalEncoder(hparams.global_encoder) 77 | self.local_encoder = GlobalEncoder(hparams.local_encoder) 78 | self.gcn = DialogueGCN_FG(hparams.dialogue_gcn) 79 | self.post_global_encoder = GlobalEncoder(hparams.post_global_encoder) 80 | self.global_attention = BahdanauAttention(hparams.global_attention.query_dim, hparams.global_attention.key_dim, hparams.global_attention.key_dim, hparams.global_attention.dim) 81 | self.local_attention = BidirectionalAttention(hparams.local_attention.k1_dim, hparams.local_attention.k2_dim, hparams.local_attention.v1_dim, hparams.local_attention.v2_dim, hparams.local_attention.dim) 82 | self.global_linear = nn.Linear(hparams.global_linear.input_dim, hparams.global_linear.output_dim) 83 | self.local_linear = nn.Linear(hparams.local_linear.input_dim, hparams.local_linear.output_dim) 84 | self.mse = nn.MSELoss() 85 | 86 | def forward(self, length, speaker, bert, history_gst, history_wst): 87 | local_features = [] 88 | global_features = [] 89 | batch_size = len(bert) 90 | for i in range(batch_size): 91 | length[i] = length[i].cpu() 92 | features = torch.cat([history_wst[i], torch.zeros((1, ) + history_wst[i].shape[1:], device=history_wst[i].device)]) 93 | features = torch.cat([features, bert[i]], dim=-1) 94 | local_features.append(self.local_encoder(features, length[i])) 95 | global_features.append(self.global_encoder(features, length[i])) 96 | global_features[-1] = global_features[-1][range(global_features[-1].shape[0]), (length[i] - 1).long(), :] 97 | 98 | current_global_features = [i[-1] for i in global_features] 99 | history_global_features = [torch.cat([i[:-1], j], dim=-1) for i, j in zip(global_features[:], history_gst)] 100 | current_local_features = [i[-1] for i in local_features] 101 | history_local_features = [i[:-1] for i in local_features] 102 | 103 | for i in range(batch_size): 104 | tmp = self.gcn(history_global_features[i], history_local_features[i], speaker[i][:-1], length[i]) 105 | history_local_features[i] = torch.cat([history_local_features[i], tmp], dim=-1) 106 | 107 | for i in range(batch_size): 108 | tmp = self.post_global_encoder(history_local_features[i], length[i][:-1]) 109 | tmp = tmp[range(tmp.shape[0]), (length[i][:-1] - 1).long(), :] 110 | history_global_features[i] = torch.cat([history_global_features[i], tmp], dim=-1) 111 | 112 | current_speaker = torch.stack([i[-1] for i in speaker]) 113 | current_speaker = nn.functional.one_hot(current_speaker, num_classes=len(speaker[0])) 114 | history_global_features = torch.stack(history_global_features) 115 | current_global_features = torch.stack(current_global_features) 116 | current_global_features = torch.cat([current_global_features, current_speaker], dim=-1) 117 | global_context_vector, global_attention_weights = self.global_attention(current_global_features, history_global_features, history_global_features) 118 | global_context_vector = torch.cat([current_global_features, global_context_vector], dim=-1) 119 | current_gst = self.global_linear(global_context_vector) 120 | #print(global_context_vector.shape, global_attention_weights.shape, current_gst.shape) 121 | 122 | local_attention_weights = [] 123 | for i in range(batch_size): 124 | local_attention_k1 = torch.stack([current_local_features[i] for j in range(len(history_local_features[i]))]) 125 | local_attention_k2 = history_local_features[i] 126 | local_attention_k1_length = torch.stack([length[i][-1] for j in range(len(history_local_features[i]))]) 127 | local_attention_k2_length = length[i][:-1] 128 | _, _, w1, w2, _ = self.local_attention(local_attention_k1, local_attention_k2, local_attention_k1, local_attention_k2, k1_lengths=local_attention_k1_length, k2_lengths=local_attention_k2_length) 129 | local_attention_weights.append(pad_attention_weights([w1], history_local_features[i].shape[1])) 130 | #print([i.shape for i in local_attention_weights]) 131 | #local_attention_weights = [pad_attention_weights(i, ) 132 | 133 | attention_weights = [] 134 | for i in range(batch_size): 135 | attention_weights.append(torch.stack([global_attention_weights[i][j] * local_attention_weights[i][j] for j in range(len(global_attention_weights[i]))])) 136 | #print([i.shape for i in attention_weights]) 137 | 138 | local_context_vector = [torch.sum(torch.bmm(attention_weights[i], history_local_features[i]), dim=0) for i in range(batch_size)] 139 | #print([i.shape for i in local_context_vector]) 140 | 141 | local_context_vector = [torch.cat([local_context_vector[i], current_local_features[i]], dim=-1) for i in range(batch_size)] 142 | #print([i.shape for i in local_context_vector]) 143 | 144 | local_context_vector = pad_sequence(local_context_vector) 145 | current_wst = self.local_linear(local_context_vector) 146 | 147 | current_wst = [current_wst[i, :length[i][-1]] for i in range(batch_size)] 148 | #print([i.shape for i in current_wst]) 149 | 150 | return current_gst, current_wst 151 | 152 | def gst_loss(self, p_gst, gst): 153 | return self.mse(p_gst, gst) 154 | 155 | def wst_loss(self, p_wst, wst): 156 | p_wst = torch.cat(p_wst, dim=0) 157 | wst = torch.cat(wst, dim=0) 158 | return self.mse(p_wst, wst) 159 | 160 | if __name__ == '__main__': 161 | from data.ecc import ECC 162 | from data.common import Collate 163 | from hparams import proposed 164 | 165 | device = 'cpu' 166 | data_loader = torch.utils.data.DataLoader(ECC('segmented-train'), batch_size=2, shuffle=True, collate_fn=Collate(device)) 167 | 168 | model = Proposed(proposed) 169 | model.to(device) 170 | 171 | for batch in data_loader: 172 | length, speaker, bert, gst, wst, _, _ = batch 173 | history_gst = [i[:-1] for i in gst] 174 | history_wst = [i[:-1] for i in wst] 175 | current_gst = [i[-1] for i in gst] 176 | current_gst = torch.stack(current_gst) 177 | current_length = [i[-1] for i in length] 178 | current_wst = [i[-1, :l] for i, l in zip(wst, current_length)] 179 | 180 | predicted_gst, predicted_wst = model(length, speaker, bert, history_gst, history_wst) 181 | print(model.gst_loss(predicted_gst, current_gst)) 182 | print(model.wst_loss(predicted_wst, current_wst)) 183 | break 184 | --------------------------------------------------------------------------------