├── README.md └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # LearnableUpsamplingLayer-Pytorch 2 | Pytorch implementation of LearnableUpsamplingLayer (NaturalSpeech, Tan et al., 2022) 3 | 4 | --- 5 | # Usage 6 | 7 | ``` python 8 | ''' 9 | y : phoneme hidden sequence [N, C, T] 10 | duration_pred : phoneme duration [N, T] 11 | src_mask : mask of phoneme hidden sequence [N, 1, T] 12 | 13 | ''' 14 | 15 | from model import LearnableUpsamplingLayer 16 | 17 | lu = LearnableUpsamplingLayer(in_channels, out_channels) 18 | y, mel_mask = lu(y, duration_pred, src_mask) 19 | 20 | ``` 21 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | from collections import OrderedDict 7 | 8 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | 10 | 11 | class LearnableUpsamplingLayer(torch.nn.Module): 12 | def __init__(self, in_channels=192, out_channels=192, attn_dims=4, content_dims=2): 13 | super(LearnableUpsamplingLayer, self).__init__() 14 | 15 | self.proj_in = nn.Sequential(OrderedDict([ 16 | ("linear", nn.Linear(in_channels, in_channels)), 17 | ("swish", nn.SiLU()) 18 | ])) 19 | 20 | self.conv_norm_swish = nn.Sequential(OrderedDict([ 21 | ("conv_ln", ConvLN(in_channels, 8, kernel_size=3, padding=1)), 22 | ("swish_2", nn.SiLU()) 23 | ])) 24 | 25 | self.W_proj = nn.Sequential(OrderedDict([ 26 | ("linaer_1", nn.Linear(10, 27 | 10)), 28 | ("swish", nn.SiLU()), 29 | ("linear_2", nn.Linear(10, attn_dims)), 30 | ("softmax", nn.Softmax(dim=2)) 31 | ])) 32 | 33 | self.C_proj = nn.Sequential(OrderedDict([ 34 | ("linaer_1", nn.Linear(10, 35 | 10)), 36 | ("swish_1", nn.SiLU()), 37 | ("linear_2", nn.Linear(10,content_dims)), 38 | ("swish_2", nn.SiLU()) 39 | ])) 40 | 41 | self.proj_WH = nn.Linear(attn_dims*in_channels, out_channels) 42 | self.proj_WC = nn.Linear(attn_dims*content_dims, out_channels) 43 | 44 | self.proj_out= nn.Conv1d(out_channels, out_channels, 1) 45 | 46 | def forward(self, H, d, src_mask): 47 | ''' 48 | H : phoneme hidden sequence ; [b, h_t=192, t_t] 49 | d : phoneme duration ; [b, t_t] 50 | src_mask : phoneme-level mask ; [b, 1, t_text] 51 | mel_mask : frame-level mask ; [b, 1, t_mel] 52 | ''' 53 | if len(src_mask.size()) == 2: 54 | src_mask = src_mask.unsqueeze(1) 55 | 56 | S, E, mel_mask = self.token_boundary_grid(d, src_mask) # [b, t_s, t_t, 1] 57 | b, t_text, t_mel = src_mask.shape[0], src_mask.shape[-1], mel_mask.shape[-1] 58 | 59 | x = torch.transpose(H, 1, 2) 60 | x = self.proj_in(x) # [b, t_t, h_t] 61 | x = torch.transpose(x, 1, 2) 62 | x = self.conv_norm_swish(x) # [b, 8, t_t] 63 | 64 | x = x.unsqueeze(1) # [b, 1, 8, t_t] 65 | x = torch.repeat_interleave(x, mel_mask.shape[-1], dim=1) # [b, t_s, 8, t_t] 66 | x = torch.transpose(x, 2, 3) # [b, t_s, t_t, 8] 67 | x = torch.cat((S, E, x), dim = 3) # [b, t_s, t_t, 10] 68 | 69 | W = self.W_proj(x) # [b, t_s, t_t, attn_dims = 4] 70 | W = W.permute(0,3,1,2) # [b, attn_dims, t_s, t_t] 71 | C = self.C_proj(x) # [b, t_s, t_t, content_dims = 2] 72 | 73 | WC = torch.einsum('bqmn,bmnp->bmqp',W,C) 74 | WC = WC.view(b, t_mel, -1) # [b, t_s, attn_dims * content_dims] 75 | WC = self.proj_WC(WC) # [b, t_s, out_channels] 76 | 77 | WH = torch.einsum('bqmn,bhn->bmqh',W,H) 78 | WH = WH.view(b, t_mel, -1) # [b, t_s, attn_dims * in_channels] 79 | WH = self.proj_WH(WH) # [b, t_s, out_channels] 80 | 81 | O = WC + WH # [b, t_s, out_channels] 82 | 83 | O = torch.transpose(O, 1, 2) 84 | 85 | O = self.proj_out(O) * mel_mask 86 | 87 | return O, ~mel_mask.squeeze() 88 | 89 | def token_boundary_grid(self, dur, src_mask): 90 | 91 | 92 | mel_len = torch.sum(dur, 1).long() 93 | max_mel_len = torch.max(mel_len).long() 94 | 95 | mel_mask = self.get_mask_from_lengths(mel_len) 96 | mel_mask = mel_mask.unsqueeze(1) 97 | b, t_text, t_mel = src_mask.shape[0], src_mask.shape[-1], mel_mask.shape[-1] 98 | 99 | token_boundary_mask = (torch.unsqueeze(src_mask, 2) * torch.unsqueeze(~mel_mask, -1)).squeeze() 100 | 101 | i = torch.arange(1, max_mel_len + 1).unsqueeze(0).to(dur.device) 102 | i = torch.repeat_interleave(i, b, dim=0).unsqueeze(-1) 103 | 104 | S_d = torch.cat((torch.zeros(b,1).to(dur.device),dur[:,:-1]), dim=1) 105 | S_d = torch.cumsum(S_d, dim=1).unsqueeze(1) 106 | S_d = torch.repeat_interleave(S_d, max_mel_len, dim=1) 107 | S_d = S_d.view(b, t_mel, t_text) 108 | 109 | E_d = torch.cumsum(dur, dim=1).unsqueeze(1) 110 | E_d = torch.repeat_interleave(E_d, max_mel_len, dim=1) 111 | E_d = E_d.view(b, t_mel, t_text) 112 | 113 | S = (i - S_d) * token_boundary_mask 114 | E = (E_d - i) * token_boundary_mask 115 | 116 | return S.unsqueeze(-1), E.unsqueeze(-1), ~mel_mask 117 | 118 | def get_mask_from_lengths(self, lengths, max_len=None): 119 | batch_size = lengths.shape[0] 120 | if max_len is None: 121 | max_len = torch.max(lengths).item() 122 | 123 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device) 124 | mask = (ids >= lengths.unsqueeze(1).expand(-1, max_len)) 125 | 126 | return mask 127 | 128 | class ConvLN(nn.Module): 129 | def __init__(self, in_channels, out_channels, kernel_size, ln_channels=8, eps=1e-5, padding=1.): 130 | super().__init__() 131 | self.channels = ln_channels 132 | self.eps = eps 133 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) 134 | 135 | self.gamma = nn.Parameter(torch.ones(ln_channels)) 136 | self.beta = nn.Parameter(torch.zeros(ln_channels)) 137 | 138 | def forward(self, x): 139 | x = self.conv(x) 140 | x = x.transpose(1, -1) 141 | x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) 142 | return x.transpose(1, -1) 143 | --------------------------------------------------------------------------------