├── cover ├── __init__.py ├── utils │ ├── __init__.py │ ├── audio.py │ ├── noise.py │ ├── utils.py │ └── spec.py ├── models │ ├── __init__.py │ ├── unet.py │ ├── pitch.py │ └── generate.py ├── load.py └── run.py ├── image ├── logo.png ├── net1.png ├── net2.png ├── tnn.png ├── model.png └── website.png ├── python ├── models │ ├── __init__.py │ ├── load.py │ ├── chord_net.py │ ├── beat_net.py │ ├── segment_net.py │ ├── unet.py │ ├── pitch_net.py │ └── transformers.py ├── audio.py ├── utils.py ├── spec.py ├── common.py ├── aitabs.py └── modulation.py └── README.md /cover/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cover/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/logo.png -------------------------------------------------------------------------------- /image/net1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/net1.png -------------------------------------------------------------------------------- /image/net2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/net2.png -------------------------------------------------------------------------------- /image/tnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/tnn.png -------------------------------------------------------------------------------- /image/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/model.png -------------------------------------------------------------------------------- /image/website.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/website.png -------------------------------------------------------------------------------- /cover/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .pitch import PitchNet 2 | from .generate import CombineNet 3 | from .unet import UNets 4 | -------------------------------------------------------------------------------- /python/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .beat_net import BeatNet 2 | from .unet import UNets 3 | from .chord_net import ChordNet 4 | from .pitch_net import PitchNet 5 | from .segment_net import SegmentNet 6 | from .load import get_model 7 | -------------------------------------------------------------------------------- /python/audio.py: -------------------------------------------------------------------------------- 1 | import torchaudio 2 | import torch 3 | 4 | 5 | def read_wav(wav_fp, sample_rate=44100, n_channel=2, device='cpu'): 6 | waveform, sr = torchaudio.load(wav_fp) 7 | assert waveform.ndim == 2 8 | assert waveform.shape[0] == n_channel 9 | if sr != sample_rate: 10 | waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform) 11 | waveform = waveform.to(device) 12 | return waveform, sample_rate 13 | 14 | 15 | def gen_wav(sample_rate=44100, n_channel=2, duration=120, device='cpu'): 16 | waveform = torch.randn(n_channel, sample_rate * duration) 17 | waveform = waveform.to(device) 18 | return waveform, sample_rate 19 | 20 | 21 | def write_wav(path, wav, sample_rate=44100): 22 | torchaudio.save(path, wav, sample_rate) 23 | -------------------------------------------------------------------------------- /cover/load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchaudio.pipelines import (HUBERT_BASE, HUBERT_LARGE) 3 | 4 | 5 | def get_model(model, config, model_path=None, is_train=True, device='cpu'): 6 | net = model(**config) 7 | if model_path: 8 | net.load_state_dict(torch.load(model_path, map_location=device)) 9 | net.to(device) 10 | 11 | if is_train: 12 | net.train() 13 | else: 14 | net.eval() 15 | 16 | return net 17 | 18 | 19 | def get_hubert_model(name='base', dl_kwargs=None, device='cpu'): 20 | if name == 'base': 21 | bundle = HUBERT_BASE 22 | elif name == 'large': 23 | bundle = HUBERT_LARGE 24 | else: 25 | raise ValueError(f'Invalid model name: {name}') 26 | 27 | model = bundle.get_model(dl_kwargs=dl_kwargs) 28 | model.to(device) 29 | model.eval() 30 | return model 31 | -------------------------------------------------------------------------------- /python/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from common import CHORD_LABELS, SEGMENT_LABELS 3 | 4 | 5 | def build_masked_stft(masks, stft_feature, n_fft=4096): 6 | out = [] 7 | for i in range(len(masks)): 8 | mask = masks[i, :, :, :] 9 | pad_num = n_fft // 2 + 1 - mask.size(-1) 10 | mask = F.pad(mask, (0, pad_num, 0, 0, 0, 0)) 11 | inst_stft = mask.type(stft_feature.dtype) * stft_feature 12 | out.append(inst_stft) 13 | return out 14 | 15 | 16 | def get_chord_name(chord_idx_list): 17 | chords = [CHORD_LABELS[idx] for idx in chord_idx_list] 18 | return chords 19 | 20 | 21 | def get_segment_name(segments): 22 | segments = [SEGMENT_LABELS[idx] for idx in segments] 23 | return segments 24 | 25 | 26 | def get_lyrics(waveform, sr, cfg): 27 | # asr and wav2vec2 28 | raise NotImplementedError() 29 | -------------------------------------------------------------------------------- /python/models/load.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models.unet import UNets 4 | from models.beat_net import BeatNet 5 | from models.chord_net import ChordNet 6 | from models.pitch_net import PitchNet 7 | from models.segment_net import SegmentNet 8 | 9 | 10 | def get_model_cls(s): 11 | if s == 'unet': 12 | return UNets 13 | elif s == 'beat': 14 | return BeatNet 15 | elif s == 'chord': 16 | return ChordNet 17 | elif s == 'pitch': 18 | return PitchNet 19 | elif s == 'segment': 20 | return SegmentNet 21 | else: 22 | raise ValueError(f'Invalid model name: {s}') 23 | 24 | 25 | def get_model(model, config, model_path=None, is_train=True, device='cpu'): 26 | if isinstance(model, str): 27 | model = get_model_cls(model) 28 | 29 | net = model(**config) 30 | if model_path: 31 | net.load_state_dict(torch.load(model_path, map_location=device)) 32 | net.to(device) 33 | 34 | if is_train: 35 | net.train() 36 | else: 37 | net.eval() 38 | 39 | return net 40 | -------------------------------------------------------------------------------- /cover/utils/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | 4 | 5 | def load_waveform(wav_fp, samplerate=16000, n_channel=2, device='cpu'): 6 | waveform, sr = torchaudio.load(wav_fp) 7 | assert waveform.ndim == 2 8 | if waveform.shape[0] != n_channel: 9 | if n_channel == 1: 10 | waveform = torch.mean(waveform, dim=0, keepdim=True) 11 | elif n_channel <= waveform.shape[0]: 12 | waveform = waveform[:n_channel] 13 | else: 14 | raise ValueError(f'Invalid number of channels: {waveform.shape[0]}') 15 | if sr != samplerate: 16 | waveform = torchaudio.transforms.Resample(sr, samplerate)(waveform) 17 | waveform = waveform.to(device) 18 | return waveform, samplerate 19 | 20 | 21 | def gen_waveform(samplerate=16000, n_channel=2, duration=120, device='cpu'): 22 | waveform = torch.randn(n_channel, samplerate * duration) 23 | waveform = waveform.to(device) 24 | return waveform, samplerate 25 | 26 | 27 | def save_waveform(path, wav, samplerate=16000): 28 | torchaudio.save(path, wav, samplerate) 29 | -------------------------------------------------------------------------------- /cover/utils/noise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def gaussian_noise(x, mean=0, std=0.1): 5 | noise = torch.normal(mean, std, x.size()).to(x.device) 6 | noisy_x = x + noise 7 | return noisy_x 8 | 9 | 10 | def markov_noise(x, transition_matrix): 11 | num_bins, num_frames = x.shape 12 | noisy_x = x.clone() 13 | state = torch.randint(0, num_bins, (1,)).item() 14 | 15 | for t in range(num_frames): 16 | noisy_x[:, t] = x[:, t] * (1 + transition_matrix[state]) 17 | state = torch.multinomial(transition_matrix[state], 1).item() 18 | 19 | return noisy_x 20 | 21 | 22 | def random_walk_noise(x, step_size=0.1): 23 | noisy_x = x.clone() 24 | random_walk = torch.FloatTensor(x.size()).uniform_(-step_size, step_size).to(x.device) 25 | noisy_x += random_walk 26 | return noisy_x 27 | 28 | 29 | def spectral_folding(x, fold_frequency=0.5): 30 | num_bins, num_frames = x.shape 31 | fold_bin = int(fold_frequency * num_bins) 32 | 33 | folded_x = x.clone() 34 | folded_x[:fold_bin, :] += x[-fold_bin:, :] 35 | folded_x[-fold_bin:, :] = 0 36 | 37 | return folded_x 38 | -------------------------------------------------------------------------------- /cover/utils/utils.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | 4 | def generate_uuid(version=4, namespace=None, name=None): 5 | """ 6 | Generate a UUID string based on the specified version. 7 | 8 | Parameters: 9 | - version (int): The UUID version (1, 3, 4, or 5). Default is 4. 10 | - namespace (str): The namespace for UUID3 and UUID5. Must be a valid UUID string. 11 | - name (str): The name for UUID3 and UUID5. 12 | 13 | Returns: 14 | - str: The generated UUID string. 15 | """ 16 | if version == 1: 17 | return str(uuid.uuid1()) 18 | elif version == 3: 19 | if namespace is None or name is None: 20 | raise ValueError("Namespace and name must be provided for UUID3.") 21 | namespace_uuid = uuid.UUID(namespace) 22 | return str(uuid.uuid3(namespace_uuid, name)) 23 | elif version == 4: 24 | return str(uuid.uuid4()) 25 | elif version == 5: 26 | if namespace is None or name is None: 27 | raise ValueError("Namespace and name must be provided for UUID5.") 28 | namespace_uuid = uuid.UUID(namespace) 29 | return str(uuid.uuid5(namespace_uuid, name)) 30 | else: 31 | raise ValueError("Unsupported UUID version. Supported versions are 1, 3, 4, and 5.") 32 | -------------------------------------------------------------------------------- /python/models/chord_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.transformers import BaseTransformer 4 | 5 | 6 | class ChordNet(nn.Module): 7 | def __init__(self, 8 | n_freq=2048, 9 | n_classes=122, 10 | n_group=32, 11 | f_layers=5, 12 | f_nhead=8, 13 | t_layers=5, 14 | t_nhead=8, 15 | d_layers=5, 16 | d_nhead=8, 17 | dropout=0.5, 18 | *args, **kwargs): 19 | super().__init__() 20 | 21 | self.transformers = BaseTransformer(n_freq=n_freq, n_group=n_group, 22 | f_layers=f_layers, f_nhead=f_nhead, f_dropout=dropout, 23 | t_layers=t_layers, t_nhead=t_nhead, t_dropout=dropout, 24 | d_layers=d_layers, d_nhead=d_nhead, d_dropout=dropout) 25 | self.dropout = nn.Dropout(dropout) 26 | self.fc = nn.Linear(n_freq, n_classes) 27 | 28 | def forward(self, x, weight=None): 29 | # x shape: (batch, channel, time, freq) 30 | output, weight_logits = self.transformers(x, weight) 31 | output = self.dropout(output) 32 | output = self.fc(output) 33 | output = output.argmax(dim=-1) 34 | return output, weight_logits 35 | 36 | 37 | if __name__ == '__main__': 38 | model = ChordNet() 39 | print(model) 40 | x = torch.randn(6, 2, 256, 2048) 41 | y, weight = model(x) 42 | print(y.shape, weight.shape) 43 | -------------------------------------------------------------------------------- /python/spec.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | def stft(x, n_fft=4096, hop_length=1024, pad=True): 5 | z = th.stft(x, 6 | n_fft, 7 | hop_length or n_fft // 4, 8 | window=th.hann_window(n_fft).to(device=x.device), 9 | win_length=n_fft, 10 | normalized=True, 11 | center=pad, 12 | return_complex=True, 13 | pad_mode='reflect') 14 | z = th.transpose(z, 1, 2) 15 | return z 16 | 17 | 18 | def istft(stft_feature, n_fft=4096, hop_length=1024, pad=True): 19 | stft_feature = th.transpose(stft_feature, 1, 2) 20 | waveform = th.istft(stft_feature, 21 | n_fft, 22 | hop_length, 23 | window=th.hann_window(n_fft).to(device=stft_feature.device), 24 | win_length=n_fft, 25 | normalized=True, 26 | center=pad) 27 | return waveform 28 | 29 | 30 | def get_spec(waveform, cfg): 31 | spec = stft(waveform, n_fft=cfg['n_fft'], hop_length=cfg['hop_length'], pad=cfg['pad']) 32 | return spec # channel, freq, time 33 | 34 | 35 | def get_specs(waveforms, cfg): 36 | # waveforms shape: sources, channel, time 37 | S, C, T = waveforms.shape 38 | _waveforms = waveforms.view(S * C, T) 39 | specs = stft(_waveforms, n_fft=cfg['n_fft'], hop_length=cfg['hop_length'], pad=cfg['pad']) 40 | return specs.view(S, C, specs.shape[-2], specs.shape[-1]) # sources, channel, freq, time 41 | 42 | 43 | def get_mixed_spec(waveforms, cfg): 44 | mixed_waveform = th.sum(waveforms, dim=0) 45 | mixed_spec = stft(mixed_waveform, n_fft=cfg['n_fft'], hop_length=cfg['hop_length'], pad=cfg['pad']) 46 | return mixed_spec # channel, freq, time 47 | -------------------------------------------------------------------------------- /python/models/beat_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.transformers import BaseTransformer 4 | 5 | 6 | class BeatNet(nn.Module): 7 | def __init__(self, 8 | source=3, 9 | n_classes=3, 10 | weights=(0.4, 0.3, 0.3), 11 | n_freq=2048, 12 | n_group=32, 13 | f_layers=2, 14 | f_nhead=4, 15 | t_layers=2, 16 | t_nhead=4, 17 | d_layers=2, 18 | d_nhead=8, 19 | dropout=0.5, 20 | *args, **kwargs 21 | ): 22 | super().__init__() 23 | self.weights = weights 24 | self.transformer_layers = nn.ModuleList() 25 | for _ in range(source): 26 | _layer = BaseTransformer(n_freq=n_freq, n_group=n_group, 27 | f_layers=f_layers, f_nhead=f_nhead, f_dropout=dropout, 28 | t_layers=t_layers, t_nhead=t_nhead, t_dropout=dropout, 29 | d_layers=d_layers, d_nhead=d_nhead, d_dropout=dropout) 30 | self.transformer_layers.append(_layer) 31 | self.dropout = nn.Dropout(dropout) 32 | self.beat_fc = nn.Linear(n_freq, n_classes) 33 | 34 | self.reset_parameters(0.05) 35 | 36 | def reset_parameters(self, confidence): 37 | self.beat_fc.bias.data.fill_(-torch.log(torch.tensor(1 / confidence - 1))) 38 | 39 | def forward(self, inp): 40 | # shape: (batch, source, channel, time, freq) 41 | 42 | y_list = [] 43 | logits_weight_list = [] 44 | for i, layer in enumerate(self.transformer_layers): 45 | x = inp[:, i, :, :, :] 46 | x, _f = layer(x) 47 | w = self.weights[i] 48 | x = x * w 49 | y_list.append(x) 50 | logits_weight_list.append(_f * w) 51 | y = torch.sum(torch.stack(y_list, dim=0), dim=0) 52 | logits_weight = torch.sum(torch.stack(logits_weight_list, dim=0), dim=0) 53 | 54 | y = self.dropout(y) 55 | beats = self.beat_fc(y) 56 | beats = torch.argmax(beats, dim=-1) 57 | return beats, logits_weight 58 | 59 | 60 | if __name__ == '__main__': 61 | model = BeatNet() 62 | print(model) 63 | x = torch.randn(6, 3, 2, 256, 2048) 64 | b, weight = model(x) 65 | print(x.shape, b.shape, weight.shape) 66 | -------------------------------------------------------------------------------- /python/common.py: -------------------------------------------------------------------------------- 1 | CHORD_LABELS = [ 2 | 'N', 'X', 3 | # maj 4 | 'C:maj', 'C#:maj', 'D:maj', 'D#:maj', 'E:maj', 'F:maj', 'F#:maj', 'G:maj', 'G#:maj', 'A:maj', 'A#:maj', 'B:maj', 5 | # min 6 | 'C:min', 'C#:min', 'D:min', 'D#:min', 'E:min', 'F:min', 'F#:min', 'G:min', 'G#:min', 'A:min', 'A#:min', 'B:min', 7 | # 7 8 | 'C:7', 'C#:7', 'D:7', 'D#:7', 'E:7', 'F:7', 'F#:7', 'G:7', 'G#:7', 'A:7', 'A#:7', 'B:7', 9 | # maj7 10 | 'C:maj7', 'C#:maj7', 'D:maj7', 'D#:maj7', 'E:maj7', 'F:maj7', 'F#:maj7', 'G:maj7', 'G#:maj7', 'A:maj7', 'A#:maj7', 11 | 'B:maj7', 12 | # min7 13 | 'C:min7', 'C#:min7', 'D:min7', 'D#:min7', 'E:min7', 'F:min7', 'F#:min7', 'G:min7', 'G#:min7', 'A:min7', 'A#:min7', 14 | 'B:min7', 15 | # 6 16 | 'C:6', 'C#:6', 'D:6', 'D#:6', 'E:6', 'F:6', 'F#:6', 'G:6', 'G#:6', 'A:6', 'A#:6', 'B:6', 17 | # m6 18 | 'C:m6', 'C#:m6', 'D:m6', 'D#:m6', 'E:m6', 'F:m6', 'F#:m6', 'G:m6', 'G#:m6', 'A:m6', 'A#:m6', 'B:m6', 19 | # sus2 20 | 'C:sus2', 'C#:sus2', 'D:sus2', 'D#:sus2', 'E:sus2', 'F:sus2', 'F#:sus2', 'G:sus2', 'G#:sus2', 'A:sus2', 'A#:sus2', 21 | 'B:sus2', 22 | # sus4 23 | 'C:sus4', 'C#:sus4', 'D:sus4', 'D#:sus4', 'E:sus4', 'F:sus4', 'F#:sus4', 'G:sus4', 'G#:sus4', 'A:sus4', 'A#:sus4', 24 | 'B:sus4', 25 | # 5 26 | 'C:5', 'C#:5', 'D:5', 'D#:5', 'E:5', 'F:5', 'F#:5', 'G:5', 'G#:5', 'A:5', 'A#:5', 'B:5', 27 | ] 28 | 29 | SEGMENT_LABELS = [ 30 | 'start', 31 | 'end', 32 | 'intro', 33 | 'outro', 34 | 'verse', 35 | 'chorus', 36 | 'solo', 37 | 'break', 38 | 'bridge', 39 | 'inst', 40 | ] 41 | 42 | MAJOR_CHORDS = [ 43 | ["C", "Dm", "Em", "F", "G", "Am", "G7"], 44 | ["G", "Am", "Bm", "C", "D", "Em", "D7"], 45 | ["D", "Em", "F#m", "G", "A", "Bm", "A7"], 46 | ["A", "Bm", "C#m", "D", "E", "F#m", "E7"], 47 | ["E", "F#m", "G#m", "A", "B", "C#m", "B7"], 48 | ["F", "Gm", "Am", "Bb", "C", "Dm", "C7"], 49 | ["B", "C#m", "D#m", "E", "F#", "G#m", "F#7"], 50 | ["Db", "Ebm", "Fm", "Gb", "Ab", "Bbm", "Ab7"], 51 | ["Eb", "Fm", "Gm", "Ab", "Bb", "Cm", "Bb7"], 52 | ["Gb", "Abm", "Bbm", "Cb", "Db", "Ebm", "Db7"], 53 | ["Ab", "Bbm", "Cm", "Db", "Eb", "Fm", "Eb7"], 54 | ["Bb", "Cm", "Dm", "Eb", "F", "Gm", "F7"], 55 | ["C#", "D#m", "E#m", "F#", "G#", "A#m", "G#7"], 56 | ["Cb", "Dbm", "Ebm", "Fb", "Gb", "Abm", "Gb7"], 57 | ["F#", "G#m", "A#m", "B", "C#", "D#m", "C#7", ], 58 | ] 59 | 60 | MINOR_CHORDS = [ 61 | ["Am", "Bdim", "C", "Dm", "Em", "F", "G"], 62 | ["Em", "F#dim", "G", "Am", "Bm", "C", "D"], 63 | ["Bm", "C#dim", "D", "Em", "F#m", "G", "A"], 64 | ["F#m", "G#dim", "A", "Bm", "C#m", "D", "E"], 65 | ["C#m", "D#dim", "E", "F#m", "G#m", "A", "B"], 66 | ["Dm", "Edim", "F", "Gm", "Am", "Bb", "C"], 67 | ["G#m", "A#dim", "B", "C#m", "D#m", "E", "F#"], 68 | ["Bbm", "Cdim", "Db", "Ebm", "Fm", "Gb", "Ab"], 69 | ["Cm", "Ddim", "Eb", "Fm", "Gm", "Ab", "Bb"], 70 | ["Ebm", "Fdim", "Gb", "Abm", "Bbm", "Cb", "Db"], 71 | ["Fm", "Gdim", "Ab", "Bbm", "Cm", "Db", "Eb"], 72 | ["Gm", "Adim", "Bb", "Cm", "Dm", "Eb", "F"], 73 | ["A#m", "B#dim", "C#", "D#m", "E#m", "F#", "G#"], 74 | ["Abm", "Bbdim", "Cb", "Dbm", "Ebm", "Fb", "Gb"], 75 | ["D#m", "E#dim", "F#", "G#m", "A#m", "B", "C#"], 76 | ] 77 | 78 | MAJ2MIN_MAP = { 79 | "G": "Em", 80 | "C": "Am", 81 | "D": "Bm", 82 | "E": "C#m", 83 | "F": "Dm", 84 | "A": "F#m", 85 | "B": "G#m", 86 | "Db": "Bbm", 87 | "Eb": "Cm", 88 | "Gb": "Ebm", 89 | "Ab": "Fm", 90 | "Bb": "Gm", 91 | "C#": "A#m", 92 | "Cb": "Abm", 93 | "F#": "D#m" 94 | } 95 | MIN2MAJ_MAP = {v: k for k, v in MAJ2MIN_MAP.items()} 96 | -------------------------------------------------------------------------------- /python/models/segment_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.transformers import BaseTransformer 4 | 5 | 6 | class SegmentEmbeddings(nn.Module): 7 | def __init__(self, n_channel=2, n_hidden=128, d_model=2048, dropout=0.5): 8 | super().__init__() 9 | self.n_channel = n_channel 10 | self.n_hidden = n_hidden 11 | 12 | self.act_fn = nn.ReLU() 13 | 14 | self.conv1 = nn.Conv2d(n_channel, n_hidden // 2, kernel_size=(1, 1)) 15 | self.pool0 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 0)) 16 | self.drop1 = nn.Dropout(dropout) 17 | 18 | self.conv2 = nn.Conv2d(n_hidden // 2, n_hidden, kernel_size=(1, 1)) 19 | self.drop2 = nn.Dropout(dropout) 20 | 21 | self.conv3 = nn.Conv2d(n_hidden, n_hidden // 2, kernel_size=(1, 1)) 22 | self.pool3 = nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 0)) 23 | self.drop3 = nn.Dropout(dropout) 24 | 25 | self.conv4 = nn.Conv2d(n_hidden // 2, n_channel, kernel_size=(1, 1)) 26 | self.drop4 = nn.Dropout(dropout) 27 | 28 | self.norm = nn.LayerNorm(d_model // 4) 29 | self.drop = nn.Dropout(dropout) 30 | 31 | def forward(self, x): 32 | # x: batch, n_channel, n_time, n_freq 33 | x = self.conv1(x) 34 | x = self.pool0(x) 35 | x = self.act_fn(x) 36 | x = self.drop1(x) 37 | 38 | x = self.conv2(x) 39 | x = self.act_fn(x) 40 | x = self.drop2(x) 41 | 42 | x = self.conv3(x) 43 | x = self.pool3(x) 44 | x = self.act_fn(x) 45 | x = self.drop3(x) 46 | 47 | x = self.conv4(x) 48 | x = self.act_fn(x) 49 | x = self.drop4(x) 50 | 51 | x = self.norm(x) 52 | x = self.drop(x) 53 | 54 | return x 55 | 56 | 57 | class SegmentNet(nn.Module): 58 | def __init__(self, 59 | n_freq=2048, 60 | n_channel=2, 61 | n_classes=10, 62 | emb_hidden=128, 63 | n_group=32, 64 | f_layers=2, 65 | f_nhead=8, 66 | t_layers=2, 67 | t_nhead=8, 68 | d_layers=2, 69 | d_nhead=8, 70 | dropout=0.5, 71 | *args, **kwargs): 72 | super().__init__() 73 | 74 | divisor = 4 75 | d_model = n_freq // divisor 76 | self.embeddings = SegmentEmbeddings(n_channel=n_channel, n_hidden=emb_hidden, d_model=n_freq, dropout=dropout) 77 | self.wfc = nn.Linear(n_freq, n_freq // divisor) 78 | 79 | self.transformer = BaseTransformer(n_channel=n_channel, n_freq=d_model, n_group=n_group, 80 | f_layers=f_layers, f_nhead=f_nhead, f_dropout=dropout, 81 | t_layers=t_layers, t_nhead=t_nhead, t_dropout=dropout, 82 | d_layers=d_layers, d_nhead=d_nhead, d_dropout=dropout, 83 | ) 84 | 85 | self.norm = nn.LayerNorm(d_model) 86 | 87 | self.segment_classifier = nn.Linear(d_model, n_classes) 88 | 89 | def forward(self, x, weight=None): 90 | # x: batch, n_channel, n_time, n_freq 91 | if weight is None: 92 | B, C, T, F = x.shape 93 | weight = torch.ones(B, T, F, device=x.device) 94 | x = self.embeddings(x) 95 | weight = self.wfc(weight) 96 | 97 | x, _ = self.transformer(x, weight=weight) 98 | x = self.norm(x) 99 | x = self.segment_classifier(x) 100 | x = torch.argmax(x, dim=-1) 101 | return x 102 | 103 | 104 | if __name__ == '__main__': 105 | net = SegmentNet(n_freq=2048) 106 | print(net) 107 | x = torch.randn(2, 2, 1024, 2048) 108 | y = net(x) 109 | print(y.shape) 110 | -------------------------------------------------------------------------------- /cover/utils/spec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from librosa import feature 3 | 4 | 5 | def stft(signal, fft_size=512, hop_length=None, pad=True): 6 | """ 7 | Perform Short-Time Fourier Transform (STFT) on the input signal. 8 | 9 | Parameters: 10 | - signal (torch.Tensor): The input signal tensor. 11 | - fft_size (int): The size of the FFT window. Default is 512. 12 | - hop_length (int): The number of samples between successive frames. Default is fft_size // 4. 13 | - pad (bool, optional): whether to pad :attr:`input` on both sides 14 | so that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. 15 | Default: ``True`` 16 | 17 | Returns: 18 | - torch.Tensor: The STFT of the input signal. 19 | """ 20 | *ot_shape, n_time = signal.shape 21 | signal = signal.reshape(-1, n_time) 22 | result = torch.stft(signal, 23 | fft_size, 24 | hop_length or fft_size // 4, 25 | window=torch.hann_window(fft_size).to(signal), 26 | win_length=fft_size, 27 | normalized=True, 28 | center=pad, 29 | return_complex=True, 30 | pad_mode='reflect') 31 | _, freqs, frames = result.shape 32 | return result.view(*ot_shape, freqs, frames) 33 | 34 | 35 | def istft(stft_matrix, hop_length=None, signal_length=None, pad=True): 36 | """ 37 | Perform Inverse Short-Time Fourier Transform (ISTFT) on the input STFT matrix. 38 | 39 | Parameters: 40 | - stft_matrix (torch.Tensor): The input STFT matrix tensor. 41 | - hop_length (int): The number of samples between successive frames. Default is None. 42 | - signal_length (int): The length of the original signal. Default is None. 43 | - pad (bool, optional): whether to pad :attr:`input` on both sides 44 | so that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. 45 | Default: ``True`` 46 | 47 | Returns: 48 | - torch.Tensor: The reconstructed time-domain signal. 49 | """ 50 | *ot_shape, n_freqs, n_frames = stft_matrix.shape 51 | fft_size = 2 * n_freqs - 2 52 | stft_matrix = stft_matrix.view(-1, n_freqs, n_frames) 53 | win_length = fft_size 54 | 55 | result = torch.istft(stft_matrix, 56 | fft_size, 57 | hop_length, 58 | window=torch.hann_window(win_length).to(stft_matrix.real), 59 | win_length=win_length, 60 | normalized=True, 61 | length=signal_length, 62 | center=pad) 63 | _, length = result.shape 64 | return result.view(*ot_shape, length) 65 | 66 | 67 | def get_spectrogram(waveform, config): 68 | """ 69 | Get the spectrogram of the input waveform based on the provided configuration. 70 | 71 | Parameters: 72 | - waveform (torch.Tensor): The input waveform tensor. 73 | - config (dict): The configuration dictionary containing 'n_fft', 'hop_length', and 'pad' keys. 74 | 75 | Returns: 76 | - torch.Tensor: The spectrogram of the input waveform. 77 | """ 78 | spectrogram = stft(waveform, fft_size=config.get('n_fft', 4096), 79 | hop_length=config.get('hop_length', 1024), pad=config.get('pad', True)) 80 | spectrogram = spectrogram.transpose(-1, -2) 81 | return spectrogram # channel, freq, time 82 | 83 | 84 | def chroma(waveform, n_chroma=12, sample_rate=44100, hop_length=512, bins_per_octave=24): 85 | dtype = waveform.dtype 86 | if isinstance(waveform, torch.Tensor): 87 | waveform = waveform.cpu().numpy() 88 | y = feature.chroma_cqt(y=waveform, sr=sample_rate, n_chroma=n_chroma, hop_length=hop_length, n_octaves=12, 89 | bins_per_octave=bins_per_octave, cqt_mode='hybrid') 90 | y = torch.from_numpy(y).to(dtype) 91 | return y 92 | -------------------------------------------------------------------------------- /cover/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch import nn 5 | 6 | 7 | class UPad(nn.Module): 8 | def __init__(self, padding_setting=(1, 2, 1, 2)): 9 | super().__init__() 10 | self.padding_setting = padding_setting 11 | 12 | def forward(self, x): 13 | return F.pad(x, self.padding_setting, "constant", 0) 14 | 15 | 16 | class UTransposedPad(nn.Module): 17 | def __init__(self, padding_setting=(1, 2, 1, 2)): 18 | super().__init__() 19 | self.padding_setting = padding_setting 20 | 21 | def forward(self, x): 22 | l, r, t, b = self.padding_setting 23 | return x[:, :, l:-r, t:-b] 24 | 25 | 26 | def get_activation(activation): 27 | if activation == "ReLU": 28 | activation_fn = nn.ReLU() 29 | elif activation == "ELU": 30 | activation_fn = nn.ELU() 31 | else: 32 | activation_fn = nn.LeakyReLU(0.2) 33 | return activation_fn 34 | 35 | 36 | class UNet(nn.Module): 37 | def __init__(self, 38 | n_channel=2, 39 | conv_n_filters=(16, 32, 64, 128, 256, 512), 40 | down_activation="ELU", 41 | up_activation="ELU", 42 | down_dropouts=None, 43 | up_dropouts=None): 44 | super().__init__() 45 | 46 | conv_num = len(conv_n_filters) 47 | 48 | down_activation_fn = get_activation(down_activation) 49 | up_activation_fn = get_activation(up_activation) 50 | 51 | down_dropouts = [0] * conv_num if down_dropouts is None else down_dropouts 52 | up_dropouts = [0] * conv_num if up_dropouts is None else up_dropouts 53 | 54 | self.down_layers = nn.ModuleList() 55 | for i in range(conv_num): 56 | in_ch = n_channel if i == 0 else conv_n_filters[i - 1] 57 | out_ch = conv_n_filters[i] 58 | dropout = down_dropouts[i] 59 | 60 | _down_layers = [ 61 | UPad(), 62 | nn.Conv2d(in_ch, out_ch, kernel_size=5, stride=2, padding=0), 63 | nn.BatchNorm2d(out_ch, track_running_stats=True, eps=1e-3, momentum=0.01), 64 | down_activation_fn 65 | ] 66 | if dropout > 0: 67 | _down_layers.append(nn.Dropout(dropout)) 68 | self.down_layers.append(nn.Sequential(*_down_layers)) 69 | 70 | self.up_layers = nn.ModuleList() 71 | for i in range(conv_num - 1, -1, -1): 72 | in_ch = conv_n_filters[conv_num - 1] if i == conv_num - 1 else conv_n_filters[i + 1] 73 | out_ch = 1 if i == 0 else conv_n_filters[i - 1] 74 | dropout = up_dropouts[i] 75 | 76 | _up_layer = [ 77 | nn.ConvTranspose2d(in_ch, out_ch, kernel_size=5, stride=2, padding=0), 78 | UTransposedPad(), 79 | up_activation_fn, 80 | nn.BatchNorm2d(out_ch, track_running_stats=True, eps=1e-3, momentum=0.01) 81 | ] 82 | if dropout > 0: 83 | _up_layer.append(nn.Dropout(dropout)) 84 | self.up_layers.append(nn.Sequential(*_up_layer)) 85 | 86 | self.last_layer = nn.Conv2d(1, 2, kernel_size=1, stride=1, padding=0) 87 | 88 | def forward(self, x): 89 | 90 | d_convs = [] 91 | for layer in self.down_layers: 92 | x = layer(x) 93 | d_convs.append(x) 94 | 95 | n = len(self.up_layers) 96 | for i, layer in enumerate(self.up_layers): 97 | if i == 0: 98 | x = layer(x) 99 | else: 100 | x1 = d_convs[n - i - 1] 101 | x = torch.cat([x, x1], axis=1) 102 | x = layer(x) 103 | x = self.last_layer(x) 104 | return x 105 | 106 | 107 | class UNets(nn.Module): 108 | def __init__(self, 109 | sources, 110 | n_channel=2, 111 | conv_n_filters=(16, 32, 64, 128, 256, 512), 112 | down_activation="ELU", 113 | up_activation="ELU", 114 | down_dropouts=None, 115 | up_dropouts=None, 116 | *args, **kwargs): 117 | super().__init__() 118 | 119 | self.unet_layers = nn.ModuleList() 120 | for i in range(sources): 121 | layer = UNet(n_channel=n_channel, 122 | conv_n_filters=conv_n_filters, 123 | down_activation=down_activation, 124 | up_activation=up_activation, 125 | down_dropouts=down_dropouts, 126 | up_dropouts=up_dropouts) 127 | self.unet_layers.append(layer) 128 | 129 | def forward(self, x): 130 | # x shape: (batch, channel, time, freq) 131 | _layers = [] 132 | for layer in self.unet_layers: 133 | y = layer(x) 134 | _layers.append(y) 135 | 136 | y = torch.stack(_layers, dim=1) 137 | y = F.softmax(y, dim=1) # shape: (batch, sources, channel, time, freq) 138 | return y 139 | 140 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fish 2 | 3 | YouTube video to chords, lyrics, beat and melody. 4 | 5 | A transformer-based hybrid multimodal model, various transformer models address different problems in the field of music information retrieval, these models generate corresponding information dependencies that mutually influence each other. 6 | 7 | An AI-powered multimodal project focused on music, generate chords, beats, lyrics, melody, and tabs for any song. 8 | 9 | > The online experience, [See the site here](https://lamucal.com) 10 | 11 | 12 | 13 | 14 | 15 | `U-Net` network model for audio source separation, `Pitch-Net`, `Beat-Net`, `Chord-Net` and `Segment-Net` based on the transformer model. Apart from establishing the correlation between the frequency and time, the most important aspect is to establish the mutual influence between different networks. 16 | 17 | The entire AI-powered process is implemented in `aitabs.py`, while the various network structure models can be referenced in the `models` folder. 18 | > **Note**: `U-Net` and `Segment-Net` use the stft spectrum of audio as input. `Beat-Net` uses three spectrograms of drums, bass, and other instruments as input,`Chord-Net` uses one spectrogram of the background music. 19 | 20 | 21 | ## Features 22 | - **Chord**, music chord detection, including major, minor, 7, maj7, min7, 6, m6, sus2, sus4, 5, and inverted chords. Determining the **key** of a song. 23 | 24 | - **Beat**, music beat, downbeat detection and **tempo** (BPM) tracking 25 | 26 | - **Pitch**, tracking the pitch of the melody in the vocal track. 27 | 28 | - **Music Structure**, music segment boundaries and labels, include intro, verse, chorus, bridge and etc. 29 | 30 | - **Lyrics**, music lyrics recognition and automatic lyrics to audio alignment, use ASR (whisper) to recognize the lyrics of the vocal track. The alignment of lyrics and audio is achieved through fine-tuning the wav2vec2 pre-training model. Currently, it supports dozens of languages, including English, Spanish, Portuguese, Russian, Japanese, Korean, Arabic, Chinese, and more. 31 | 32 | - **AI Tabs**, Generate playable sheet music, including chord charts and six-line staves, using chords, beats, music structure information, lyrics, rhythm, etc. It supports editing functionalities for chords, rhythm, and lyrics. 33 | 34 | - **Other**, audio source separation, speed adjustment, pitch shifting, etc. 35 | 36 | For more AI-powered feature experiences, see the [website](https://lamucal.com): 37 | 38 | ## Cover 39 | Using a combination of audio STFT, MFCC, and chroma features, with a Transformer model for timbre feature 40 | modeling and high-level abstraction, this approach can maximize the avoidance of overfitting and underfitting 41 | problems compared to using a single feature, and has better generalization capabilities. With a small amount of 42 | data and minimal training, it can achieve better results. 43 | 44 | > The online experience, [See the site here](https://lamucal.com/ai-cover) 45 | 46 | 47 | 48 | 49 | 50 | The model begins by processing the audio signal through a `U-Net`, which isolates the vocal track. 51 | The vocal track is then simultaneously fed into `PitchNet` and `HuBERT` (Wav2Vec2). `PitchNet` is 52 | responsible for extracting pitch features, while `HuBERT` captures detailed features of the vocals. 53 | 54 | The core of the model is `CombineNet`, which receives features from the `Features` module. This 55 | module consists of three spectrograms: STFT, MFCC, and Chroma, each extracting different aspects 56 | of the audio. These features are enhanced by the TimbreBlock before being passed to the Encoder. 57 | During this process, noise is introduced via STFT transformation and combined with the features 58 | before entering the Encoder for processing. The processed features are then passed to the Decoder, 59 | where they are combined with the previous features to generate the final audio output. 60 | 61 | `CombineNet` is based on an encoder-decoder architecture and is trained to generate a mask that 62 | is used to extract and replace the timbre, ultimately producing the final output audio. 63 | 64 | The entire AI-powered process is implemented in `run.py`, while the various network structure 65 | models can be referenced in the `models` folder. 66 | 67 | ## Demo 68 | The results of training on a 1-minute speech of Donald Trump are as follows: 69 | 70 | 71 | 72 | 77 | 82 | 83 | 84 | 89 | 94 | 95 |
73 | 74 | **Train 10 epoch(Hozier's Too Sweet)** 75 | 76 | 78 | 79 | **Train 100 epoch(Hozier's Too Sweet)** 80 | 81 |
85 | 86 | [Train 10 epoch.webm](https://github.com/user-attachments/assets/992747d6-3e47-442c-ab63-0742c83933ee) 87 | 88 | 90 | 91 | [Train 100 epoch.webm](https://github.com/user-attachments/assets/877d2cae-d7b7-4355-807f-424ada7df3a1) 92 | 93 |
96 | 97 | 98 | You can experience creating your own voice online, [See the site here](https://lamucal.com/ai-cover) 99 | -------------------------------------------------------------------------------- /python/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from torch import nn 5 | 6 | 7 | class UPad(nn.Module): 8 | def __init__(self, padding_setting=(1, 2, 1, 2)): 9 | super().__init__() 10 | self.padding_setting = padding_setting 11 | 12 | def forward(self, x): 13 | return F.pad(x, self.padding_setting, "constant", 0) 14 | 15 | 16 | class UTransposedPad(nn.Module): 17 | def __init__(self, padding_setting=(1, 2, 1, 2)): 18 | super().__init__() 19 | self.padding_setting = padding_setting 20 | 21 | def forward(self, x): 22 | l, r, t, b = self.padding_setting 23 | return x[:, :, l:-r, t:-b] 24 | 25 | 26 | def get_activation(activation): 27 | if activation == "ReLU": 28 | activation_fn = nn.ReLU() 29 | elif activation == "ELU": 30 | activation_fn = nn.ELU() 31 | else: 32 | activation_fn = nn.LeakyReLU(0.2) 33 | return activation_fn 34 | 35 | 36 | class UNet(nn.Module): 37 | def __init__(self, 38 | n_channel=2, 39 | conv_n_filters=(16, 32, 64, 128, 256, 512), 40 | down_activation="ELU", 41 | up_activation="ELU", 42 | down_dropouts=None, 43 | up_dropouts=None): 44 | super().__init__() 45 | 46 | conv_num = len(conv_n_filters) 47 | 48 | down_activation_fn = get_activation(down_activation) 49 | up_activation_fn = get_activation(up_activation) 50 | 51 | down_dropouts = [0] * conv_num if down_dropouts is None else down_dropouts 52 | up_dropouts = [0] * conv_num if up_dropouts is None else up_dropouts 53 | 54 | self.down_layers = nn.ModuleList() 55 | for i in range(conv_num): 56 | in_ch = n_channel if i == 0 else conv_n_filters[i - 1] 57 | out_ch = conv_n_filters[i] 58 | dropout = down_dropouts[i] 59 | 60 | _down_layers = [ 61 | UPad(), 62 | nn.Conv2d(in_ch, out_ch, kernel_size=5, stride=2, padding=0), 63 | nn.BatchNorm2d(out_ch, track_running_stats=True, eps=1e-3, momentum=0.01), 64 | down_activation_fn 65 | ] 66 | if dropout > 0: 67 | _down_layers.append(nn.Dropout(dropout)) 68 | self.down_layers.append(nn.Sequential(*_down_layers)) 69 | 70 | self.up_layers = nn.ModuleList() 71 | for i in range(conv_num - 1, -1, -1): 72 | in_ch = conv_n_filters[conv_num - 1] if i == conv_num - 1 else conv_n_filters[i + 1] 73 | out_ch = 1 if i == 0 else conv_n_filters[i - 1] 74 | dropout = up_dropouts[i] 75 | 76 | _up_layer = [ 77 | nn.ConvTranspose2d(in_ch, out_ch, kernel_size=5, stride=2, padding=0), 78 | UTransposedPad(), 79 | up_activation_fn, 80 | nn.BatchNorm2d(out_ch, track_running_stats=True, eps=1e-3, momentum=0.01) 81 | ] 82 | if dropout > 0: 83 | _up_layer.append(nn.Dropout(dropout)) 84 | self.up_layers.append(nn.Sequential(*_up_layer)) 85 | 86 | self.last_layer = nn.Conv2d(1, 2, kernel_size=1, stride=1, padding=0) 87 | 88 | def forward(self, x): 89 | 90 | d_convs = [] 91 | for layer in self.down_layers: 92 | x = layer(x) 93 | d_convs.append(x) 94 | 95 | n = len(self.up_layers) 96 | for i, layer in enumerate(self.up_layers): 97 | if i == 0: 98 | x = layer(x) 99 | else: 100 | x1 = d_convs[n - i - 1] 101 | x = torch.cat([x, x1], axis=1) 102 | x = layer(x) 103 | x = self.last_layer(x) 104 | return x 105 | 106 | 107 | class UNets(nn.Module): 108 | def __init__(self, 109 | sources, 110 | n_channel=2, 111 | conv_n_filters=(16, 32, 64, 128, 256, 512), 112 | down_activation="ELU", 113 | up_activation="ELU", 114 | down_dropouts=None, 115 | up_dropouts=None, 116 | *args, **kwargs): 117 | super().__init__() 118 | 119 | self.unet_layers = nn.ModuleList() 120 | for i in range(sources): 121 | layer = UNet(n_channel=n_channel, 122 | conv_n_filters=conv_n_filters, 123 | down_activation=down_activation, 124 | up_activation=up_activation, 125 | down_dropouts=down_dropouts, 126 | up_dropouts=up_dropouts) 127 | self.unet_layers.append(layer) 128 | 129 | def forward(self, x): 130 | # x shape: (batch, channel, time, freq) 131 | _layers = [] 132 | for layer in self.unet_layers: 133 | y = layer(x) 134 | _layers.append(y) 135 | 136 | y = torch.stack(_layers, axis=1) 137 | y = F.softmax(y, dim=1) # shape: (batch, sources, channel, time, freq) 138 | return y 139 | 140 | 141 | if __name__ == '__main__': 142 | model_params = { 143 | 'sources': 4, 144 | 'n_channel': 2, 145 | 'conv_n_filters': [8, 16, 32, 64, 128, 256, 512, 1024], 146 | 'down_activation': "ELU", 147 | 'up_activation': "ELU", 148 | # 'down_dropouts': [0, 0, 0, 0, 0, 0, 0, 0], 149 | # 'up_dropouts': [0, 0, 0, 0, 0, 0, 0, 0], 150 | } 151 | net = UNets(**model_params) 152 | print(net) 153 | print(net(torch.rand(1, 2, 256, 2049)).shape) 154 | -------------------------------------------------------------------------------- /cover/run.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import torch as th 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from load import get_model, get_hubert_model 8 | from utils.audio import load_waveform, save_waveform 9 | from utils.spec import get_spectrogram, istft 10 | from utils.utils import generate_uuid 11 | from models import UNets, PitchNet, CombineNet 12 | 13 | 14 | def build_masked_stft(masks, stft_feature, n_fft=4096): 15 | out = [] 16 | for i in range(len(masks)): 17 | mask = masks[i, :, :, :] 18 | pad_num = n_fft // 2 + 1 - mask.size(-1) 19 | mask = F.pad(mask, (0, pad_num, 0, 0, 0, 0)) 20 | inst_stft = mask.type(stft_feature.dtype) * stft_feature 21 | out.append(inst_stft) 22 | return out 23 | 24 | 25 | def merge_wav(stem_list, volume_list=None): 26 | stem_num = len(stem_list) 27 | length = min([stem.shape[0] for stem in stem_list]) 28 | 29 | for i in range(stem_num): 30 | stem = stem_list[i][:length] # shape: time, channel 31 | if stem.ndim == 1: 32 | stem = np.tile(np.expand_dims(stem, axis=1), (1, 2)) 33 | stem = stem / np.abs(stem).max() 34 | stem_list[i] = stem 35 | if volume_list: 36 | stem_list[i] = stem_list[i] * volume_list[i] 37 | 38 | mix = sum(stem_list) / stem_num 39 | mix = mix / np.abs(mix).max() 40 | return mix 41 | 42 | 43 | class AudioGenerate(object): 44 | def __init__(self, config, device='cpu'): 45 | self.config = config 46 | self.device = device 47 | self.n_channel = self.config['n_channel'] 48 | self.sources = self.config['sources'] 49 | self.samplerate = self.config['samplerate'] 50 | self.separate_config = self.config['separate'] 51 | self.f0_config = self.config['f0'] 52 | self.hubert_config = self.config['hubert'] 53 | self.generate_config = self.config['generate'] 54 | 55 | self.separate_model_cfg = self.separate_config['model'] 56 | self.separate_model_cfg['sources'] = self.sources 57 | self.separate_model_cfg['n_channel'] = self.n_channel 58 | self.unet = get_model(UNets, self.separate_model_cfg, 59 | model_path=self.separate_config['model_path'], 60 | is_train=False, device=device) 61 | 62 | self.f0_model_cfg = self.f0_config['model'] 63 | self.f0_extractor = get_model(PitchNet, self.f0_model_cfg, 64 | model_path=self.f0_config['model_path'], 65 | is_train=False, device=device) 66 | self.hubert_model = get_hubert_model(self.hubert_config['name'], 67 | dl_kwargs=self.hubert_config['download'], 68 | device=device) 69 | 70 | self.generate_model_cfg = self.generate_config['model'] 71 | self.ai_cover_model = get_model(CombineNet, self.generate_model_cfg, 72 | model_path=self.generate_config['model_path'], 73 | device=device) 74 | 75 | def separate(self, waveform, samplerate): 76 | assert samplerate == self.samplerate 77 | wav_len = waveform.shape[-1] 78 | 79 | spec_config = self.separate_config['spec'] 80 | n_fft = spec_config['n_fft'] 81 | hop_length = spec_config['hop_length'] 82 | n_time = spec_config['n_time'] 83 | 84 | split_len = (n_time - 5) * hop_length + n_fft 85 | 86 | output_waveforms = [[] for _ in range(self.sources)] 87 | for i in range(0, wav_len, split_len): 88 | with th.no_grad(): 89 | x = waveform[:, i:i + split_len] 90 | pad_num = 0 91 | if x.shape[-1] < split_len: 92 | pad_num = split_len - (wav_len - i) 93 | x = F.pad(x, (0, pad_num)) 94 | 95 | # separator 96 | z = get_spectrogram(x, spec_config) 97 | mag_z = th.abs(z).unsqueeze(0) 98 | masks = self.unet(mag_z) 99 | masks = masks.squeeze(0) 100 | _masked_stfts = build_masked_stft(masks, z, n_fft=n_fft) 101 | # build waveform 102 | for j, _masked_stft in enumerate(_masked_stfts): 103 | _masked_stft = _masked_stft.transpose(-1, -2) 104 | _waveform = istft(_masked_stft, hop_length=hop_length) 105 | if pad_num > 0: 106 | _waveform = _waveform[:, :-pad_num] 107 | output_waveforms[j].append(_waveform) 108 | 109 | inst_waveforms = [] 110 | for waveform_list in output_waveforms: 111 | inst_waveforms.append(th.cat(waveform_list, dim=-1)) 112 | return th.stack(inst_waveforms, dim=0) 113 | 114 | def get_feat(self, audio): 115 | feats, _ = self.hubert_model(audio) 116 | feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) 117 | return feats 118 | 119 | def generate(self, audio_fp, save_path): 120 | waveform, samplerate = load_waveform(audio_fp, samplerate=self.samplerate, 121 | n_channel=self.n_channel, device=self.device) 122 | 123 | waveforms = self.separate(waveform, samplerate) 124 | 125 | vocal_waveform = waveforms[0] 126 | other_waveform = waveforms[1] 127 | 128 | with th.no_grad(): 129 | feat = self.get_feat(vocal_waveform) 130 | f0 = self.f0_extractor(vocal_waveform) 131 | out_waveform, _, _ = self.ai_cover_model(vocal_waveform, feat, f0) 132 | 133 | final_waveform = merge_wav([vocal_waveform, other_waveform], 134 | volume_list=self.generate_config['volume_list']) 135 | 136 | fp = os.path.join(save_path, f'{generate_uuid()}.wav') 137 | save_waveform(fp, final_waveform, self.samplerate) 138 | 139 | -------------------------------------------------------------------------------- /python/models/pitch_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.transformers import EncoderFre, EncoderTime, FeedForward, positional_encoding 4 | 5 | 6 | class PitchEmbedding(nn.Module): 7 | def __init__(self, n_channel, d_model, n_hidden=32, dropout=0.5): 8 | super().__init__() 9 | self.n_channel = n_channel 10 | self.n_hidden = n_hidden 11 | 12 | self.act_fn = nn.ReLU() 13 | 14 | self.conv1 = nn.Conv2d(n_channel, n_hidden // 4, kernel_size=(1, 1)) 15 | self.pool1 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2)) 16 | self.drop1 = nn.Dropout(dropout) 17 | 18 | self.conv2 = nn.Conv2d(n_hidden // 4, n_hidden // 2, kernel_size=(1, 1)) 19 | self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2)) 20 | self.drop2 = nn.Dropout(dropout) 21 | 22 | self.conv3 = nn.Conv2d(n_hidden // 2, n_hidden, kernel_size=(1, 1)) 23 | self.pool3 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2)) 24 | self.drop3 = nn.Dropout(dropout) 25 | 26 | self.conv4 = nn.Conv2d(n_hidden, n_hidden, kernel_size=(1, 1)) 27 | self.pool4 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2)) 28 | self.drop4 = nn.Dropout(dropout) 29 | 30 | self.norm = nn.LayerNorm(d_model // 16) 31 | self.drop = nn.Dropout(dropout) 32 | 33 | def forward(self, x): 34 | # x: batch, n_channel, n_time, n_freq 35 | x = self.conv1(x) 36 | x = self.pool1(x) 37 | x = self.act_fn(x) 38 | x = self.drop1(x) 39 | 40 | x = self.conv2(x) 41 | x = self.pool2(x) 42 | x = self.act_fn(x) 43 | x = self.drop2(x) 44 | 45 | x = self.conv3(x) 46 | x = self.pool3(x) 47 | x = self.act_fn(x) 48 | x = self.drop3(x) 49 | 50 | x = self.conv4(x) 51 | x = self.pool4(x) 52 | x = self.act_fn(x) 53 | x = self.drop4(x) 54 | 55 | x = self.norm(x) 56 | x = self.drop(x) 57 | 58 | return x 59 | 60 | 61 | class PitchEncoder(nn.Module): 62 | def __init__(self, 63 | n_freq=2048, 64 | n_group=32, 65 | weights=(0.6, 0.4), 66 | f_layers=2, 67 | f_nhead=8, 68 | f_pr=0.01, 69 | t_layers=2, 70 | t_nhead=8, 71 | t_pr=0.01, 72 | dropout=0.5): 73 | super().__init__() 74 | self.weights = weights 75 | 76 | self.encoder_fre = EncoderFre(n_freq=n_freq, n_group=n_group, nhead=f_nhead, n_layers=f_layers, dropout=dropout, 77 | pr=f_pr) 78 | self.encoder_time = EncoderTime(n_freq=n_freq, nhead=t_nhead, n_layers=t_layers, dropout=dropout, pr=t_pr) 79 | 80 | self.dropout = nn.Dropout(dropout) 81 | self.norm = nn.LayerNorm(n_freq) 82 | 83 | def forward(self, x): 84 | # x: batch, channel, n_time, n_freq 85 | y1 = self.encoder_fre(x) 86 | y2 = self.encoder_time(x) 87 | y = y1 * self.weights[0] + y2 * self.weights[1] 88 | y = self.dropout(y) 89 | y = self.norm(y) 90 | return y 91 | 92 | 93 | class PitchNet(nn.Module): 94 | def __init__(self, 95 | n_freq=2048, 96 | n_channel=2, 97 | n_classes=850, # 85 * 10 98 | emb_hidden=16, 99 | wr=0.3, 100 | pr=0.02, 101 | n_group=32, 102 | f_layers=2, 103 | f_nhead=8, 104 | t_layers=2, 105 | t_nhead=8, 106 | enc_weights=(0.6, 0.4), 107 | d_layers=2, 108 | d_nhead=8, 109 | dropout=0.5, 110 | *args, **kwargs): 111 | super().__init__() 112 | self.wr = wr 113 | self.pr = pr 114 | self.embedding = PitchEmbedding(n_channel, n_freq, emb_hidden, dropout) 115 | d_model = n_freq // 16 * emb_hidden 116 | self.encoder = PitchEncoder(n_freq=d_model, n_group=n_group, 117 | weights=enc_weights, 118 | f_layers=f_layers, f_nhead=f_nhead, 119 | t_layers=t_layers, t_nhead=t_nhead, 120 | dropout=dropout) 121 | 122 | self.attn1_layer = nn.ModuleList() 123 | self.attn2_layer = nn.ModuleList() 124 | self.ff_layer = nn.ModuleList() 125 | for _ in range(d_layers): 126 | _layer1 = nn.MultiheadAttention(d_model, d_nhead, batch_first=True) 127 | _layer2 = nn.MultiheadAttention(d_model, d_nhead, batch_first=True) 128 | _layer3 = FeedForward(d_model, dropout=dropout) 129 | self.attn1_layer.append(_layer1) 130 | self.attn2_layer.append(_layer2) 131 | self.ff_layer.append(_layer3) 132 | 133 | self.dropout = nn.Dropout(dropout) 134 | self.classifier = nn.Linear(d_model, n_classes) 135 | 136 | def forward(self, x, weight=None): 137 | # x: batch, channel, n_time, n_freq 138 | y = self.embedding(x) 139 | 140 | B, C, T, F = y.shape 141 | y = y.permute(0, 2, 1, 3) # B, C, T, F => B, T, C, F 142 | y = y.reshape(B, T, C * F) # B, T, C * F 143 | y = self.encoder(y) 144 | 145 | if weight is None: 146 | weight = torch.zeros_like(y) 147 | y_w = y + weight * self.wr 148 | B, T, F = y_w.shape 149 | y_w += positional_encoding(B, T, F) * self.pr 150 | 151 | for _attn1_layer, _attn2_layer, _ff_layer in zip(self.attn1_layer, self.attn2_layer, self.ff_layer): 152 | y, _ = _attn1_layer(y, y, y, need_weights=False) 153 | y, _ = _attn1_layer(y, y_w, y_w, need_weights=False) 154 | y = _ff_layer(y) 155 | 156 | output = self.dropout(y) 157 | output = self.classifier(output) 158 | output = torch.argmax(output, dim=-1) 159 | return output, y 160 | 161 | 162 | if __name__ == '__main__': 163 | model = PitchNet() 164 | print(model) 165 | x = torch.randn(6, 2, 256, 2048) 166 | y1, y2 = model(x, None) 167 | print(y1.shape, y2.shape) 168 | -------------------------------------------------------------------------------- /python/aitabs.py: -------------------------------------------------------------------------------- 1 | import torchaudio as ta 2 | import torch as th 3 | import torch.nn.functional as F 4 | from librosa.feature import tempo 5 | 6 | from audio import read_wav, write_wav, gen_wav 7 | from utils import build_masked_stft, get_chord_name, get_segment_name, get_lyrics 8 | from spec import istft, get_spec, get_specs, get_mixed_spec 9 | from modulation import search_key 10 | from models import get_model 11 | 12 | 13 | class AITabTranscription(object): 14 | def __init__(self, config): 15 | self.config = config 16 | self.n_channel = self.config['n_channel'] 17 | self.sources = self.config['sources'] 18 | self.sample_rate = self.config['sample_rate'] 19 | self.sep_config = self.config['separate'] 20 | self.lyrics_cfg = self.config['lyrics'] 21 | self.beat_cfg = self.config['beat'] 22 | self.chord_cfg = self.config['chord'] 23 | self.segment_cfg = self.config['segment'] 24 | self.pitch_cfg = self.config['pitch'] 25 | self.spec_cfg = self.config['spec'] 26 | self.tempo_cfg = self.config['tempo'] 27 | 28 | def separate(self, waveform, sample_rate, device='cpu'): 29 | assert sample_rate == self.sample_rate 30 | wav_len = waveform.shape[-1] 31 | 32 | model_config = self.sep_config['model'] 33 | spec_config = self.sep_config['spec'] 34 | n_fft = self.sep_config['spec']['n_fft'] 35 | hop_length = self.sep_config['spec']['hop_length'] 36 | n_time = self.sep_config['spec']['n_time'] 37 | 38 | _model_cfg = { 39 | 'sources': self.sources, 40 | 'n_channel': self.n_channel, 41 | } 42 | _model_cfg.update(model_config) 43 | unet = get_model(self.sep_config['model_name'], _model_cfg, model_path=self.sep_config['model_path'], 44 | is_train=False, device=device) 45 | 46 | split_len = (n_time - 5) * hop_length + n_fft 47 | 48 | output_waveforms = [[] for _ in range(self.sources)] 49 | for i in range(0, wav_len, split_len): 50 | with th.no_grad(): 51 | x = waveform[:, i:i + split_len] 52 | pad_num = 0 53 | if x.shape[-1] < split_len: 54 | pad_num = split_len - (wav_len - i) 55 | x = F.pad(x, (0, pad_num)) 56 | 57 | # separator 58 | z = get_spec(x, spec_config) 59 | mag_z = th.abs(z).unsqueeze(0) 60 | masks = unet(mag_z) 61 | masks = masks.squeeze(0) 62 | _masked_stfts = build_masked_stft(masks, z, n_fft=n_fft) 63 | # build waveform 64 | for j, _masked_stft in enumerate(_masked_stfts): 65 | _waveform = istft(_masked_stft, n_fft=n_fft, hop_length=hop_length, pad=True) 66 | if pad_num > 0: 67 | _waveform = _waveform[:, :-pad_num] 68 | output_waveforms[j].append(_waveform) 69 | 70 | inst_waveforms = [] 71 | for waveform_list in output_waveforms: 72 | inst_waveforms.append(th.cat(waveform_list, dim=-1)) 73 | return th.stack(inst_waveforms, dim=0) 74 | 75 | def transcribe(self, wav_fp, device='cpu'): 76 | 77 | waveform, sample_rate = read_wav(wav_fp, sample_rate=self.sample_rate, n_channel=self.n_channel, device=device) 78 | # print(waveform.shape, sample_rate) 79 | 80 | inst_waveforms = self.separate(waveform, sample_rate) 81 | # print(inst_waveforms.shape) 82 | 83 | # laod model 84 | beat_net = get_model(self.beat_cfg['model_name'], self.beat_cfg['model'], 85 | model_path=self.beat_cfg['model_path'], is_train=False, device=device) 86 | chord_net = get_model(self.chord_cfg['model_name'], self.chord_cfg['model'], 87 | model_path=self.chord_cfg['model_path'], is_train=False, device=device) 88 | segment_net = get_model(self.segment_cfg['model_name'], self.segment_cfg['model'], 89 | model_path=self.segment_cfg['model_path'], is_train=False, device=device) 90 | pitch_net = get_model(self.pitch_cfg['model_name'], self.pitch_cfg['model'], 91 | model_path=self.pitch_cfg['model_path'], is_train=False, device=device) 92 | 93 | vocal_waveform = inst_waveforms[0].numpy() 94 | orig_spec = get_spec(waveform, self.spec_cfg) 95 | inst_specs = get_specs(inst_waveforms, self.spec_cfg) # vocal, bass, drum, other 96 | vocal_spec = get_spec(inst_waveforms[0], self.spec_cfg) # vocal 97 | other_spec = get_mixed_spec(inst_waveforms[1:], self.spec_cfg) # bass + drum + other 98 | 99 | # pred lyrics 100 | lyrics, lyrics_matrix = get_lyrics(vocal_waveform, sample_rate, self.lyrics_cfg) 101 | 102 | with th.no_grad(): 103 | # pred beat 104 | beat_features = inst_specs[:, :, :, :self.spec_cfg['n_fft'] // 2].unsqueeze(0) # B, S, C, T, F 105 | beat_features_mag = th.abs(beat_features) 106 | 107 | beat_pred, beat_logist = beat_net(beat_features_mag) 108 | print('beat info', beat_pred.shape, beat_logist.shape) 109 | 110 | # pred chord 111 | chord_features = other_spec[:, :, :self.spec_cfg['n_fft'] // 2].unsqueeze(0) 112 | chord_features_mag = th.abs(chord_features) 113 | 114 | chord_pred, chord_logist = chord_net(chord_features_mag, beat_logist) 115 | print('chord info', chord_pred.shape, chord_logist.shape) 116 | 117 | # pred segment 118 | segment_features = orig_spec[:, :, :self.spec_cfg['n_fft'] // 2].unsqueeze(0) 119 | segment_features_mag = th.abs(segment_features) 120 | segment_pred = segment_net(segment_features_mag, chord_logist) 121 | print('segment info', segment_pred.shape) 122 | 123 | # pred pitch 124 | pitch_features = vocal_spec[:, :, :self.spec_cfg['n_fft'] // 2].unsqueeze(0) 125 | pitch_features_mag = th.abs(pitch_features) 126 | pitch_pred, pitch_logist = pitch_net(pitch_features_mag, lyrics_matrix) 127 | print('pitch info', pitch_pred.shape, pitch_logist.shape) 128 | 129 | beats = beat_pred.squeeze(0).numpy() 130 | bpm = tempo(onset_envelope=beats, hop_length=self.tempo_cfg['hop_length']).tolist() 131 | chord_pred = chord_pred.squeeze(0) 132 | chords = get_chord_name(chord_pred) 133 | song_key = search_key(chords) 134 | segment_pred = segment_pred.squeeze(0) 135 | segment = get_segment_name(segment_pred) 136 | beats = beats.tolist() 137 | pitch_list = pitch_pred.squeeze(0).tolist() 138 | 139 | ret = { 140 | 'bpm': bpm, 141 | 'key': song_key, 142 | 'chords': chords, 143 | 'beat': beats, 144 | 'segment': segment, 145 | 'pitch': pitch_list, 146 | 'lyrics': lyrics, 147 | } 148 | return ret, inst_waveforms 149 | 150 | -------------------------------------------------------------------------------- /python/models/transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def positional_encoding(batch_size, n_time, n_feature, zero_pad=False, scale=False, dtype=torch.float32): 7 | pos_indices = torch.tile(torch.unsqueeze(torch.arange(n_time), 0), [batch_size, 1]) 8 | 9 | pos = torch.arange(n_time, dtype=dtype).reshape(-1, 1) 10 | pos_enc = pos / torch.pow(10000, 2 * torch.arange(0, n_feature, dtype=dtype) / n_feature) 11 | pos_enc[:, 0::2] = torch.sin(pos_enc[:, 0::2]) 12 | pos_enc[:, 1::2] = torch.cos(pos_enc[:, 1::2]) 13 | 14 | if zero_pad: 15 | pos_enc = torch.cat([torch.zeros(size=[1, n_feature]), pos_enc[1:, :]], 0) 16 | 17 | outputs = F.embedding(pos_indices, pos_enc) 18 | 19 | if scale: 20 | outputs = outputs * (n_feature ** 0.5) 21 | 22 | return outputs 23 | 24 | 25 | class FeedForward(nn.Module): 26 | def __init__(self, n_feature=2048, n_hidden=512, dropout=0.5): 27 | super().__init__() 28 | self.linear1 = nn.Linear(n_feature, n_hidden) 29 | self.linear2 = nn.Linear(n_hidden, n_feature) 30 | 31 | self.dropout1 = nn.Dropout(dropout) 32 | self.dropout2 = nn.Dropout(dropout) 33 | self.norm_layer = nn.LayerNorm(n_feature) 34 | 35 | def forward(self, x): 36 | y = self.linear1(x) 37 | y = F.relu(y) 38 | y = self.dropout1(y) 39 | y = self.linear2(y) 40 | y = self.dropout2(y) 41 | y = self.norm_layer(y) 42 | return y 43 | 44 | 45 | class EncoderFre(nn.Module): 46 | def __init__(self, n_freq, n_group, nhead=8, n_layers=5, dropout=0.5, pr=0.01): 47 | super().__init__() 48 | assert n_freq % n_group == 0 49 | self.d_model = d_model = n_freq // n_group 50 | self.n_freq = n_freq 51 | self.n_group = n_group 52 | self.pr = pr 53 | self.n_layers = n_layers 54 | 55 | self.attn_layer = nn.ModuleList() 56 | self.ff_layer = nn.ModuleList() 57 | 58 | for _ in range(self.n_layers): 59 | _attn_layer = nn.MultiheadAttention(d_model, nhead, batch_first=True) 60 | _ff_layer = FeedForward(d_model, dropout=dropout) 61 | 62 | self.attn_layer.append(_attn_layer) 63 | self.ff_layer.append(_ff_layer) 64 | 65 | self.dropout = nn.Dropout(dropout) 66 | self.fc = nn.Linear(n_freq, n_freq) 67 | self.norm_layer = nn.LayerNorm(n_freq) 68 | 69 | def forward(self, x): 70 | # x: batch, n_time, n_freq 71 | B, T, Fr = x.shape 72 | x_reshape = x.reshape(B * T, self.n_group, self.d_model) 73 | x_reshape += positional_encoding( 74 | batch_size=x_reshape.shape[0], n_time=x_reshape.shape[1], n_feature=x_reshape.shape[2] 75 | ) * self.pr 76 | 77 | for _attn_layer, _ff_layer in zip(self.attn_layer, self.ff_layer): 78 | x_reshape, _ = _attn_layer(x_reshape, x_reshape, x_reshape, need_weights=False) 79 | x_reshape = _ff_layer(x_reshape) 80 | 81 | y = x_reshape.reshape(B, T, self.n_freq) 82 | 83 | y = self.dropout(y) 84 | y = self.fc(y) 85 | y = self.norm_layer(y) 86 | return y 87 | 88 | 89 | class EncoderTime(nn.Module): 90 | def __init__(self, n_freq, nhead=8, n_layers=5, dropout=0.5, pr=0.02): 91 | super().__init__() 92 | self.n_freq = n_freq 93 | self.n_layers = n_layers 94 | self.pr = pr 95 | 96 | self.attn_layer = nn.ModuleList() 97 | self.ff_layer = nn.ModuleList() 98 | for _ in range(self.n_layers): 99 | _attn_layer = nn.MultiheadAttention(n_freq, nhead, batch_first=True) 100 | _ff_layer = FeedForward(n_freq, dropout=dropout) 101 | 102 | self.attn_layer.append(_attn_layer) 103 | self.ff_layer.append(_ff_layer) 104 | 105 | self.norm_layer = nn.LayerNorm(n_freq) 106 | self.dropout = nn.Dropout(dropout) 107 | self.fc = nn.Linear(n_freq, n_freq) 108 | 109 | def forward(self, x): 110 | # x: batch, n_time, n_freq 111 | B, T, Fr = x.shape 112 | x += positional_encoding(B, T, Fr) * self.pr 113 | 114 | for _attn_layer, _ff_layer in zip(self.attn_layer, self.ff_layer): 115 | x, _ = _attn_layer(x, x, x, need_weights=False) 116 | x = _ff_layer(x) 117 | 118 | x = self.dropout(x) 119 | x = self.fc(x) 120 | x = self.norm_layer(x) 121 | return x 122 | 123 | 124 | class Decoder(nn.Module): 125 | 126 | def __init__(self, d_model=512, nhead=8, n_layers=5, dropout=0.5, r1=1.0, r2=1.0, wr=1.0, pr=0.01): 127 | super().__init__() 128 | 129 | self.r1 = r1 130 | self.r2 = r2 131 | self.wr = wr 132 | self.n_layers = n_layers 133 | self.pr = pr 134 | 135 | self.attn1_layer = nn.ModuleList() 136 | self.attn2_layer = nn.ModuleList() 137 | self.ff_layer = nn.ModuleList() 138 | for _ in range(n_layers): 139 | _layer1 = nn.MultiheadAttention(d_model, nhead, batch_first=True) 140 | _layer2 = nn.MultiheadAttention(d_model, nhead, batch_first=True) 141 | _layer3 = FeedForward(d_model, dropout=dropout) 142 | self.attn1_layer.append(_layer1) 143 | self.attn2_layer.append(_layer2) 144 | self.ff_layer.append(_layer3) 145 | self.dropout = nn.Dropout(dropout) 146 | self.fc = nn.Linear(d_model, d_model) 147 | 148 | def forward(self, x1, x2, weight=None): 149 | y = x1 * self.r1 + x2 * self.r2 150 | if weight is not None: 151 | y += weight * self.wr 152 | 153 | y += positional_encoding(y.shape[0], y.shape[1], y.shape[2]) * self.pr + self.pr 154 | 155 | for i in range(self.n_layers): 156 | _attn1_layer = self.attn1_layer[i] 157 | _attn2_layer = self.attn2_layer[i] 158 | _ff_layer = self.ff_layer[i] 159 | y, _ = _attn1_layer(y, y, y, need_weights=False) 160 | y, _ = _attn2_layer(y, x2, x2, need_weights=False) 161 | y = _ff_layer(y) 162 | output = self.dropout(y) 163 | output = self.fc(output) 164 | return output, y 165 | 166 | 167 | class BaseTransformer(nn.Module): 168 | def __init__(self, 169 | n_channel=2, 170 | n_freq=2048, 171 | # EncoderFre 172 | n_group=32, 173 | f_layers=2, 174 | f_nhead=4, 175 | f_dropout=0.5, 176 | f_pr=0.01, 177 | # EncoderTime 178 | t_layers=2, 179 | t_nhead=4, 180 | t_dropout=0.5, 181 | t_pr=0.01, 182 | # Decoder 183 | d_layers=2, 184 | d_nhead=4, 185 | d_dropout=0.5, 186 | d_pr=0.02, 187 | r1=1.0, 188 | r2=1.0, 189 | wr=0.2, 190 | ): 191 | super().__init__() 192 | self.n_channel = n_channel 193 | self.encoder_fre_layers = nn.ModuleList() 194 | self.encoder_time_layers = nn.ModuleList() 195 | for _ in range(n_channel): 196 | _encoder_fre_layer = EncoderFre(n_freq=n_freq, n_group=n_group, nhead=f_nhead, n_layers=f_layers, 197 | dropout=f_dropout, pr=f_pr) 198 | _encoder_time_layer = EncoderTime(n_freq=n_freq, nhead=t_nhead, n_layers=t_layers, dropout=t_dropout, 199 | pr=t_pr) 200 | self.encoder_fre_layers.append(_encoder_fre_layer) 201 | self.encoder_time_layers.append(_encoder_time_layer) 202 | 203 | self.decoder = Decoder(d_model=n_freq, nhead=d_nhead, n_layers=d_layers, dropout=d_dropout, 204 | r1=r1, r2=r2, wr=wr, pr=d_pr) 205 | 206 | def forward(self, x, weight=None): 207 | # x: batch, channel, n_time, n_freq 208 | ff_list = [] 209 | tf_list = [] 210 | for i in range(self.n_channel): 211 | x1 = self.encoder_fre_layers[i](x[:, i, :, :]) 212 | x2 = self.encoder_time_layers[i](x[:, i, :, :]) 213 | ff_list.append(x1) 214 | tf_list.append(x2) 215 | 216 | y1 = torch.sum(torch.stack(ff_list, dim=0), dim=0) 217 | y2 = torch.sum(torch.stack(tf_list, dim=0), dim=0) 218 | y, w = self.decoder(y1, y2, weight=weight) 219 | return y, w 220 | 221 | 222 | if __name__ == '__main__': 223 | net = BaseTransformer(n_freq=2048) 224 | # net = EncoderTime(2048) 225 | # net = EncoderFre(2048, 32) 226 | 227 | # net = EncoderFre(2048, 8) 228 | x = torch.randn(1, 2, 1024, 2048) 229 | 230 | y, logits = net(x) 231 | y1, logits = net(x) 232 | 233 | # print(np.allclose(x, y)) 234 | print(y.shape, logits.shape) 235 | print(y[0, 0, 0], y[0, 1, 0]) 236 | print(y1[0, 0, 0], y1[0, 1, 0]) 237 | -------------------------------------------------------------------------------- /python/modulation.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from itertools import product 3 | 4 | from common import MAJOR_CHORDS, MINOR_CHORDS, MAJ2MIN_MAP 5 | 6 | 7 | def get_chord_tonic(chord): 8 | chord = chord.replace('Dim', 'dim').replace('Aug', 'aug') 9 | if len(chord) >= 3: 10 | if chord[1] in ["#", 'b']: 11 | if 'maj7' in chord: 12 | return chord[:2] 13 | # elif 'dim' in chord: 14 | # return chord[:2] + 'm' 15 | elif chord[2] in ['m']: 16 | return chord[:3] 17 | elif 'sus2' in chord or 'sus4' in chord or 'dim' in chord or 'aug' in chord: 18 | return chord 19 | else: 20 | return chord[:2] 21 | elif chord[1] == "5": 22 | return chord[:2] 23 | # elif 'dim' in chord: 24 | # return chord[:1] + 'm' 25 | elif 'sus2' in chord or 'sus4' in chord or 'dim' in chord or 'aug' in chord: 26 | return chord 27 | elif 'maj7' in chord: 28 | return chord.replace('maj7', '') 29 | elif '7M' in chord: 30 | return chord.replace('7M', '') 31 | elif chord[1] in ['m']: 32 | return chord[:2] 33 | else: 34 | return chord[:1] 35 | elif len(chord) == 2: 36 | if chord[1] in ['#', 'b', 'm', '5']: 37 | return chord[:2] 38 | else: 39 | return chord[:1] 40 | else: 41 | return chord[:1] 42 | 43 | 44 | def chord_name_filter(chords): 45 | search_name_list = [] 46 | for i in chords: 47 | if not i: 48 | continue 49 | search_name_list.append(get_chord_tonic(i)) 50 | return search_name_list 51 | 52 | 53 | def chord_to_tone(chord): 54 | tone = chord[0] 55 | if len(chord) > 2: 56 | if chord[1] in ['b', '#']: 57 | tone += chord[1] 58 | return tone 59 | 60 | 61 | def get_chords_repeat_num3(chords): 62 | chords_score_map = {} 63 | for ex_chords in MAJOR_CHORDS: 64 | chords_score_map[ex_chords[0]] = len([i for i in chords if i in ex_chords]) 65 | return chords_score_map 66 | 67 | 68 | def process_sus_chords(filter_total_chords, sus_chords): 69 | tone_count_dic = defaultdict(int) 70 | for chord in sus_chords: 71 | tone_count_dic[chord] += 1 72 | keys = list(tone_count_dic) 73 | product_list = list(product(*([[0, 1]] * len(keys)))) 74 | chords_score_list = [] 75 | for item in product_list: 76 | sus_conv_chord = [] 77 | for i, x in enumerate(item): 78 | key = keys[i] 79 | tone = chord_to_tone(key) 80 | if x == 0: 81 | sus_conv_chord.extend([tone] * tone_count_dic[key]) 82 | else: 83 | sus_conv_chord.extend([tone + 'm'] * tone_count_dic[key]) 84 | chords_score_map = get_chords_repeat_num3(filter_total_chords + sus_conv_chord) 85 | chords_score_list.extend([(k, i, item) for k, i in chords_score_map.items() if i != 0]) 86 | chords_score_list = sorted(chords_score_list, key=lambda x: x[1], reverse=True) 87 | return chords_score_list, keys 88 | 89 | 90 | def get_ret_key(best_key, chord, keys, maj_min_dic, is_sus, ret_key, sus_dic): 91 | # print(best_key, "best_key", chord) 92 | for k, i, item in best_key: 93 | if 'sus' in chord or 'dim' in chord or 'aug' in chord or '5' in chord: 94 | tone = chord_to_tone(chord) 95 | if item[keys.index(chord)] == 1: 96 | tone += 'm' 97 | chord = tone 98 | if chord == maj_min_dic[k][0] or chord == maj_min_dic[k][5]: 99 | ret_key = k 100 | if is_sus: 101 | sus_dic = {keys[i]: chord_to_tone(keys[i]) if v == 0 else chord_to_tone(keys[i]) + 'm' for i, v in 102 | enumerate(item)} 103 | break 104 | return ret_key, sus_dic 105 | 106 | 107 | def get_ret_key2(best_key, head_chord, tail_chord, maj_min_dic, ret_key): 108 | for k, i in best_key: 109 | if head_chord == maj_min_dic[k][0] or head_chord == maj_min_dic[k][5]: 110 | ret_key = k 111 | break 112 | if not ret_key: 113 | for k, i in best_key: 114 | if tail_chord == maj_min_dic[k][0] or tail_chord == maj_min_dic[k][5]: 115 | ret_key = k 116 | break 117 | return ret_key 118 | 119 | 120 | def process_best_key(chords_score_list, head_chord, tail_chord, keys, maj_min_dic, ret_key, sus_dic, is_sus=True): 121 | # 获取key 122 | if not chords_score_list: 123 | return None, None 124 | if is_sus: 125 | best_key = [x for x in chords_score_list if x[1] == chords_score_list[0][1]] 126 | else: 127 | best_key = [x for x in chords_score_list if x[-1] == chords_score_list[0][-1]] 128 | if len(best_key) > 1 and is_sus: 129 | ret_key, sus_dic = get_ret_key(best_key, head_chord, keys, maj_min_dic, is_sus, ret_key, sus_dic) 130 | if not ret_key: 131 | ret_key, sus_dic = get_ret_key(best_key, tail_chord, keys, maj_min_dic, is_sus, ret_key, sus_dic) 132 | elif len(best_key) > 1 and not is_sus: 133 | ret_key = get_ret_key2(best_key, head_chord, tail_chord, maj_min_dic, ret_key) 134 | # else: 135 | # return None, None 136 | if not ret_key: 137 | if len(best_key[0]) == 2: 138 | ret_key = best_key[0][0] 139 | else: 140 | ret_key = best_key[0][0] 141 | sus_dic = {keys[i]: chord_to_tone(keys[i]) if v == 0 else chord_to_tone(keys[i]) + 'm' for i, v in 142 | enumerate(best_key[0][2])} 143 | return ret_key, sus_dic 144 | 145 | 146 | def get_maj_or_min_key(ret_key, maj_min_dic, head_chord, tail_chord, sus_dic): 147 | # check maj or min 148 | maj_key = ret_key 149 | min_key = MAJ2MIN_MAP[ret_key] 150 | 151 | maj_chords = maj_min_dic[maj_key] 152 | 153 | if 'sus' in head_chord or 'dim' in head_chord or 'aug' in head_chord or '5' in head_chord: 154 | head_chord = sus_dic[head_chord] 155 | 156 | if head_chord == maj_chords[0]: 157 | ret_key = maj_key 158 | elif head_chord == maj_chords[5]: 159 | ret_key = min_key 160 | else: 161 | if 'sus' in tail_chord or 'dim' in tail_chord or 'aug' in tail_chord or '5' in tail_chord: 162 | tail_chord = sus_dic[tail_chord] 163 | 164 | if tail_chord == maj_chords[0]: 165 | ret_key = maj_key 166 | elif tail_chord == maj_chords[5]: 167 | ret_key = min_key 168 | else: 169 | ret_key = maj_key 170 | return ret_key 171 | 172 | 173 | def search_key(total_chords): 174 | if not total_chords: 175 | return "C", 0 176 | maj_min_dic = {x[0]: x for x in MAJOR_CHORDS + MINOR_CHORDS} 177 | ret_key = None 178 | keys = [] 179 | sus_dic = {} 180 | total_chords = chord_name_filter(total_chords) 181 | head_chord = total_chords[0] 182 | tail_chord = total_chords[-1] 183 | 184 | sus_chords = [x for x in total_chords if 'sus' in x or 'dim' in x or 'aug' in x or '5' in x] 185 | filter_total_chords = [x for x in total_chords if 186 | 'sus' not in x and 'dim' not in x and 'aug' not in x and '5' not in x] 187 | if not filter_total_chords and not sus_chords: 188 | return "C", 0 189 | 190 | if sus_chords: 191 | chords_score_list, keys = process_sus_chords(filter_total_chords, sus_chords) 192 | 193 | ret_key, sus_dic = process_best_key(chords_score_list, head_chord, tail_chord, keys, maj_min_dic, ret_key, 194 | sus_dic) 195 | else: 196 | chords_score_map = get_chords_repeat_num3(filter_total_chords) 197 | chords_score_list = [(k, i) for k, i in chords_score_map.items() if i != 0] 198 | chords_score_list = sorted(chords_score_list, key=lambda x: x[1], reverse=True) 199 | print(chords_score_list) 200 | ret_key, sus_dic = process_best_key(chords_score_list, head_chord, tail_chord, keys, maj_min_dic, ret_key, 201 | sus_dic, is_sus=False) 202 | key_score = 0 203 | for i in chords_score_list: 204 | if i[0] == ret_key: 205 | key_score = i[1] 206 | if not ret_key: 207 | return "C", 0 208 | ret_key = get_maj_or_min_key(ret_key, maj_min_dic, head_chord, tail_chord, sus_dic) 209 | 210 | return ret_key, key_score 211 | 212 | 213 | def chord_convert(chord): 214 | chord_convert_map = { 215 | "C": "Db", 216 | "D": "Eb", 217 | "E": "F", 218 | "F": "Gb", 219 | "G": "Ab", 220 | "A": "Bb", 221 | "B": "C" 222 | } 223 | if "#" in chord: 224 | a, b = chord.split('#') 225 | new_chord = chord_convert_map[a] + b 226 | else: 227 | new_chord = chord 228 | return new_chord 229 | 230 | 231 | if __name__ == '__main__': 232 | demo_chords = ['D', 'D', 'Bm7', 'Bm7', 'Asus4', 'A', 'G', 'G', 'D', 'D', 'D', 'Bm7', 'A', 'Asus4', 'A', 'G', 'D', 233 | 'D', 'D', 'D', 'Bm7', 'A', 'Asus4', 'A', 'G', 'D', 'E', 'D', 'D', 'Bm', 'Bm7', 'A', 'A', 'G', 'G', 234 | 'D', 'D', 'Bm7', 'Bm', 'Bm7', 'A', 'A', 'D', 'G', 'G', 'Bm7', 'Bm', 'A', 'G', 'G', 'A', 'D', 'D', 235 | 'D', 'D', 'A', 'Asus4', 'A', 'Asus4', 'Em7', 'G', 'D', 'D', 'D', 'D', 'D', 'A', 'A7', 'A7', 'G', 'G', 236 | 'D', 'D', 'Bm7', 'Bm7', 'A', 'A', 'G', 'G', 'D', 'D', 'Bm', 'Bm7', 'A', 'Asus4', 'A', 'D', 'G', 'G', 237 | 'Bm', 'A', 'G', 'G', 'A', 'G', 'G', 'Abm', 'Dbm', 'B', 'A', 'Abm7', 'Ab7', 'A', 'Gbm7', 'E', 'Gbm7', 238 | 'Ab7', 'Dbm7', 'B', 'A', 'E', 'Ab7', 'A', 'E', 'Gbm7', 'Ab7', 'Dbm', 'B', 'A', 'A', 'B', 'E', 'B', 239 | 'Abm', 'A', 'E', 'B', 'Abm7', 'Abm', 'A', 'F', 'C', 'A', 'Bb', 'F', 'C', 'A', 'F', 'F', 'Bb', 'Bb', 240 | 'F', 'F', 'Dm7', 'D', 'Dm7', 'C', 'C', 'Bb', 'Bb', 'F', 'F', 'C', 'Dm7', 'D', 'Dm7', 'F', 'Gsus4', 241 | 'C', 'Bb', 'Bb', 'Bb', 'Bb', 'F', 'F'] 242 | demo_chords1 = [] 243 | for i in demo_chords: 244 | demo_chords1.append(chord_convert(i)) 245 | print(search_key(demo_chords)) 246 | print(search_key(demo_chords1)) 247 | -------------------------------------------------------------------------------- /cover/models/pitch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from utils.spec import stft 7 | 8 | 9 | class ResConvBlock(nn.Module): 10 | def __init__(self, 11 | in_channels, 12 | out_channels, 13 | kernel_size=(3, 3), 14 | stride=(1, 1), 15 | padding=(1, 1), 16 | momentum=0.01, 17 | bias=True): 18 | super().__init__() 19 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias) 20 | self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum) 21 | self.relu1 = nn.ReLU() 22 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=bias) 23 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum) 24 | self.relu2 = nn.ReLU() 25 | 26 | if in_channels != out_channels or stride != (1, 1): 27 | self.residual_connection = nn.Conv2d(in_channels, 28 | out_channels, 29 | kernel_size=(1, 1), 30 | stride=stride, 31 | padding=(0, 0), 32 | bias=bias) 33 | else: 34 | self.residual_connection = nn.Identity() 35 | 36 | def forward(self, x): 37 | identity = self.residual_connection(x) 38 | x = self.conv1(x) 39 | x = self.bn1(x) 40 | x = self.relu1(x) 41 | x = self.conv2(x) 42 | x = self.bn2(x) 43 | x += identity 44 | x = self.relu2(x) 45 | return x 46 | 47 | 48 | class ResEncoderBlock(nn.Module): 49 | def __init__(self, in_channels, out_channels, n_blocks=1, momentum=0.01): 50 | super().__init__() 51 | self.n_blocks = n_blocks 52 | self.res_conv_layers = nn.ModuleList() 53 | for i in range(self.n_blocks): 54 | if i == 0: 55 | self.res_conv_layers.append(ResConvBlock(in_channels, out_channels, momentum=momentum)) 56 | else: 57 | self.res_conv_layers.append(ResConvBlock(out_channels, out_channels, momentum=momentum)) 58 | 59 | def forward(self, x): 60 | for i in range(self.n_blocks): 61 | x = self.res_conv_layers[i](x) 62 | return x 63 | 64 | 65 | class Encoder(nn.Module): 66 | def __init__(self, 67 | in_channels, 68 | out_channels, 69 | n_layers, 70 | pool_size, 71 | n_blocks=1, 72 | momentum=0.01): 73 | super().__init__() 74 | self.n_layers = n_layers 75 | 76 | self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) 77 | self.layers = nn.ModuleList() 78 | self.pool_layers = nn.ModuleList() 79 | for i in range(self.n_layers): 80 | self.layers.append(ResEncoderBlock(in_channels, out_channels, n_blocks=n_blocks, momentum=momentum)) 81 | self.pool_layers.append(nn.AvgPool2d(kernel_size=pool_size)) 82 | in_channels = out_channels 83 | out_channels *= 2 84 | self.out_channels = out_channels 85 | 86 | def forward(self, x): 87 | x = self.bn(x) 88 | h = [] 89 | for i in range(self.n_layers): 90 | t = self.layers[i](x) 91 | x = self.pool_layers[i](t) 92 | h.append(t) 93 | return x, h 94 | 95 | 96 | class Enhancer(nn.Module): 97 | def __init__(self, in_channels, out_channels, n_layers, n_blocks, momentum=0.01): 98 | super().__init__() 99 | 100 | self.n_layers = n_layers 101 | self.layers = nn.ModuleList() 102 | for i in range(self.n_layers): 103 | if i == 0: 104 | self.layers.append(ResEncoderBlock(in_channels, out_channels, n_blocks=n_blocks, momentum=momentum)) 105 | else: 106 | self.layers.append(ResEncoderBlock(out_channels, out_channels, n_blocks=n_blocks, momentum=momentum)) 107 | 108 | def forward(self, x): 109 | for i in range(self.n_layers): 110 | x = self.layers[i](x) 111 | return x 112 | 113 | 114 | class ResDecoderBlock(nn.Module): 115 | def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): 116 | super().__init__() 117 | self.n_blocks = n_blocks 118 | self.conv = nn.Sequential( 119 | nn.ConvTranspose2d( 120 | in_channels=in_channels, 121 | out_channels=out_channels, 122 | kernel_size=(3, 3), 123 | stride=stride, 124 | padding=(1, 1), 125 | output_padding=(0, 1) if stride == (1, 2) else (1, 1), 126 | bias=False, 127 | ), 128 | nn.BatchNorm2d(out_channels, momentum=momentum), 129 | nn.ReLU() 130 | ) 131 | self.res_conv_layers = nn.ModuleList() 132 | for i in range(self.n_blocks): 133 | if i == 0: 134 | self.res_conv_layers.append(ResConvBlock(out_channels * 2, out_channels, momentum=momentum)) 135 | else: 136 | self.res_conv_layers.append(ResConvBlock(out_channels, out_channels, momentum=momentum)) 137 | 138 | def forward(self, x, h): 139 | x = self.conv(x) 140 | x = torch.cat((x, h), dim=1) 141 | for i in range(self.n_blocks): 142 | x = self.res_conv_layers[i](x) 143 | return x 144 | 145 | 146 | class Decoder(nn.Module): 147 | def __init__(self, 148 | in_channels, 149 | n_layers, 150 | stride, 151 | n_blocks, 152 | momentum=0.01): 153 | super().__init__() 154 | self.n_layers = n_layers 155 | self.layers = nn.ModuleList() 156 | for i in range(self.n_layers): 157 | self.layers.append( 158 | ResDecoderBlock(in_channels, in_channels // 2, stride, n_blocks=n_blocks, momentum=momentum)) 159 | in_channels //= 2 160 | 161 | def forward(self, x, h): 162 | for i in range(self.n_layers): 163 | x = self.layers[i](x, h[-i - 1]) 164 | return x 165 | 166 | 167 | class BiGRU(nn.Module): 168 | def __init__(self, input_size, hidden_size, num_layers, dropout=0.25): 169 | super().__init__() 170 | self.gru = nn.GRU(input_size, 171 | hidden_size, 172 | num_layers=num_layers, 173 | batch_first=True, 174 | bidirectional=True, 175 | dropout=dropout) 176 | 177 | def forward(self, x): 178 | x, _ = self.gru(x) 179 | return x 180 | 181 | 182 | class MelSpectrogram(nn.Module): 183 | def __init__(self, 184 | nfft, 185 | n_mels, 186 | hop_length = None, 187 | mel_f_min=30.0, 188 | mel_f_max=8000.0, 189 | samplerate=44100): 190 | super().__init__() 191 | self.samplerate = samplerate 192 | self.nfft = nfft 193 | self.n_stft = nfft // 2 + 1 194 | self.hop_length = nfft // 4 if hop_length is None else hop_length 195 | self.n_mels = n_mels 196 | self.mel_scale = torchaudio.transforms.MelScale( 197 | n_mels=n_mels, sample_rate=samplerate, n_stft=self.n_stft, f_min=mel_f_min, f_max=mel_f_max) 198 | 199 | def _stft(self, x): 200 | hl = self.hop_length 201 | nfft = self.nfft 202 | return stft(x, fft_size=nfft, hop_length=hl) 203 | 204 | def _mel(self, stft): 205 | magnitude = stft.abs().pow(2) 206 | mel = self.mel_scale(magnitude) 207 | return mel 208 | 209 | def forward(self, x): 210 | stft = self._stft(x) 211 | mel = self._mel(stft) 212 | return mel 213 | 214 | 215 | class PitchNet(nn.Module): 216 | 217 | def __init__(self, 218 | nfft, 219 | mel_f_min=30.0, 220 | mel_f_max=8000.0, 221 | hop_length=None, 222 | samplerate=16000, 223 | en_layers=5, 224 | re_layers=4, 225 | de_layers=5, 226 | n_blocks=5, 227 | in_channels=1, 228 | en_out_channels=16, 229 | pool_size=(2, 2), 230 | n_gru=1, 231 | n_mels=128, 232 | n_classes=360): 233 | super().__init__() 234 | 235 | self.mel_spectrogram = MelSpectrogram(n_mels=n_mels, 236 | samplerate=samplerate, 237 | hop_length=hop_length, 238 | nfft=nfft, 239 | mel_f_min=mel_f_min, 240 | mel_f_max=mel_f_max) 241 | self.encoder = Encoder(in_channels, en_out_channels, en_layers, n_blocks=n_blocks, pool_size=pool_size) 242 | self.enhancer = Enhancer(self.encoder.out_channels // 2, self.encoder.out_channels, re_layers, 243 | n_blocks=n_blocks) 244 | self.decoder = Decoder(self.encoder.out_channels, de_layers, stride=pool_size, n_blocks=n_blocks) 245 | 246 | self.conv = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) 247 | if n_gru > 0: 248 | self.fc = nn.Sequential( 249 | BiGRU(3 * n_mels, 256, num_layers=n_gru), 250 | nn.Linear(256 * 2, n_classes), 251 | nn.Dropout(0.25), 252 | nn.Sigmoid() 253 | ) 254 | else: 255 | self.fc = nn.Sequential( 256 | nn.Linear(3 * n_mels, n_classes), 257 | nn.Dropout(0.25), 258 | nn.Sigmoid() 259 | ) 260 | 261 | def forward(self, x): 262 | x = self.mel_spectrogram(x) 263 | n_frames = x.size(-1) 264 | n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames 265 | if n_pad > 0: 266 | x = F.pad(x, (0, n_pad), mode="constant") 267 | x = x.transpose(-1, -2).unsqueeze(1) 268 | x, h = self.encoder(x) 269 | x = self.enhancer(x) 270 | x = self.decoder(x, h) 271 | x = self.conv(x) 272 | x = x.transpose(1, 2).flatten(-2) 273 | x = self.fc(x) 274 | x = x[:, :n_frames] 275 | return x 276 | 277 | -------------------------------------------------------------------------------- /cover/models/generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import torchaudio.transforms as transforms 5 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 6 | 7 | from utils.spec import stft, istft, chroma 8 | from utils.noise import gaussian_noise, markov_noise, random_walk_noise, spectral_folding 9 | 10 | 11 | def positional_encoding(batch_size, n_time, n_feature, zero_pad=False, scale=False, dtype=torch.float32): 12 | pos_indices = torch.tile(torch.unsqueeze(torch.arange(n_time), 0), [batch_size, 1]) 13 | 14 | pos = torch.arange(n_time, dtype=dtype).reshape(-1, 1) 15 | pos_enc = pos / torch.pow(10000, 2 * torch.arange(0, n_feature, dtype=dtype) / n_feature) 16 | pos_enc[:, 0::2] = torch.sin(pos_enc[:, 0::2]) 17 | pos_enc[:, 1::2] = torch.cos(pos_enc[:, 1::2]) 18 | 19 | if zero_pad: 20 | pos_enc = torch.cat([torch.zeros(size=[1, n_feature]), pos_enc[1:, :]], 0) 21 | 22 | outputs = F.embedding(pos_indices, pos_enc) 23 | 24 | if scale: 25 | outputs = outputs * (n_feature ** 0.5) 26 | 27 | return outputs 28 | 29 | 30 | class FeedForward(nn.Module): 31 | def __init__(self, n_feature=2048, n_hidden=512, dropout=0.5): 32 | super().__init__() 33 | self.linear1 = nn.Linear(n_feature, n_hidden) 34 | self.linear2 = nn.Linear(n_hidden, n_feature) 35 | 36 | self.dropout1 = nn.Dropout(dropout) 37 | self.dropout2 = nn.Dropout(dropout) 38 | self.norm_layer = nn.LayerNorm(n_feature) 39 | 40 | def forward(self, x): 41 | y = self.linear1(x) 42 | y = F.relu(y) 43 | y = self.dropout1(y) 44 | y = self.linear2(y) 45 | y = self.dropout2(y) 46 | y = self.norm_layer(y) 47 | return y 48 | 49 | 50 | class MeanStdNormalization(nn.Module): 51 | def __init__(self, dim=0, eps=1e-5): 52 | super(MeanStdNormalization, self).__init__() 53 | self.dim = dim 54 | self.eps = eps 55 | 56 | def forward(self, x): 57 | mean = x.mean(dim=self.dim, keepdim=True) 58 | std = x.std(dim=self.dim, keepdim=True) + self.eps 59 | x_normalized = (x - mean) / std 60 | return x_normalized 61 | 62 | 63 | class FLEXLayer(nn.Module): 64 | def __init__(self, input_dim, output_dim, num_layers=3, dropout_rate=0.3): 65 | super().__init__() 66 | 67 | self.layers = nn.ModuleList() 68 | self.dropout = nn.Dropout(dropout_rate) 69 | 70 | if input_dim != output_dim: 71 | self.projection = nn.Conv1d(input_dim, output_dim, kernel_size=1) 72 | else: 73 | self.projection = None 74 | 75 | for _ in range(num_layers): 76 | self.layers.append(nn.Conv1d(output_dim, output_dim, kernel_size=3, padding=1)) 77 | self.layers.append(nn.ReLU()) 78 | self.layers.append(nn.BatchNorm1d(output_dim)) 79 | 80 | def forward(self, x): 81 | 82 | if self.projection is not None: 83 | x = self.projection(x) 84 | 85 | x1 = x 86 | for layer in self.layers: 87 | x1 = layer(x1) 88 | 89 | x += x1 90 | x = self.dropout(x) 91 | 92 | return x 93 | 94 | 95 | class AFALayer(nn.Module): 96 | def __init__(self, d_model=2048, hidden_size=2048, n_conv_layers=3): 97 | super().__init__() 98 | 99 | self.lstm = nn.LSTM(input_size=d_model, hidden_size=hidden_size, 100 | batch_first=True, bidirectional=False) 101 | 102 | conv_layers = [] 103 | for _ in range(n_conv_layers): 104 | conv_layers.append(nn.Conv1d(in_channels=hidden_size, out_channels=hidden_size, 105 | kernel_size=3, padding=1)) 106 | conv_layers.append(nn.ReLU()) 107 | 108 | self.conv = nn.Sequential(*conv_layers) 109 | 110 | self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=8, batch_first=True) 111 | 112 | def forward(self, x1, x2, x3): 113 | x1 = x1.permute(0, 2, 1) 114 | x2 = x2.permute(0, 2, 1) 115 | x3 = x3.permute(0, 2, 1) 116 | 117 | combined_features = torch.cat((x1, x2, x3), dim=-1) 118 | 119 | lstm_out, _ = self.lstm(combined_features) 120 | 121 | lstm_out = lstm_out.transpose(1, 2) 122 | conv_out = self.conv(lstm_out) 123 | conv_out = conv_out.transpose(1, 2) 124 | 125 | attn_output, _ = self.attention(conv_out, conv_out, conv_out) 126 | 127 | return attn_output 128 | 129 | 130 | class TimbreBlock(nn.Module): 131 | def __init__(self, 132 | n_fs=2049, 133 | n_fm=13, 134 | n_fc=12, 135 | n_out=768, 136 | nhead=8, 137 | num_layers=4, 138 | d_model=2048): 139 | super().__init__() 140 | 141 | self.auto_feature_x1 = FLEXLayer(input_dim=n_fs, output_dim=d_model) 142 | self.auto_feature_x2 = FLEXLayer(input_dim=n_fm, output_dim=d_model) 143 | self.auto_feature_x3 = FLEXLayer(input_dim=n_fc, output_dim=d_model) 144 | 145 | self.attention_fusion = AFALayer(d_model=d_model * 3) 146 | 147 | encoder_layers = TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True) 148 | self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers=num_layers) 149 | 150 | self.output_layer = nn.Linear(d_model, n_out) 151 | 152 | self.norm = MeanStdNormalization(dim=-1) 153 | 154 | def forward(self, x1, x2, x3): 155 | x1_transformed = self.auto_feature_x1(x1) 156 | x2_transformed = self.auto_feature_x2(self.norm(x2)) 157 | x3_transformed = self.auto_feature_x3(self.norm(x3)) 158 | 159 | fused_features = self.attention_fusion(x1_transformed, x2_transformed, 160 | x3_transformed) 161 | 162 | encoded_features = self.transformer_encoder(fused_features) 163 | output = self.output_layer(encoded_features) 164 | return output 165 | 166 | 167 | class EncoderBlock(nn.Module): 168 | def __init__(self, n_feat, nhead=8, n_layers=5, dropout=0.5, pr=0.02): 169 | super().__init__() 170 | self.n_feat = n_feat 171 | self.n_layers = n_layers 172 | self.pr = pr 173 | 174 | self.attn_layer1 = nn.ModuleList() 175 | self.attn_layer2 = nn.ModuleList() 176 | self.ff_layer1 = nn.ModuleList() 177 | self.ff_layer2 = nn.ModuleList() 178 | for _ in range(self.n_layers): 179 | _attn_layer1 = nn.MultiheadAttention(n_feat, nhead, batch_first=True) 180 | _attn_layer2 = nn.MultiheadAttention(n_feat, nhead, batch_first=True) 181 | _ff_layer1 = FeedForward(n_feat, dropout=dropout) 182 | _ff_layer2 = FeedForward(n_feat, dropout=dropout) 183 | self.attn_layer1.append(_attn_layer1) 184 | self.attn_layer2.append(_attn_layer2) 185 | self.ff_layer1.append(_ff_layer1) 186 | self.ff_layer2.append(_ff_layer2) 187 | 188 | self.norm_layer = nn.LayerNorm(n_feat) 189 | self.dropout = nn.Dropout(dropout) 190 | self.fc = nn.Linear(n_feat, n_feat) 191 | 192 | def forward(self, x, x1): 193 | x = x.permute(0, 2, 1) 194 | x1 = x1.permute(0, 2, 1) 195 | 196 | B, T, Ft = x.shape 197 | x += positional_encoding(B, T, Ft) * self.pr 198 | B1, T1, Ft1 = x1.shape 199 | x1 += positional_encoding(B1, T1, Ft1) * self.pr 200 | 201 | for _attn_layer1, _attn_layer2, _ff_layer1, _ff_layer2 in zip(self.attn_layer1, 202 | self.attn_layer2, 203 | self.ff_layer1, 204 | self.ff_layer2): 205 | x, _ = _attn_layer1(x, x, x, need_weights=False) 206 | x1, _ = _attn_layer2(x, x1, x1, need_weights=False) 207 | x = _ff_layer1(x) 208 | x1 = _ff_layer2(x1) 209 | 210 | x = torch.stack([x, x1], dim=1) 211 | x = self.dropout(x) 212 | x = self.fc(x) 213 | x = self.norm_layer(x) 214 | x1 = x[:, 1, ].permute(0, 2, 1) 215 | x = x[:, 0, ].permute(0, 2, 1) 216 | return x, x1 217 | 218 | 219 | class Encoder(nn.Module): 220 | def __init__(self, 221 | n_feat, 222 | n_stft, 223 | out_feat, 224 | n_block=5, 225 | nhead=8, 226 | n_layers=5, 227 | dropout=0.25, 228 | pr=0.02): 229 | super().__init__() 230 | 231 | self.n_block = n_block 232 | 233 | self.conv_pre1 = nn.Sequential( 234 | nn.ConvTranspose1d(n_stft, n_feat, kernel_size=1), 235 | nn.BatchNorm1d(n_feat), 236 | nn.LeakyReLU(0.01), 237 | ) 238 | self.conv_pre2 = nn.Sequential( 239 | nn.ConvTranspose1d(n_feat, n_feat, kernel_size=1), 240 | nn.BatchNorm1d(n_feat), 241 | nn.LeakyReLU(0.01), 242 | ) 243 | self.layers = nn.ModuleList() 244 | for i in range(n_block): 245 | self.layers.append(EncoderBlock(n_feat, nhead=nhead, n_layers=n_layers, dropout=dropout, pr=pr)) 246 | 247 | self.dropout = nn.Dropout(dropout) 248 | self.conv = nn.Sequential( 249 | nn.Conv1d(n_feat * 2, n_feat, kernel_size=1), 250 | nn.BatchNorm1d(n_feat), 251 | nn.ReLU(), 252 | nn.Conv1d(n_feat, n_feat, kernel_size=1), 253 | nn.BatchNorm1d(n_feat), 254 | nn.ReLU(), 255 | ) 256 | self.linear = nn.Linear(n_feat, out_feat) 257 | 258 | def forward(self, x, x1, x2): 259 | x = self.conv_pre1(x) 260 | x1 = x1 + x2 261 | x1 = x1.transpose(-1, -2) 262 | x1 = self.conv_pre2(x1) 263 | 264 | for layer in self.layers: 265 | x, x1 = layer(x, x1) 266 | 267 | z = torch.hstack([x, x1]) 268 | z = self.conv(z) 269 | z = z.transpose(-2, -1) 270 | z = self.dropout(z) 271 | logits = self.linear(z) 272 | return logits, z 273 | 274 | 275 | class DecoderBlock(nn.Module): 276 | 277 | def __init__(self, d_model, out_feat, nhead, n_layers, dropout=0.5): 278 | super().__init__() 279 | self.d_model = d_model 280 | self.nhead = nhead 281 | self.n_layers = n_layers 282 | 283 | self.attn1_layer = nn.ModuleList() 284 | self.attn2_layer = nn.ModuleList() 285 | self.ff_layer = nn.ModuleList() 286 | for _ in range(n_layers): 287 | _layer1 = nn.MultiheadAttention(d_model, nhead, batch_first=True) 288 | _layer2 = nn.MultiheadAttention(d_model, nhead, batch_first=True) 289 | _layer3 = FeedForward(d_model, dropout=dropout) 290 | self.attn1_layer.append(_layer1) 291 | self.attn2_layer.append(_layer2) 292 | self.ff_layer.append(_layer3) 293 | 294 | self.dropout = nn.Dropout(dropout) 295 | self.fc = nn.Sequential( 296 | nn.Linear(d_model, d_model), 297 | nn.Dropout(dropout), 298 | nn.LeakyReLU(0.01), 299 | nn.Linear(d_model, out_feat) 300 | ) 301 | 302 | def forward(self, x, x1): 303 | for i in range(self.n_layers): 304 | _attn1_layer = self.attn1_layer[i] 305 | _attn2_layer = self.attn2_layer[i] 306 | _ff_layer = self.ff_layer[i] 307 | x, _ = _attn1_layer(x, x, x, need_weights=False) 308 | x, _ = _attn2_layer(x, x1, x1, need_weights=False) 309 | x = _ff_layer(x) 310 | 311 | output = self.dropout(x) 312 | output = self.fc(output) 313 | return output 314 | 315 | 316 | class Decoder(nn.Module): 317 | 318 | def __init__(self, 319 | n_feat, 320 | n_timbre, 321 | n_pitch, 322 | nfft=1024, 323 | d_model=512, 324 | nhead=8, 325 | n_layers=5, 326 | dropout=0.5, 327 | pr=0.01): 328 | super().__init__() 329 | 330 | self.n_layers = n_layers 331 | self.d_model = d_model 332 | self.pr = pr 333 | 334 | self.emb_feat = nn.Linear(n_feat, d_model) 335 | self.emb_timbre = nn.Linear(n_timbre, d_model) 336 | self.emb_pitch = nn.Linear(n_pitch, d_model) 337 | 338 | self.lrelu = nn.LeakyReLU(0.01, inplace=True) 339 | 340 | n_out = nfft // 2 + 1 341 | self.dec = DecoderBlock(d_model, n_out, nhead, n_layers, dropout=dropout) 342 | 343 | self.exp_layer = nn.Sequential( 344 | nn.Conv1d(d_model, d_model, kernel_size=1, bias=True), 345 | nn.BatchNorm1d(d_model), 346 | nn.Conv1d(d_model, d_model, kernel_size=1, bias=True), 347 | nn.LeakyReLU(0.01), 348 | nn.Conv1d(d_model, d_model, kernel_size=1, bias=True), 349 | nn.BatchNorm1d(d_model), 350 | nn.LeakyReLU(0.01), 351 | nn.Conv1d(d_model, n_out, kernel_size=1, bias=True), 352 | nn.BatchNorm1d(n_out), 353 | nn.LeakyReLU(0.01), 354 | ) 355 | self.mask_layer = nn.Sequential( 356 | nn.Linear(n_out, n_out), 357 | nn.LeakyReLU(0.01), 358 | nn.Linear(n_out, n_out), 359 | nn.LeakyReLU(0.01), 360 | nn.Sigmoid() 361 | ) 362 | 363 | def forward(self, x1, x2, pitch): 364 | x = self.emb_timbre(x1) + self.emb_feat(x2) 365 | x = x * self.d_model ** (1 / 2) 366 | x = self.lrelu(x) 367 | y = x.transpose(-2, -1) 368 | y = self.exp_layer(y) 369 | 370 | x2 = self.emb_pitch(pitch) 371 | x += positional_encoding(x.shape[0], x.shape[1], x.shape[2]) * self.pr + self.pr 372 | x = self.dec(x, x2) 373 | 374 | y = y.transpose(-2, -1) 375 | z = y * x 376 | 377 | m = self.mask_layer(z) 378 | m = m.transpose(-2, -1) 379 | return m, z, x 380 | 381 | 382 | class CombineNet(nn.Module): 383 | 384 | def __init__(self, 385 | n_hidden, 386 | n_timbre, 387 | nfft=1024, 388 | hop_length=160, 389 | samplerate=16000, 390 | n_feat=768, 391 | n_pitch=360, 392 | n_mfcc=40, 393 | n_chroma=12, 394 | n_cqt=84, 395 | n_tb_head=8, 396 | n_tb_layers=5, 397 | tb_dropout=0.5, 398 | pr=0.02, 399 | noise_type=None, 400 | noise_kw=None, 401 | ): 402 | super().__init__() 403 | 404 | self.hop_length = hop_length 405 | self.nfft = nfft 406 | self.n_cqt = n_cqt 407 | self.n_chroma = n_chroma 408 | self.samplerate = samplerate 409 | self.noise_type = noise_type 410 | self.noise_kw = noise_kw 411 | 412 | self.mfcc_layer = transforms.MFCC(sample_rate=samplerate, n_mfcc=n_mfcc, 413 | melkwargs={"n_fft": nfft, "hop_length": hop_length}) 414 | 415 | self.timber = TimbreBlock(n_fs=(nfft // 2) + 1, n_fc=n_chroma, n_fm=n_mfcc, n_out=n_feat) 416 | 417 | self.encoder = Encoder( 418 | n_feat, 419 | (nfft // 2) + 1, 420 | n_timbre, 421 | nhead=n_tb_head, 422 | n_layers=n_tb_layers, 423 | dropout=tb_dropout, 424 | pr=pr) 425 | 426 | self.dec_p = Decoder( 427 | n_feat=n_feat, 428 | n_timbre=n_timbre, 429 | n_pitch=n_pitch, 430 | d_model=n_hidden, 431 | nhead=8, 432 | n_layers=5, 433 | dropout=0.5, 434 | ) 435 | 436 | def _stft(self, x): 437 | hl = self.hop_length 438 | nfft = self.nfft 439 | return stft(x, fft_size=nfft, hop_length=hl) 440 | 441 | def _chroma(self, x): 442 | hl = self.hop_length 443 | return chroma(x, n_chroma=self.n_chroma, sample_rate=self.samplerate, hop_length=hl) 444 | 445 | def _mfcc(self, x): 446 | return self.mfcc_layer(x) 447 | 448 | def _istft(self, x, n_time): 449 | hl = self.hop_length 450 | return istft(x, hop_length=hl, signal_length=n_time) 451 | 452 | def _magnitude(self, z): 453 | # return the magnitude of the spectrogram, except when cac is True, 454 | # in which case we just move the complex dimension to the channel one. 455 | if self.cac: 456 | B, C, Fr, T = z.shape 457 | m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) 458 | m = m.reshape(B, C * 2, Fr, T) 459 | else: 460 | m = z.abs() 461 | return m 462 | 463 | def _noise(self, x): 464 | if self.noise_type is None: 465 | x = x 466 | else: 467 | noise_kw = self.noise_kw or {} 468 | if self.noise_type == 'gaussian': 469 | x = gaussian_noise(x, **noise_kw) 470 | elif self.noise_type == 'markov': 471 | x = markov_noise(x, **noise_kw) 472 | elif self.noise_type == 'random_walk': 473 | x = random_walk_noise(x, **noise_kw) 474 | elif self.noise_type == 'spectral_folding': 475 | x = spectral_folding(x, **noise_kw) 476 | return x 477 | 478 | def forward(self, audio, feat, f0, is_noise=False): 479 | n_len = audio.shape[-1] 480 | x = self._stft(audio) 481 | x_mfcc = self._mfcc(audio) 482 | x_chroma = self._chroma(audio) 483 | 484 | _mag = x.abs() 485 | if is_noise: 486 | _mag = self._noise(_mag) 487 | 488 | _, f0_t, f0_f = f0.shape 489 | _, ft_t, ft_f = feat.shape 490 | if f0_t < ft_t: 491 | feat = feat[:, :f0_t, :] 492 | elif f0_t > ft_t: 493 | feat = F.pad(feat, (0, 0, 0, f0_t - ft_t)) 494 | 495 | timber_feat = self.timber(_mag, x_mfcc, x_chroma) 496 | logits, h = self.encoder(_mag, feat, timber_feat) 497 | m, z, y = self.dec_p(logits, feat, f0) 498 | 499 | z = x * m 500 | o = self._istft(z, n_len) 501 | o = o[..., :n_len] 502 | return o, z, (y, m, logits, h) 503 | --------------------------------------------------------------------------------