├── cover
├── __init__.py
├── utils
│ ├── __init__.py
│ ├── audio.py
│ ├── noise.py
│ ├── utils.py
│ └── spec.py
├── models
│ ├── __init__.py
│ ├── unet.py
│ ├── pitch.py
│ └── generate.py
├── load.py
└── run.py
├── image
├── logo.png
├── net1.png
├── net2.png
├── tnn.png
├── model.png
└── website.png
├── python
├── models
│ ├── __init__.py
│ ├── load.py
│ ├── chord_net.py
│ ├── beat_net.py
│ ├── segment_net.py
│ ├── unet.py
│ ├── pitch_net.py
│ └── transformers.py
├── audio.py
├── utils.py
├── spec.py
├── common.py
├── aitabs.py
└── modulation.py
└── README.md
/cover/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cover/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/image/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/logo.png
--------------------------------------------------------------------------------
/image/net1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/net1.png
--------------------------------------------------------------------------------
/image/net2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/net2.png
--------------------------------------------------------------------------------
/image/tnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/tnn.png
--------------------------------------------------------------------------------
/image/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/model.png
--------------------------------------------------------------------------------
/image/website.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JoinMusic/fish/HEAD/image/website.png
--------------------------------------------------------------------------------
/cover/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .pitch import PitchNet
2 | from .generate import CombineNet
3 | from .unet import UNets
4 |
--------------------------------------------------------------------------------
/python/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .beat_net import BeatNet
2 | from .unet import UNets
3 | from .chord_net import ChordNet
4 | from .pitch_net import PitchNet
5 | from .segment_net import SegmentNet
6 | from .load import get_model
7 |
--------------------------------------------------------------------------------
/python/audio.py:
--------------------------------------------------------------------------------
1 | import torchaudio
2 | import torch
3 |
4 |
5 | def read_wav(wav_fp, sample_rate=44100, n_channel=2, device='cpu'):
6 | waveform, sr = torchaudio.load(wav_fp)
7 | assert waveform.ndim == 2
8 | assert waveform.shape[0] == n_channel
9 | if sr != sample_rate:
10 | waveform = torchaudio.transforms.Resample(sr, sample_rate)(waveform)
11 | waveform = waveform.to(device)
12 | return waveform, sample_rate
13 |
14 |
15 | def gen_wav(sample_rate=44100, n_channel=2, duration=120, device='cpu'):
16 | waveform = torch.randn(n_channel, sample_rate * duration)
17 | waveform = waveform.to(device)
18 | return waveform, sample_rate
19 |
20 |
21 | def write_wav(path, wav, sample_rate=44100):
22 | torchaudio.save(path, wav, sample_rate)
23 |
--------------------------------------------------------------------------------
/cover/load.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchaudio.pipelines import (HUBERT_BASE, HUBERT_LARGE)
3 |
4 |
5 | def get_model(model, config, model_path=None, is_train=True, device='cpu'):
6 | net = model(**config)
7 | if model_path:
8 | net.load_state_dict(torch.load(model_path, map_location=device))
9 | net.to(device)
10 |
11 | if is_train:
12 | net.train()
13 | else:
14 | net.eval()
15 |
16 | return net
17 |
18 |
19 | def get_hubert_model(name='base', dl_kwargs=None, device='cpu'):
20 | if name == 'base':
21 | bundle = HUBERT_BASE
22 | elif name == 'large':
23 | bundle = HUBERT_LARGE
24 | else:
25 | raise ValueError(f'Invalid model name: {name}')
26 |
27 | model = bundle.get_model(dl_kwargs=dl_kwargs)
28 | model.to(device)
29 | model.eval()
30 | return model
31 |
--------------------------------------------------------------------------------
/python/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from common import CHORD_LABELS, SEGMENT_LABELS
3 |
4 |
5 | def build_masked_stft(masks, stft_feature, n_fft=4096):
6 | out = []
7 | for i in range(len(masks)):
8 | mask = masks[i, :, :, :]
9 | pad_num = n_fft // 2 + 1 - mask.size(-1)
10 | mask = F.pad(mask, (0, pad_num, 0, 0, 0, 0))
11 | inst_stft = mask.type(stft_feature.dtype) * stft_feature
12 | out.append(inst_stft)
13 | return out
14 |
15 |
16 | def get_chord_name(chord_idx_list):
17 | chords = [CHORD_LABELS[idx] for idx in chord_idx_list]
18 | return chords
19 |
20 |
21 | def get_segment_name(segments):
22 | segments = [SEGMENT_LABELS[idx] for idx in segments]
23 | return segments
24 |
25 |
26 | def get_lyrics(waveform, sr, cfg):
27 | # asr and wav2vec2
28 | raise NotImplementedError()
29 |
--------------------------------------------------------------------------------
/python/models/load.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from models.unet import UNets
4 | from models.beat_net import BeatNet
5 | from models.chord_net import ChordNet
6 | from models.pitch_net import PitchNet
7 | from models.segment_net import SegmentNet
8 |
9 |
10 | def get_model_cls(s):
11 | if s == 'unet':
12 | return UNets
13 | elif s == 'beat':
14 | return BeatNet
15 | elif s == 'chord':
16 | return ChordNet
17 | elif s == 'pitch':
18 | return PitchNet
19 | elif s == 'segment':
20 | return SegmentNet
21 | else:
22 | raise ValueError(f'Invalid model name: {s}')
23 |
24 |
25 | def get_model(model, config, model_path=None, is_train=True, device='cpu'):
26 | if isinstance(model, str):
27 | model = get_model_cls(model)
28 |
29 | net = model(**config)
30 | if model_path:
31 | net.load_state_dict(torch.load(model_path, map_location=device))
32 | net.to(device)
33 |
34 | if is_train:
35 | net.train()
36 | else:
37 | net.eval()
38 |
39 | return net
40 |
--------------------------------------------------------------------------------
/cover/utils/audio.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 |
4 |
5 | def load_waveform(wav_fp, samplerate=16000, n_channel=2, device='cpu'):
6 | waveform, sr = torchaudio.load(wav_fp)
7 | assert waveform.ndim == 2
8 | if waveform.shape[0] != n_channel:
9 | if n_channel == 1:
10 | waveform = torch.mean(waveform, dim=0, keepdim=True)
11 | elif n_channel <= waveform.shape[0]:
12 | waveform = waveform[:n_channel]
13 | else:
14 | raise ValueError(f'Invalid number of channels: {waveform.shape[0]}')
15 | if sr != samplerate:
16 | waveform = torchaudio.transforms.Resample(sr, samplerate)(waveform)
17 | waveform = waveform.to(device)
18 | return waveform, samplerate
19 |
20 |
21 | def gen_waveform(samplerate=16000, n_channel=2, duration=120, device='cpu'):
22 | waveform = torch.randn(n_channel, samplerate * duration)
23 | waveform = waveform.to(device)
24 | return waveform, samplerate
25 |
26 |
27 | def save_waveform(path, wav, samplerate=16000):
28 | torchaudio.save(path, wav, samplerate)
29 |
--------------------------------------------------------------------------------
/cover/utils/noise.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def gaussian_noise(x, mean=0, std=0.1):
5 | noise = torch.normal(mean, std, x.size()).to(x.device)
6 | noisy_x = x + noise
7 | return noisy_x
8 |
9 |
10 | def markov_noise(x, transition_matrix):
11 | num_bins, num_frames = x.shape
12 | noisy_x = x.clone()
13 | state = torch.randint(0, num_bins, (1,)).item()
14 |
15 | for t in range(num_frames):
16 | noisy_x[:, t] = x[:, t] * (1 + transition_matrix[state])
17 | state = torch.multinomial(transition_matrix[state], 1).item()
18 |
19 | return noisy_x
20 |
21 |
22 | def random_walk_noise(x, step_size=0.1):
23 | noisy_x = x.clone()
24 | random_walk = torch.FloatTensor(x.size()).uniform_(-step_size, step_size).to(x.device)
25 | noisy_x += random_walk
26 | return noisy_x
27 |
28 |
29 | def spectral_folding(x, fold_frequency=0.5):
30 | num_bins, num_frames = x.shape
31 | fold_bin = int(fold_frequency * num_bins)
32 |
33 | folded_x = x.clone()
34 | folded_x[:fold_bin, :] += x[-fold_bin:, :]
35 | folded_x[-fold_bin:, :] = 0
36 |
37 | return folded_x
38 |
--------------------------------------------------------------------------------
/cover/utils/utils.py:
--------------------------------------------------------------------------------
1 | import uuid
2 |
3 |
4 | def generate_uuid(version=4, namespace=None, name=None):
5 | """
6 | Generate a UUID string based on the specified version.
7 |
8 | Parameters:
9 | - version (int): The UUID version (1, 3, 4, or 5). Default is 4.
10 | - namespace (str): The namespace for UUID3 and UUID5. Must be a valid UUID string.
11 | - name (str): The name for UUID3 and UUID5.
12 |
13 | Returns:
14 | - str: The generated UUID string.
15 | """
16 | if version == 1:
17 | return str(uuid.uuid1())
18 | elif version == 3:
19 | if namespace is None or name is None:
20 | raise ValueError("Namespace and name must be provided for UUID3.")
21 | namespace_uuid = uuid.UUID(namespace)
22 | return str(uuid.uuid3(namespace_uuid, name))
23 | elif version == 4:
24 | return str(uuid.uuid4())
25 | elif version == 5:
26 | if namespace is None or name is None:
27 | raise ValueError("Namespace and name must be provided for UUID5.")
28 | namespace_uuid = uuid.UUID(namespace)
29 | return str(uuid.uuid5(namespace_uuid, name))
30 | else:
31 | raise ValueError("Unsupported UUID version. Supported versions are 1, 3, 4, and 5.")
32 |
--------------------------------------------------------------------------------
/python/models/chord_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.transformers import BaseTransformer
4 |
5 |
6 | class ChordNet(nn.Module):
7 | def __init__(self,
8 | n_freq=2048,
9 | n_classes=122,
10 | n_group=32,
11 | f_layers=5,
12 | f_nhead=8,
13 | t_layers=5,
14 | t_nhead=8,
15 | d_layers=5,
16 | d_nhead=8,
17 | dropout=0.5,
18 | *args, **kwargs):
19 | super().__init__()
20 |
21 | self.transformers = BaseTransformer(n_freq=n_freq, n_group=n_group,
22 | f_layers=f_layers, f_nhead=f_nhead, f_dropout=dropout,
23 | t_layers=t_layers, t_nhead=t_nhead, t_dropout=dropout,
24 | d_layers=d_layers, d_nhead=d_nhead, d_dropout=dropout)
25 | self.dropout = nn.Dropout(dropout)
26 | self.fc = nn.Linear(n_freq, n_classes)
27 |
28 | def forward(self, x, weight=None):
29 | # x shape: (batch, channel, time, freq)
30 | output, weight_logits = self.transformers(x, weight)
31 | output = self.dropout(output)
32 | output = self.fc(output)
33 | output = output.argmax(dim=-1)
34 | return output, weight_logits
35 |
36 |
37 | if __name__ == '__main__':
38 | model = ChordNet()
39 | print(model)
40 | x = torch.randn(6, 2, 256, 2048)
41 | y, weight = model(x)
42 | print(y.shape, weight.shape)
43 |
--------------------------------------------------------------------------------
/python/spec.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 |
3 |
4 | def stft(x, n_fft=4096, hop_length=1024, pad=True):
5 | z = th.stft(x,
6 | n_fft,
7 | hop_length or n_fft // 4,
8 | window=th.hann_window(n_fft).to(device=x.device),
9 | win_length=n_fft,
10 | normalized=True,
11 | center=pad,
12 | return_complex=True,
13 | pad_mode='reflect')
14 | z = th.transpose(z, 1, 2)
15 | return z
16 |
17 |
18 | def istft(stft_feature, n_fft=4096, hop_length=1024, pad=True):
19 | stft_feature = th.transpose(stft_feature, 1, 2)
20 | waveform = th.istft(stft_feature,
21 | n_fft,
22 | hop_length,
23 | window=th.hann_window(n_fft).to(device=stft_feature.device),
24 | win_length=n_fft,
25 | normalized=True,
26 | center=pad)
27 | return waveform
28 |
29 |
30 | def get_spec(waveform, cfg):
31 | spec = stft(waveform, n_fft=cfg['n_fft'], hop_length=cfg['hop_length'], pad=cfg['pad'])
32 | return spec # channel, freq, time
33 |
34 |
35 | def get_specs(waveforms, cfg):
36 | # waveforms shape: sources, channel, time
37 | S, C, T = waveforms.shape
38 | _waveforms = waveforms.view(S * C, T)
39 | specs = stft(_waveforms, n_fft=cfg['n_fft'], hop_length=cfg['hop_length'], pad=cfg['pad'])
40 | return specs.view(S, C, specs.shape[-2], specs.shape[-1]) # sources, channel, freq, time
41 |
42 |
43 | def get_mixed_spec(waveforms, cfg):
44 | mixed_waveform = th.sum(waveforms, dim=0)
45 | mixed_spec = stft(mixed_waveform, n_fft=cfg['n_fft'], hop_length=cfg['hop_length'], pad=cfg['pad'])
46 | return mixed_spec # channel, freq, time
47 |
--------------------------------------------------------------------------------
/python/models/beat_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.transformers import BaseTransformer
4 |
5 |
6 | class BeatNet(nn.Module):
7 | def __init__(self,
8 | source=3,
9 | n_classes=3,
10 | weights=(0.4, 0.3, 0.3),
11 | n_freq=2048,
12 | n_group=32,
13 | f_layers=2,
14 | f_nhead=4,
15 | t_layers=2,
16 | t_nhead=4,
17 | d_layers=2,
18 | d_nhead=8,
19 | dropout=0.5,
20 | *args, **kwargs
21 | ):
22 | super().__init__()
23 | self.weights = weights
24 | self.transformer_layers = nn.ModuleList()
25 | for _ in range(source):
26 | _layer = BaseTransformer(n_freq=n_freq, n_group=n_group,
27 | f_layers=f_layers, f_nhead=f_nhead, f_dropout=dropout,
28 | t_layers=t_layers, t_nhead=t_nhead, t_dropout=dropout,
29 | d_layers=d_layers, d_nhead=d_nhead, d_dropout=dropout)
30 | self.transformer_layers.append(_layer)
31 | self.dropout = nn.Dropout(dropout)
32 | self.beat_fc = nn.Linear(n_freq, n_classes)
33 |
34 | self.reset_parameters(0.05)
35 |
36 | def reset_parameters(self, confidence):
37 | self.beat_fc.bias.data.fill_(-torch.log(torch.tensor(1 / confidence - 1)))
38 |
39 | def forward(self, inp):
40 | # shape: (batch, source, channel, time, freq)
41 |
42 | y_list = []
43 | logits_weight_list = []
44 | for i, layer in enumerate(self.transformer_layers):
45 | x = inp[:, i, :, :, :]
46 | x, _f = layer(x)
47 | w = self.weights[i]
48 | x = x * w
49 | y_list.append(x)
50 | logits_weight_list.append(_f * w)
51 | y = torch.sum(torch.stack(y_list, dim=0), dim=0)
52 | logits_weight = torch.sum(torch.stack(logits_weight_list, dim=0), dim=0)
53 |
54 | y = self.dropout(y)
55 | beats = self.beat_fc(y)
56 | beats = torch.argmax(beats, dim=-1)
57 | return beats, logits_weight
58 |
59 |
60 | if __name__ == '__main__':
61 | model = BeatNet()
62 | print(model)
63 | x = torch.randn(6, 3, 2, 256, 2048)
64 | b, weight = model(x)
65 | print(x.shape, b.shape, weight.shape)
66 |
--------------------------------------------------------------------------------
/python/common.py:
--------------------------------------------------------------------------------
1 | CHORD_LABELS = [
2 | 'N', 'X',
3 | # maj
4 | 'C:maj', 'C#:maj', 'D:maj', 'D#:maj', 'E:maj', 'F:maj', 'F#:maj', 'G:maj', 'G#:maj', 'A:maj', 'A#:maj', 'B:maj',
5 | # min
6 | 'C:min', 'C#:min', 'D:min', 'D#:min', 'E:min', 'F:min', 'F#:min', 'G:min', 'G#:min', 'A:min', 'A#:min', 'B:min',
7 | # 7
8 | 'C:7', 'C#:7', 'D:7', 'D#:7', 'E:7', 'F:7', 'F#:7', 'G:7', 'G#:7', 'A:7', 'A#:7', 'B:7',
9 | # maj7
10 | 'C:maj7', 'C#:maj7', 'D:maj7', 'D#:maj7', 'E:maj7', 'F:maj7', 'F#:maj7', 'G:maj7', 'G#:maj7', 'A:maj7', 'A#:maj7',
11 | 'B:maj7',
12 | # min7
13 | 'C:min7', 'C#:min7', 'D:min7', 'D#:min7', 'E:min7', 'F:min7', 'F#:min7', 'G:min7', 'G#:min7', 'A:min7', 'A#:min7',
14 | 'B:min7',
15 | # 6
16 | 'C:6', 'C#:6', 'D:6', 'D#:6', 'E:6', 'F:6', 'F#:6', 'G:6', 'G#:6', 'A:6', 'A#:6', 'B:6',
17 | # m6
18 | 'C:m6', 'C#:m6', 'D:m6', 'D#:m6', 'E:m6', 'F:m6', 'F#:m6', 'G:m6', 'G#:m6', 'A:m6', 'A#:m6', 'B:m6',
19 | # sus2
20 | 'C:sus2', 'C#:sus2', 'D:sus2', 'D#:sus2', 'E:sus2', 'F:sus2', 'F#:sus2', 'G:sus2', 'G#:sus2', 'A:sus2', 'A#:sus2',
21 | 'B:sus2',
22 | # sus4
23 | 'C:sus4', 'C#:sus4', 'D:sus4', 'D#:sus4', 'E:sus4', 'F:sus4', 'F#:sus4', 'G:sus4', 'G#:sus4', 'A:sus4', 'A#:sus4',
24 | 'B:sus4',
25 | # 5
26 | 'C:5', 'C#:5', 'D:5', 'D#:5', 'E:5', 'F:5', 'F#:5', 'G:5', 'G#:5', 'A:5', 'A#:5', 'B:5',
27 | ]
28 |
29 | SEGMENT_LABELS = [
30 | 'start',
31 | 'end',
32 | 'intro',
33 | 'outro',
34 | 'verse',
35 | 'chorus',
36 | 'solo',
37 | 'break',
38 | 'bridge',
39 | 'inst',
40 | ]
41 |
42 | MAJOR_CHORDS = [
43 | ["C", "Dm", "Em", "F", "G", "Am", "G7"],
44 | ["G", "Am", "Bm", "C", "D", "Em", "D7"],
45 | ["D", "Em", "F#m", "G", "A", "Bm", "A7"],
46 | ["A", "Bm", "C#m", "D", "E", "F#m", "E7"],
47 | ["E", "F#m", "G#m", "A", "B", "C#m", "B7"],
48 | ["F", "Gm", "Am", "Bb", "C", "Dm", "C7"],
49 | ["B", "C#m", "D#m", "E", "F#", "G#m", "F#7"],
50 | ["Db", "Ebm", "Fm", "Gb", "Ab", "Bbm", "Ab7"],
51 | ["Eb", "Fm", "Gm", "Ab", "Bb", "Cm", "Bb7"],
52 | ["Gb", "Abm", "Bbm", "Cb", "Db", "Ebm", "Db7"],
53 | ["Ab", "Bbm", "Cm", "Db", "Eb", "Fm", "Eb7"],
54 | ["Bb", "Cm", "Dm", "Eb", "F", "Gm", "F7"],
55 | ["C#", "D#m", "E#m", "F#", "G#", "A#m", "G#7"],
56 | ["Cb", "Dbm", "Ebm", "Fb", "Gb", "Abm", "Gb7"],
57 | ["F#", "G#m", "A#m", "B", "C#", "D#m", "C#7", ],
58 | ]
59 |
60 | MINOR_CHORDS = [
61 | ["Am", "Bdim", "C", "Dm", "Em", "F", "G"],
62 | ["Em", "F#dim", "G", "Am", "Bm", "C", "D"],
63 | ["Bm", "C#dim", "D", "Em", "F#m", "G", "A"],
64 | ["F#m", "G#dim", "A", "Bm", "C#m", "D", "E"],
65 | ["C#m", "D#dim", "E", "F#m", "G#m", "A", "B"],
66 | ["Dm", "Edim", "F", "Gm", "Am", "Bb", "C"],
67 | ["G#m", "A#dim", "B", "C#m", "D#m", "E", "F#"],
68 | ["Bbm", "Cdim", "Db", "Ebm", "Fm", "Gb", "Ab"],
69 | ["Cm", "Ddim", "Eb", "Fm", "Gm", "Ab", "Bb"],
70 | ["Ebm", "Fdim", "Gb", "Abm", "Bbm", "Cb", "Db"],
71 | ["Fm", "Gdim", "Ab", "Bbm", "Cm", "Db", "Eb"],
72 | ["Gm", "Adim", "Bb", "Cm", "Dm", "Eb", "F"],
73 | ["A#m", "B#dim", "C#", "D#m", "E#m", "F#", "G#"],
74 | ["Abm", "Bbdim", "Cb", "Dbm", "Ebm", "Fb", "Gb"],
75 | ["D#m", "E#dim", "F#", "G#m", "A#m", "B", "C#"],
76 | ]
77 |
78 | MAJ2MIN_MAP = {
79 | "G": "Em",
80 | "C": "Am",
81 | "D": "Bm",
82 | "E": "C#m",
83 | "F": "Dm",
84 | "A": "F#m",
85 | "B": "G#m",
86 | "Db": "Bbm",
87 | "Eb": "Cm",
88 | "Gb": "Ebm",
89 | "Ab": "Fm",
90 | "Bb": "Gm",
91 | "C#": "A#m",
92 | "Cb": "Abm",
93 | "F#": "D#m"
94 | }
95 | MIN2MAJ_MAP = {v: k for k, v in MAJ2MIN_MAP.items()}
96 |
--------------------------------------------------------------------------------
/python/models/segment_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.transformers import BaseTransformer
4 |
5 |
6 | class SegmentEmbeddings(nn.Module):
7 | def __init__(self, n_channel=2, n_hidden=128, d_model=2048, dropout=0.5):
8 | super().__init__()
9 | self.n_channel = n_channel
10 | self.n_hidden = n_hidden
11 |
12 | self.act_fn = nn.ReLU()
13 |
14 | self.conv1 = nn.Conv2d(n_channel, n_hidden // 2, kernel_size=(1, 1))
15 | self.pool0 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 0))
16 | self.drop1 = nn.Dropout(dropout)
17 |
18 | self.conv2 = nn.Conv2d(n_hidden // 2, n_hidden, kernel_size=(1, 1))
19 | self.drop2 = nn.Dropout(dropout)
20 |
21 | self.conv3 = nn.Conv2d(n_hidden, n_hidden // 2, kernel_size=(1, 1))
22 | self.pool3 = nn.AvgPool2d(kernel_size=(1, 2), stride=(1, 2), padding=(0, 0))
23 | self.drop3 = nn.Dropout(dropout)
24 |
25 | self.conv4 = nn.Conv2d(n_hidden // 2, n_channel, kernel_size=(1, 1))
26 | self.drop4 = nn.Dropout(dropout)
27 |
28 | self.norm = nn.LayerNorm(d_model // 4)
29 | self.drop = nn.Dropout(dropout)
30 |
31 | def forward(self, x):
32 | # x: batch, n_channel, n_time, n_freq
33 | x = self.conv1(x)
34 | x = self.pool0(x)
35 | x = self.act_fn(x)
36 | x = self.drop1(x)
37 |
38 | x = self.conv2(x)
39 | x = self.act_fn(x)
40 | x = self.drop2(x)
41 |
42 | x = self.conv3(x)
43 | x = self.pool3(x)
44 | x = self.act_fn(x)
45 | x = self.drop3(x)
46 |
47 | x = self.conv4(x)
48 | x = self.act_fn(x)
49 | x = self.drop4(x)
50 |
51 | x = self.norm(x)
52 | x = self.drop(x)
53 |
54 | return x
55 |
56 |
57 | class SegmentNet(nn.Module):
58 | def __init__(self,
59 | n_freq=2048,
60 | n_channel=2,
61 | n_classes=10,
62 | emb_hidden=128,
63 | n_group=32,
64 | f_layers=2,
65 | f_nhead=8,
66 | t_layers=2,
67 | t_nhead=8,
68 | d_layers=2,
69 | d_nhead=8,
70 | dropout=0.5,
71 | *args, **kwargs):
72 | super().__init__()
73 |
74 | divisor = 4
75 | d_model = n_freq // divisor
76 | self.embeddings = SegmentEmbeddings(n_channel=n_channel, n_hidden=emb_hidden, d_model=n_freq, dropout=dropout)
77 | self.wfc = nn.Linear(n_freq, n_freq // divisor)
78 |
79 | self.transformer = BaseTransformer(n_channel=n_channel, n_freq=d_model, n_group=n_group,
80 | f_layers=f_layers, f_nhead=f_nhead, f_dropout=dropout,
81 | t_layers=t_layers, t_nhead=t_nhead, t_dropout=dropout,
82 | d_layers=d_layers, d_nhead=d_nhead, d_dropout=dropout,
83 | )
84 |
85 | self.norm = nn.LayerNorm(d_model)
86 |
87 | self.segment_classifier = nn.Linear(d_model, n_classes)
88 |
89 | def forward(self, x, weight=None):
90 | # x: batch, n_channel, n_time, n_freq
91 | if weight is None:
92 | B, C, T, F = x.shape
93 | weight = torch.ones(B, T, F, device=x.device)
94 | x = self.embeddings(x)
95 | weight = self.wfc(weight)
96 |
97 | x, _ = self.transformer(x, weight=weight)
98 | x = self.norm(x)
99 | x = self.segment_classifier(x)
100 | x = torch.argmax(x, dim=-1)
101 | return x
102 |
103 |
104 | if __name__ == '__main__':
105 | net = SegmentNet(n_freq=2048)
106 | print(net)
107 | x = torch.randn(2, 2, 1024, 2048)
108 | y = net(x)
109 | print(y.shape)
110 |
--------------------------------------------------------------------------------
/cover/utils/spec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from librosa import feature
3 |
4 |
5 | def stft(signal, fft_size=512, hop_length=None, pad=True):
6 | """
7 | Perform Short-Time Fourier Transform (STFT) on the input signal.
8 |
9 | Parameters:
10 | - signal (torch.Tensor): The input signal tensor.
11 | - fft_size (int): The size of the FFT window. Default is 512.
12 | - hop_length (int): The number of samples between successive frames. Default is fft_size // 4.
13 | - pad (bool, optional): whether to pad :attr:`input` on both sides
14 | so that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
15 | Default: ``True``
16 |
17 | Returns:
18 | - torch.Tensor: The STFT of the input signal.
19 | """
20 | *ot_shape, n_time = signal.shape
21 | signal = signal.reshape(-1, n_time)
22 | result = torch.stft(signal,
23 | fft_size,
24 | hop_length or fft_size // 4,
25 | window=torch.hann_window(fft_size).to(signal),
26 | win_length=fft_size,
27 | normalized=True,
28 | center=pad,
29 | return_complex=True,
30 | pad_mode='reflect')
31 | _, freqs, frames = result.shape
32 | return result.view(*ot_shape, freqs, frames)
33 |
34 |
35 | def istft(stft_matrix, hop_length=None, signal_length=None, pad=True):
36 | """
37 | Perform Inverse Short-Time Fourier Transform (ISTFT) on the input STFT matrix.
38 |
39 | Parameters:
40 | - stft_matrix (torch.Tensor): The input STFT matrix tensor.
41 | - hop_length (int): The number of samples between successive frames. Default is None.
42 | - signal_length (int): The length of the original signal. Default is None.
43 | - pad (bool, optional): whether to pad :attr:`input` on both sides
44 | so that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
45 | Default: ``True``
46 |
47 | Returns:
48 | - torch.Tensor: The reconstructed time-domain signal.
49 | """
50 | *ot_shape, n_freqs, n_frames = stft_matrix.shape
51 | fft_size = 2 * n_freqs - 2
52 | stft_matrix = stft_matrix.view(-1, n_freqs, n_frames)
53 | win_length = fft_size
54 |
55 | result = torch.istft(stft_matrix,
56 | fft_size,
57 | hop_length,
58 | window=torch.hann_window(win_length).to(stft_matrix.real),
59 | win_length=win_length,
60 | normalized=True,
61 | length=signal_length,
62 | center=pad)
63 | _, length = result.shape
64 | return result.view(*ot_shape, length)
65 |
66 |
67 | def get_spectrogram(waveform, config):
68 | """
69 | Get the spectrogram of the input waveform based on the provided configuration.
70 |
71 | Parameters:
72 | - waveform (torch.Tensor): The input waveform tensor.
73 | - config (dict): The configuration dictionary containing 'n_fft', 'hop_length', and 'pad' keys.
74 |
75 | Returns:
76 | - torch.Tensor: The spectrogram of the input waveform.
77 | """
78 | spectrogram = stft(waveform, fft_size=config.get('n_fft', 4096),
79 | hop_length=config.get('hop_length', 1024), pad=config.get('pad', True))
80 | spectrogram = spectrogram.transpose(-1, -2)
81 | return spectrogram # channel, freq, time
82 |
83 |
84 | def chroma(waveform, n_chroma=12, sample_rate=44100, hop_length=512, bins_per_octave=24):
85 | dtype = waveform.dtype
86 | if isinstance(waveform, torch.Tensor):
87 | waveform = waveform.cpu().numpy()
88 | y = feature.chroma_cqt(y=waveform, sr=sample_rate, n_chroma=n_chroma, hop_length=hop_length, n_octaves=12,
89 | bins_per_octave=bins_per_octave, cqt_mode='hybrid')
90 | y = torch.from_numpy(y).to(dtype)
91 | return y
92 |
--------------------------------------------------------------------------------
/cover/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch import nn
5 |
6 |
7 | class UPad(nn.Module):
8 | def __init__(self, padding_setting=(1, 2, 1, 2)):
9 | super().__init__()
10 | self.padding_setting = padding_setting
11 |
12 | def forward(self, x):
13 | return F.pad(x, self.padding_setting, "constant", 0)
14 |
15 |
16 | class UTransposedPad(nn.Module):
17 | def __init__(self, padding_setting=(1, 2, 1, 2)):
18 | super().__init__()
19 | self.padding_setting = padding_setting
20 |
21 | def forward(self, x):
22 | l, r, t, b = self.padding_setting
23 | return x[:, :, l:-r, t:-b]
24 |
25 |
26 | def get_activation(activation):
27 | if activation == "ReLU":
28 | activation_fn = nn.ReLU()
29 | elif activation == "ELU":
30 | activation_fn = nn.ELU()
31 | else:
32 | activation_fn = nn.LeakyReLU(0.2)
33 | return activation_fn
34 |
35 |
36 | class UNet(nn.Module):
37 | def __init__(self,
38 | n_channel=2,
39 | conv_n_filters=(16, 32, 64, 128, 256, 512),
40 | down_activation="ELU",
41 | up_activation="ELU",
42 | down_dropouts=None,
43 | up_dropouts=None):
44 | super().__init__()
45 |
46 | conv_num = len(conv_n_filters)
47 |
48 | down_activation_fn = get_activation(down_activation)
49 | up_activation_fn = get_activation(up_activation)
50 |
51 | down_dropouts = [0] * conv_num if down_dropouts is None else down_dropouts
52 | up_dropouts = [0] * conv_num if up_dropouts is None else up_dropouts
53 |
54 | self.down_layers = nn.ModuleList()
55 | for i in range(conv_num):
56 | in_ch = n_channel if i == 0 else conv_n_filters[i - 1]
57 | out_ch = conv_n_filters[i]
58 | dropout = down_dropouts[i]
59 |
60 | _down_layers = [
61 | UPad(),
62 | nn.Conv2d(in_ch, out_ch, kernel_size=5, stride=2, padding=0),
63 | nn.BatchNorm2d(out_ch, track_running_stats=True, eps=1e-3, momentum=0.01),
64 | down_activation_fn
65 | ]
66 | if dropout > 0:
67 | _down_layers.append(nn.Dropout(dropout))
68 | self.down_layers.append(nn.Sequential(*_down_layers))
69 |
70 | self.up_layers = nn.ModuleList()
71 | for i in range(conv_num - 1, -1, -1):
72 | in_ch = conv_n_filters[conv_num - 1] if i == conv_num - 1 else conv_n_filters[i + 1]
73 | out_ch = 1 if i == 0 else conv_n_filters[i - 1]
74 | dropout = up_dropouts[i]
75 |
76 | _up_layer = [
77 | nn.ConvTranspose2d(in_ch, out_ch, kernel_size=5, stride=2, padding=0),
78 | UTransposedPad(),
79 | up_activation_fn,
80 | nn.BatchNorm2d(out_ch, track_running_stats=True, eps=1e-3, momentum=0.01)
81 | ]
82 | if dropout > 0:
83 | _up_layer.append(nn.Dropout(dropout))
84 | self.up_layers.append(nn.Sequential(*_up_layer))
85 |
86 | self.last_layer = nn.Conv2d(1, 2, kernel_size=1, stride=1, padding=0)
87 |
88 | def forward(self, x):
89 |
90 | d_convs = []
91 | for layer in self.down_layers:
92 | x = layer(x)
93 | d_convs.append(x)
94 |
95 | n = len(self.up_layers)
96 | for i, layer in enumerate(self.up_layers):
97 | if i == 0:
98 | x = layer(x)
99 | else:
100 | x1 = d_convs[n - i - 1]
101 | x = torch.cat([x, x1], axis=1)
102 | x = layer(x)
103 | x = self.last_layer(x)
104 | return x
105 |
106 |
107 | class UNets(nn.Module):
108 | def __init__(self,
109 | sources,
110 | n_channel=2,
111 | conv_n_filters=(16, 32, 64, 128, 256, 512),
112 | down_activation="ELU",
113 | up_activation="ELU",
114 | down_dropouts=None,
115 | up_dropouts=None,
116 | *args, **kwargs):
117 | super().__init__()
118 |
119 | self.unet_layers = nn.ModuleList()
120 | for i in range(sources):
121 | layer = UNet(n_channel=n_channel,
122 | conv_n_filters=conv_n_filters,
123 | down_activation=down_activation,
124 | up_activation=up_activation,
125 | down_dropouts=down_dropouts,
126 | up_dropouts=up_dropouts)
127 | self.unet_layers.append(layer)
128 |
129 | def forward(self, x):
130 | # x shape: (batch, channel, time, freq)
131 | _layers = []
132 | for layer in self.unet_layers:
133 | y = layer(x)
134 | _layers.append(y)
135 |
136 | y = torch.stack(_layers, dim=1)
137 | y = F.softmax(y, dim=1) # shape: (batch, sources, channel, time, freq)
138 | return y
139 |
140 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fish
2 |
3 | YouTube video to chords, lyrics, beat and melody.
4 |
5 | A transformer-based hybrid multimodal model, various transformer models address different problems in the field of music information retrieval, these models generate corresponding information dependencies that mutually influence each other.
6 |
7 | An AI-powered multimodal project focused on music, generate chords, beats, lyrics, melody, and tabs for any song.
8 |
9 | > The online experience, [See the site here](https://lamucal.com)
10 |
11 |
12 |
13 |
14 |
15 | `U-Net` network model for audio source separation, `Pitch-Net`, `Beat-Net`, `Chord-Net` and `Segment-Net` based on the transformer model. Apart from establishing the correlation between the frequency and time, the most important aspect is to establish the mutual influence between different networks.
16 |
17 | The entire AI-powered process is implemented in `aitabs.py`, while the various network structure models can be referenced in the `models` folder.
18 | > **Note**: `U-Net` and `Segment-Net` use the stft spectrum of audio as input. `Beat-Net` uses three spectrograms of drums, bass, and other instruments as input,`Chord-Net` uses one spectrogram of the background music.
19 |
20 |
21 | ## Features
22 | - **Chord**, music chord detection, including major, minor, 7, maj7, min7, 6, m6, sus2, sus4, 5, and inverted chords. Determining the **key** of a song.
23 |
24 | - **Beat**, music beat, downbeat detection and **tempo** (BPM) tracking
25 |
26 | - **Pitch**, tracking the pitch of the melody in the vocal track.
27 |
28 | - **Music Structure**, music segment boundaries and labels, include intro, verse, chorus, bridge and etc.
29 |
30 | - **Lyrics**, music lyrics recognition and automatic lyrics to audio alignment, use ASR (whisper) to recognize the lyrics of the vocal track. The alignment of lyrics and audio is achieved through fine-tuning the wav2vec2 pre-training model. Currently, it supports dozens of languages, including English, Spanish, Portuguese, Russian, Japanese, Korean, Arabic, Chinese, and more.
31 |
32 | - **AI Tabs**, Generate playable sheet music, including chord charts and six-line staves, using chords, beats, music structure information, lyrics, rhythm, etc. It supports editing functionalities for chords, rhythm, and lyrics.
33 |
34 | - **Other**, audio source separation, speed adjustment, pitch shifting, etc.
35 |
36 | For more AI-powered feature experiences, see the [website](https://lamucal.com):
37 |
38 | ## Cover
39 | Using a combination of audio STFT, MFCC, and chroma features, with a Transformer model for timbre feature
40 | modeling and high-level abstraction, this approach can maximize the avoidance of overfitting and underfitting
41 | problems compared to using a single feature, and has better generalization capabilities. With a small amount of
42 | data and minimal training, it can achieve better results.
43 |
44 | > The online experience, [See the site here](https://lamucal.com/ai-cover)
45 |
46 |
47 |
48 |
49 |
50 | The model begins by processing the audio signal through a `U-Net`, which isolates the vocal track.
51 | The vocal track is then simultaneously fed into `PitchNet` and `HuBERT` (Wav2Vec2). `PitchNet` is
52 | responsible for extracting pitch features, while `HuBERT` captures detailed features of the vocals.
53 |
54 | The core of the model is `CombineNet`, which receives features from the `Features` module. This
55 | module consists of three spectrograms: STFT, MFCC, and Chroma, each extracting different aspects
56 | of the audio. These features are enhanced by the TimbreBlock before being passed to the Encoder.
57 | During this process, noise is introduced via STFT transformation and combined with the features
58 | before entering the Encoder for processing. The processed features are then passed to the Decoder,
59 | where they are combined with the previous features to generate the final audio output.
60 |
61 | `CombineNet` is based on an encoder-decoder architecture and is trained to generate a mask that
62 | is used to extract and replace the timbre, ultimately producing the final output audio.
63 |
64 | The entire AI-powered process is implemented in `run.py`, while the various network structure
65 | models can be referenced in the `models` folder.
66 |
67 | ## Demo
68 | The results of training on a 1-minute speech of Donald Trump are as follows:
69 |
70 |
71 |
72 | |
73 |
74 | **Train 10 epoch(Hozier's Too Sweet)**
75 |
76 | |
77 |
78 |
79 | **Train 100 epoch(Hozier's Too Sweet)**
80 |
81 | |
82 |
83 |
84 | |
85 |
86 | [Train 10 epoch.webm](https://github.com/user-attachments/assets/992747d6-3e47-442c-ab63-0742c83933ee)
87 |
88 | |
89 |
90 |
91 | [Train 100 epoch.webm](https://github.com/user-attachments/assets/877d2cae-d7b7-4355-807f-424ada7df3a1)
92 |
93 | |
94 |
95 |
96 |
97 |
98 | You can experience creating your own voice online, [See the site here](https://lamucal.com/ai-cover)
99 |
--------------------------------------------------------------------------------
/python/models/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from torch import nn
5 |
6 |
7 | class UPad(nn.Module):
8 | def __init__(self, padding_setting=(1, 2, 1, 2)):
9 | super().__init__()
10 | self.padding_setting = padding_setting
11 |
12 | def forward(self, x):
13 | return F.pad(x, self.padding_setting, "constant", 0)
14 |
15 |
16 | class UTransposedPad(nn.Module):
17 | def __init__(self, padding_setting=(1, 2, 1, 2)):
18 | super().__init__()
19 | self.padding_setting = padding_setting
20 |
21 | def forward(self, x):
22 | l, r, t, b = self.padding_setting
23 | return x[:, :, l:-r, t:-b]
24 |
25 |
26 | def get_activation(activation):
27 | if activation == "ReLU":
28 | activation_fn = nn.ReLU()
29 | elif activation == "ELU":
30 | activation_fn = nn.ELU()
31 | else:
32 | activation_fn = nn.LeakyReLU(0.2)
33 | return activation_fn
34 |
35 |
36 | class UNet(nn.Module):
37 | def __init__(self,
38 | n_channel=2,
39 | conv_n_filters=(16, 32, 64, 128, 256, 512),
40 | down_activation="ELU",
41 | up_activation="ELU",
42 | down_dropouts=None,
43 | up_dropouts=None):
44 | super().__init__()
45 |
46 | conv_num = len(conv_n_filters)
47 |
48 | down_activation_fn = get_activation(down_activation)
49 | up_activation_fn = get_activation(up_activation)
50 |
51 | down_dropouts = [0] * conv_num if down_dropouts is None else down_dropouts
52 | up_dropouts = [0] * conv_num if up_dropouts is None else up_dropouts
53 |
54 | self.down_layers = nn.ModuleList()
55 | for i in range(conv_num):
56 | in_ch = n_channel if i == 0 else conv_n_filters[i - 1]
57 | out_ch = conv_n_filters[i]
58 | dropout = down_dropouts[i]
59 |
60 | _down_layers = [
61 | UPad(),
62 | nn.Conv2d(in_ch, out_ch, kernel_size=5, stride=2, padding=0),
63 | nn.BatchNorm2d(out_ch, track_running_stats=True, eps=1e-3, momentum=0.01),
64 | down_activation_fn
65 | ]
66 | if dropout > 0:
67 | _down_layers.append(nn.Dropout(dropout))
68 | self.down_layers.append(nn.Sequential(*_down_layers))
69 |
70 | self.up_layers = nn.ModuleList()
71 | for i in range(conv_num - 1, -1, -1):
72 | in_ch = conv_n_filters[conv_num - 1] if i == conv_num - 1 else conv_n_filters[i + 1]
73 | out_ch = 1 if i == 0 else conv_n_filters[i - 1]
74 | dropout = up_dropouts[i]
75 |
76 | _up_layer = [
77 | nn.ConvTranspose2d(in_ch, out_ch, kernel_size=5, stride=2, padding=0),
78 | UTransposedPad(),
79 | up_activation_fn,
80 | nn.BatchNorm2d(out_ch, track_running_stats=True, eps=1e-3, momentum=0.01)
81 | ]
82 | if dropout > 0:
83 | _up_layer.append(nn.Dropout(dropout))
84 | self.up_layers.append(nn.Sequential(*_up_layer))
85 |
86 | self.last_layer = nn.Conv2d(1, 2, kernel_size=1, stride=1, padding=0)
87 |
88 | def forward(self, x):
89 |
90 | d_convs = []
91 | for layer in self.down_layers:
92 | x = layer(x)
93 | d_convs.append(x)
94 |
95 | n = len(self.up_layers)
96 | for i, layer in enumerate(self.up_layers):
97 | if i == 0:
98 | x = layer(x)
99 | else:
100 | x1 = d_convs[n - i - 1]
101 | x = torch.cat([x, x1], axis=1)
102 | x = layer(x)
103 | x = self.last_layer(x)
104 | return x
105 |
106 |
107 | class UNets(nn.Module):
108 | def __init__(self,
109 | sources,
110 | n_channel=2,
111 | conv_n_filters=(16, 32, 64, 128, 256, 512),
112 | down_activation="ELU",
113 | up_activation="ELU",
114 | down_dropouts=None,
115 | up_dropouts=None,
116 | *args, **kwargs):
117 | super().__init__()
118 |
119 | self.unet_layers = nn.ModuleList()
120 | for i in range(sources):
121 | layer = UNet(n_channel=n_channel,
122 | conv_n_filters=conv_n_filters,
123 | down_activation=down_activation,
124 | up_activation=up_activation,
125 | down_dropouts=down_dropouts,
126 | up_dropouts=up_dropouts)
127 | self.unet_layers.append(layer)
128 |
129 | def forward(self, x):
130 | # x shape: (batch, channel, time, freq)
131 | _layers = []
132 | for layer in self.unet_layers:
133 | y = layer(x)
134 | _layers.append(y)
135 |
136 | y = torch.stack(_layers, axis=1)
137 | y = F.softmax(y, dim=1) # shape: (batch, sources, channel, time, freq)
138 | return y
139 |
140 |
141 | if __name__ == '__main__':
142 | model_params = {
143 | 'sources': 4,
144 | 'n_channel': 2,
145 | 'conv_n_filters': [8, 16, 32, 64, 128, 256, 512, 1024],
146 | 'down_activation': "ELU",
147 | 'up_activation': "ELU",
148 | # 'down_dropouts': [0, 0, 0, 0, 0, 0, 0, 0],
149 | # 'up_dropouts': [0, 0, 0, 0, 0, 0, 0, 0],
150 | }
151 | net = UNets(**model_params)
152 | print(net)
153 | print(net(torch.rand(1, 2, 256, 2049)).shape)
154 |
--------------------------------------------------------------------------------
/cover/run.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import torch as th
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | from load import get_model, get_hubert_model
8 | from utils.audio import load_waveform, save_waveform
9 | from utils.spec import get_spectrogram, istft
10 | from utils.utils import generate_uuid
11 | from models import UNets, PitchNet, CombineNet
12 |
13 |
14 | def build_masked_stft(masks, stft_feature, n_fft=4096):
15 | out = []
16 | for i in range(len(masks)):
17 | mask = masks[i, :, :, :]
18 | pad_num = n_fft // 2 + 1 - mask.size(-1)
19 | mask = F.pad(mask, (0, pad_num, 0, 0, 0, 0))
20 | inst_stft = mask.type(stft_feature.dtype) * stft_feature
21 | out.append(inst_stft)
22 | return out
23 |
24 |
25 | def merge_wav(stem_list, volume_list=None):
26 | stem_num = len(stem_list)
27 | length = min([stem.shape[0] for stem in stem_list])
28 |
29 | for i in range(stem_num):
30 | stem = stem_list[i][:length] # shape: time, channel
31 | if stem.ndim == 1:
32 | stem = np.tile(np.expand_dims(stem, axis=1), (1, 2))
33 | stem = stem / np.abs(stem).max()
34 | stem_list[i] = stem
35 | if volume_list:
36 | stem_list[i] = stem_list[i] * volume_list[i]
37 |
38 | mix = sum(stem_list) / stem_num
39 | mix = mix / np.abs(mix).max()
40 | return mix
41 |
42 |
43 | class AudioGenerate(object):
44 | def __init__(self, config, device='cpu'):
45 | self.config = config
46 | self.device = device
47 | self.n_channel = self.config['n_channel']
48 | self.sources = self.config['sources']
49 | self.samplerate = self.config['samplerate']
50 | self.separate_config = self.config['separate']
51 | self.f0_config = self.config['f0']
52 | self.hubert_config = self.config['hubert']
53 | self.generate_config = self.config['generate']
54 |
55 | self.separate_model_cfg = self.separate_config['model']
56 | self.separate_model_cfg['sources'] = self.sources
57 | self.separate_model_cfg['n_channel'] = self.n_channel
58 | self.unet = get_model(UNets, self.separate_model_cfg,
59 | model_path=self.separate_config['model_path'],
60 | is_train=False, device=device)
61 |
62 | self.f0_model_cfg = self.f0_config['model']
63 | self.f0_extractor = get_model(PitchNet, self.f0_model_cfg,
64 | model_path=self.f0_config['model_path'],
65 | is_train=False, device=device)
66 | self.hubert_model = get_hubert_model(self.hubert_config['name'],
67 | dl_kwargs=self.hubert_config['download'],
68 | device=device)
69 |
70 | self.generate_model_cfg = self.generate_config['model']
71 | self.ai_cover_model = get_model(CombineNet, self.generate_model_cfg,
72 | model_path=self.generate_config['model_path'],
73 | device=device)
74 |
75 | def separate(self, waveform, samplerate):
76 | assert samplerate == self.samplerate
77 | wav_len = waveform.shape[-1]
78 |
79 | spec_config = self.separate_config['spec']
80 | n_fft = spec_config['n_fft']
81 | hop_length = spec_config['hop_length']
82 | n_time = spec_config['n_time']
83 |
84 | split_len = (n_time - 5) * hop_length + n_fft
85 |
86 | output_waveforms = [[] for _ in range(self.sources)]
87 | for i in range(0, wav_len, split_len):
88 | with th.no_grad():
89 | x = waveform[:, i:i + split_len]
90 | pad_num = 0
91 | if x.shape[-1] < split_len:
92 | pad_num = split_len - (wav_len - i)
93 | x = F.pad(x, (0, pad_num))
94 |
95 | # separator
96 | z = get_spectrogram(x, spec_config)
97 | mag_z = th.abs(z).unsqueeze(0)
98 | masks = self.unet(mag_z)
99 | masks = masks.squeeze(0)
100 | _masked_stfts = build_masked_stft(masks, z, n_fft=n_fft)
101 | # build waveform
102 | for j, _masked_stft in enumerate(_masked_stfts):
103 | _masked_stft = _masked_stft.transpose(-1, -2)
104 | _waveform = istft(_masked_stft, hop_length=hop_length)
105 | if pad_num > 0:
106 | _waveform = _waveform[:, :-pad_num]
107 | output_waveforms[j].append(_waveform)
108 |
109 | inst_waveforms = []
110 | for waveform_list in output_waveforms:
111 | inst_waveforms.append(th.cat(waveform_list, dim=-1))
112 | return th.stack(inst_waveforms, dim=0)
113 |
114 | def get_feat(self, audio):
115 | feats, _ = self.hubert_model(audio)
116 | feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
117 | return feats
118 |
119 | def generate(self, audio_fp, save_path):
120 | waveform, samplerate = load_waveform(audio_fp, samplerate=self.samplerate,
121 | n_channel=self.n_channel, device=self.device)
122 |
123 | waveforms = self.separate(waveform, samplerate)
124 |
125 | vocal_waveform = waveforms[0]
126 | other_waveform = waveforms[1]
127 |
128 | with th.no_grad():
129 | feat = self.get_feat(vocal_waveform)
130 | f0 = self.f0_extractor(vocal_waveform)
131 | out_waveform, _, _ = self.ai_cover_model(vocal_waveform, feat, f0)
132 |
133 | final_waveform = merge_wav([vocal_waveform, other_waveform],
134 | volume_list=self.generate_config['volume_list'])
135 |
136 | fp = os.path.join(save_path, f'{generate_uuid()}.wav')
137 | save_waveform(fp, final_waveform, self.samplerate)
138 |
139 |
--------------------------------------------------------------------------------
/python/models/pitch_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.transformers import EncoderFre, EncoderTime, FeedForward, positional_encoding
4 |
5 |
6 | class PitchEmbedding(nn.Module):
7 | def __init__(self, n_channel, d_model, n_hidden=32, dropout=0.5):
8 | super().__init__()
9 | self.n_channel = n_channel
10 | self.n_hidden = n_hidden
11 |
12 | self.act_fn = nn.ReLU()
13 |
14 | self.conv1 = nn.Conv2d(n_channel, n_hidden // 4, kernel_size=(1, 1))
15 | self.pool1 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
16 | self.drop1 = nn.Dropout(dropout)
17 |
18 | self.conv2 = nn.Conv2d(n_hidden // 4, n_hidden // 2, kernel_size=(1, 1))
19 | self.pool2 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
20 | self.drop2 = nn.Dropout(dropout)
21 |
22 | self.conv3 = nn.Conv2d(n_hidden // 2, n_hidden, kernel_size=(1, 1))
23 | self.pool3 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
24 | self.drop3 = nn.Dropout(dropout)
25 |
26 | self.conv4 = nn.Conv2d(n_hidden, n_hidden, kernel_size=(1, 1))
27 | self.pool4 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
28 | self.drop4 = nn.Dropout(dropout)
29 |
30 | self.norm = nn.LayerNorm(d_model // 16)
31 | self.drop = nn.Dropout(dropout)
32 |
33 | def forward(self, x):
34 | # x: batch, n_channel, n_time, n_freq
35 | x = self.conv1(x)
36 | x = self.pool1(x)
37 | x = self.act_fn(x)
38 | x = self.drop1(x)
39 |
40 | x = self.conv2(x)
41 | x = self.pool2(x)
42 | x = self.act_fn(x)
43 | x = self.drop2(x)
44 |
45 | x = self.conv3(x)
46 | x = self.pool3(x)
47 | x = self.act_fn(x)
48 | x = self.drop3(x)
49 |
50 | x = self.conv4(x)
51 | x = self.pool4(x)
52 | x = self.act_fn(x)
53 | x = self.drop4(x)
54 |
55 | x = self.norm(x)
56 | x = self.drop(x)
57 |
58 | return x
59 |
60 |
61 | class PitchEncoder(nn.Module):
62 | def __init__(self,
63 | n_freq=2048,
64 | n_group=32,
65 | weights=(0.6, 0.4),
66 | f_layers=2,
67 | f_nhead=8,
68 | f_pr=0.01,
69 | t_layers=2,
70 | t_nhead=8,
71 | t_pr=0.01,
72 | dropout=0.5):
73 | super().__init__()
74 | self.weights = weights
75 |
76 | self.encoder_fre = EncoderFre(n_freq=n_freq, n_group=n_group, nhead=f_nhead, n_layers=f_layers, dropout=dropout,
77 | pr=f_pr)
78 | self.encoder_time = EncoderTime(n_freq=n_freq, nhead=t_nhead, n_layers=t_layers, dropout=dropout, pr=t_pr)
79 |
80 | self.dropout = nn.Dropout(dropout)
81 | self.norm = nn.LayerNorm(n_freq)
82 |
83 | def forward(self, x):
84 | # x: batch, channel, n_time, n_freq
85 | y1 = self.encoder_fre(x)
86 | y2 = self.encoder_time(x)
87 | y = y1 * self.weights[0] + y2 * self.weights[1]
88 | y = self.dropout(y)
89 | y = self.norm(y)
90 | return y
91 |
92 |
93 | class PitchNet(nn.Module):
94 | def __init__(self,
95 | n_freq=2048,
96 | n_channel=2,
97 | n_classes=850, # 85 * 10
98 | emb_hidden=16,
99 | wr=0.3,
100 | pr=0.02,
101 | n_group=32,
102 | f_layers=2,
103 | f_nhead=8,
104 | t_layers=2,
105 | t_nhead=8,
106 | enc_weights=(0.6, 0.4),
107 | d_layers=2,
108 | d_nhead=8,
109 | dropout=0.5,
110 | *args, **kwargs):
111 | super().__init__()
112 | self.wr = wr
113 | self.pr = pr
114 | self.embedding = PitchEmbedding(n_channel, n_freq, emb_hidden, dropout)
115 | d_model = n_freq // 16 * emb_hidden
116 | self.encoder = PitchEncoder(n_freq=d_model, n_group=n_group,
117 | weights=enc_weights,
118 | f_layers=f_layers, f_nhead=f_nhead,
119 | t_layers=t_layers, t_nhead=t_nhead,
120 | dropout=dropout)
121 |
122 | self.attn1_layer = nn.ModuleList()
123 | self.attn2_layer = nn.ModuleList()
124 | self.ff_layer = nn.ModuleList()
125 | for _ in range(d_layers):
126 | _layer1 = nn.MultiheadAttention(d_model, d_nhead, batch_first=True)
127 | _layer2 = nn.MultiheadAttention(d_model, d_nhead, batch_first=True)
128 | _layer3 = FeedForward(d_model, dropout=dropout)
129 | self.attn1_layer.append(_layer1)
130 | self.attn2_layer.append(_layer2)
131 | self.ff_layer.append(_layer3)
132 |
133 | self.dropout = nn.Dropout(dropout)
134 | self.classifier = nn.Linear(d_model, n_classes)
135 |
136 | def forward(self, x, weight=None):
137 | # x: batch, channel, n_time, n_freq
138 | y = self.embedding(x)
139 |
140 | B, C, T, F = y.shape
141 | y = y.permute(0, 2, 1, 3) # B, C, T, F => B, T, C, F
142 | y = y.reshape(B, T, C * F) # B, T, C * F
143 | y = self.encoder(y)
144 |
145 | if weight is None:
146 | weight = torch.zeros_like(y)
147 | y_w = y + weight * self.wr
148 | B, T, F = y_w.shape
149 | y_w += positional_encoding(B, T, F) * self.pr
150 |
151 | for _attn1_layer, _attn2_layer, _ff_layer in zip(self.attn1_layer, self.attn2_layer, self.ff_layer):
152 | y, _ = _attn1_layer(y, y, y, need_weights=False)
153 | y, _ = _attn1_layer(y, y_w, y_w, need_weights=False)
154 | y = _ff_layer(y)
155 |
156 | output = self.dropout(y)
157 | output = self.classifier(output)
158 | output = torch.argmax(output, dim=-1)
159 | return output, y
160 |
161 |
162 | if __name__ == '__main__':
163 | model = PitchNet()
164 | print(model)
165 | x = torch.randn(6, 2, 256, 2048)
166 | y1, y2 = model(x, None)
167 | print(y1.shape, y2.shape)
168 |
--------------------------------------------------------------------------------
/python/aitabs.py:
--------------------------------------------------------------------------------
1 | import torchaudio as ta
2 | import torch as th
3 | import torch.nn.functional as F
4 | from librosa.feature import tempo
5 |
6 | from audio import read_wav, write_wav, gen_wav
7 | from utils import build_masked_stft, get_chord_name, get_segment_name, get_lyrics
8 | from spec import istft, get_spec, get_specs, get_mixed_spec
9 | from modulation import search_key
10 | from models import get_model
11 |
12 |
13 | class AITabTranscription(object):
14 | def __init__(self, config):
15 | self.config = config
16 | self.n_channel = self.config['n_channel']
17 | self.sources = self.config['sources']
18 | self.sample_rate = self.config['sample_rate']
19 | self.sep_config = self.config['separate']
20 | self.lyrics_cfg = self.config['lyrics']
21 | self.beat_cfg = self.config['beat']
22 | self.chord_cfg = self.config['chord']
23 | self.segment_cfg = self.config['segment']
24 | self.pitch_cfg = self.config['pitch']
25 | self.spec_cfg = self.config['spec']
26 | self.tempo_cfg = self.config['tempo']
27 |
28 | def separate(self, waveform, sample_rate, device='cpu'):
29 | assert sample_rate == self.sample_rate
30 | wav_len = waveform.shape[-1]
31 |
32 | model_config = self.sep_config['model']
33 | spec_config = self.sep_config['spec']
34 | n_fft = self.sep_config['spec']['n_fft']
35 | hop_length = self.sep_config['spec']['hop_length']
36 | n_time = self.sep_config['spec']['n_time']
37 |
38 | _model_cfg = {
39 | 'sources': self.sources,
40 | 'n_channel': self.n_channel,
41 | }
42 | _model_cfg.update(model_config)
43 | unet = get_model(self.sep_config['model_name'], _model_cfg, model_path=self.sep_config['model_path'],
44 | is_train=False, device=device)
45 |
46 | split_len = (n_time - 5) * hop_length + n_fft
47 |
48 | output_waveforms = [[] for _ in range(self.sources)]
49 | for i in range(0, wav_len, split_len):
50 | with th.no_grad():
51 | x = waveform[:, i:i + split_len]
52 | pad_num = 0
53 | if x.shape[-1] < split_len:
54 | pad_num = split_len - (wav_len - i)
55 | x = F.pad(x, (0, pad_num))
56 |
57 | # separator
58 | z = get_spec(x, spec_config)
59 | mag_z = th.abs(z).unsqueeze(0)
60 | masks = unet(mag_z)
61 | masks = masks.squeeze(0)
62 | _masked_stfts = build_masked_stft(masks, z, n_fft=n_fft)
63 | # build waveform
64 | for j, _masked_stft in enumerate(_masked_stfts):
65 | _waveform = istft(_masked_stft, n_fft=n_fft, hop_length=hop_length, pad=True)
66 | if pad_num > 0:
67 | _waveform = _waveform[:, :-pad_num]
68 | output_waveforms[j].append(_waveform)
69 |
70 | inst_waveforms = []
71 | for waveform_list in output_waveforms:
72 | inst_waveforms.append(th.cat(waveform_list, dim=-1))
73 | return th.stack(inst_waveforms, dim=0)
74 |
75 | def transcribe(self, wav_fp, device='cpu'):
76 |
77 | waveform, sample_rate = read_wav(wav_fp, sample_rate=self.sample_rate, n_channel=self.n_channel, device=device)
78 | # print(waveform.shape, sample_rate)
79 |
80 | inst_waveforms = self.separate(waveform, sample_rate)
81 | # print(inst_waveforms.shape)
82 |
83 | # laod model
84 | beat_net = get_model(self.beat_cfg['model_name'], self.beat_cfg['model'],
85 | model_path=self.beat_cfg['model_path'], is_train=False, device=device)
86 | chord_net = get_model(self.chord_cfg['model_name'], self.chord_cfg['model'],
87 | model_path=self.chord_cfg['model_path'], is_train=False, device=device)
88 | segment_net = get_model(self.segment_cfg['model_name'], self.segment_cfg['model'],
89 | model_path=self.segment_cfg['model_path'], is_train=False, device=device)
90 | pitch_net = get_model(self.pitch_cfg['model_name'], self.pitch_cfg['model'],
91 | model_path=self.pitch_cfg['model_path'], is_train=False, device=device)
92 |
93 | vocal_waveform = inst_waveforms[0].numpy()
94 | orig_spec = get_spec(waveform, self.spec_cfg)
95 | inst_specs = get_specs(inst_waveforms, self.spec_cfg) # vocal, bass, drum, other
96 | vocal_spec = get_spec(inst_waveforms[0], self.spec_cfg) # vocal
97 | other_spec = get_mixed_spec(inst_waveforms[1:], self.spec_cfg) # bass + drum + other
98 |
99 | # pred lyrics
100 | lyrics, lyrics_matrix = get_lyrics(vocal_waveform, sample_rate, self.lyrics_cfg)
101 |
102 | with th.no_grad():
103 | # pred beat
104 | beat_features = inst_specs[:, :, :, :self.spec_cfg['n_fft'] // 2].unsqueeze(0) # B, S, C, T, F
105 | beat_features_mag = th.abs(beat_features)
106 |
107 | beat_pred, beat_logist = beat_net(beat_features_mag)
108 | print('beat info', beat_pred.shape, beat_logist.shape)
109 |
110 | # pred chord
111 | chord_features = other_spec[:, :, :self.spec_cfg['n_fft'] // 2].unsqueeze(0)
112 | chord_features_mag = th.abs(chord_features)
113 |
114 | chord_pred, chord_logist = chord_net(chord_features_mag, beat_logist)
115 | print('chord info', chord_pred.shape, chord_logist.shape)
116 |
117 | # pred segment
118 | segment_features = orig_spec[:, :, :self.spec_cfg['n_fft'] // 2].unsqueeze(0)
119 | segment_features_mag = th.abs(segment_features)
120 | segment_pred = segment_net(segment_features_mag, chord_logist)
121 | print('segment info', segment_pred.shape)
122 |
123 | # pred pitch
124 | pitch_features = vocal_spec[:, :, :self.spec_cfg['n_fft'] // 2].unsqueeze(0)
125 | pitch_features_mag = th.abs(pitch_features)
126 | pitch_pred, pitch_logist = pitch_net(pitch_features_mag, lyrics_matrix)
127 | print('pitch info', pitch_pred.shape, pitch_logist.shape)
128 |
129 | beats = beat_pred.squeeze(0).numpy()
130 | bpm = tempo(onset_envelope=beats, hop_length=self.tempo_cfg['hop_length']).tolist()
131 | chord_pred = chord_pred.squeeze(0)
132 | chords = get_chord_name(chord_pred)
133 | song_key = search_key(chords)
134 | segment_pred = segment_pred.squeeze(0)
135 | segment = get_segment_name(segment_pred)
136 | beats = beats.tolist()
137 | pitch_list = pitch_pred.squeeze(0).tolist()
138 |
139 | ret = {
140 | 'bpm': bpm,
141 | 'key': song_key,
142 | 'chords': chords,
143 | 'beat': beats,
144 | 'segment': segment,
145 | 'pitch': pitch_list,
146 | 'lyrics': lyrics,
147 | }
148 | return ret, inst_waveforms
149 |
150 |
--------------------------------------------------------------------------------
/python/models/transformers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def positional_encoding(batch_size, n_time, n_feature, zero_pad=False, scale=False, dtype=torch.float32):
7 | pos_indices = torch.tile(torch.unsqueeze(torch.arange(n_time), 0), [batch_size, 1])
8 |
9 | pos = torch.arange(n_time, dtype=dtype).reshape(-1, 1)
10 | pos_enc = pos / torch.pow(10000, 2 * torch.arange(0, n_feature, dtype=dtype) / n_feature)
11 | pos_enc[:, 0::2] = torch.sin(pos_enc[:, 0::2])
12 | pos_enc[:, 1::2] = torch.cos(pos_enc[:, 1::2])
13 |
14 | if zero_pad:
15 | pos_enc = torch.cat([torch.zeros(size=[1, n_feature]), pos_enc[1:, :]], 0)
16 |
17 | outputs = F.embedding(pos_indices, pos_enc)
18 |
19 | if scale:
20 | outputs = outputs * (n_feature ** 0.5)
21 |
22 | return outputs
23 |
24 |
25 | class FeedForward(nn.Module):
26 | def __init__(self, n_feature=2048, n_hidden=512, dropout=0.5):
27 | super().__init__()
28 | self.linear1 = nn.Linear(n_feature, n_hidden)
29 | self.linear2 = nn.Linear(n_hidden, n_feature)
30 |
31 | self.dropout1 = nn.Dropout(dropout)
32 | self.dropout2 = nn.Dropout(dropout)
33 | self.norm_layer = nn.LayerNorm(n_feature)
34 |
35 | def forward(self, x):
36 | y = self.linear1(x)
37 | y = F.relu(y)
38 | y = self.dropout1(y)
39 | y = self.linear2(y)
40 | y = self.dropout2(y)
41 | y = self.norm_layer(y)
42 | return y
43 |
44 |
45 | class EncoderFre(nn.Module):
46 | def __init__(self, n_freq, n_group, nhead=8, n_layers=5, dropout=0.5, pr=0.01):
47 | super().__init__()
48 | assert n_freq % n_group == 0
49 | self.d_model = d_model = n_freq // n_group
50 | self.n_freq = n_freq
51 | self.n_group = n_group
52 | self.pr = pr
53 | self.n_layers = n_layers
54 |
55 | self.attn_layer = nn.ModuleList()
56 | self.ff_layer = nn.ModuleList()
57 |
58 | for _ in range(self.n_layers):
59 | _attn_layer = nn.MultiheadAttention(d_model, nhead, batch_first=True)
60 | _ff_layer = FeedForward(d_model, dropout=dropout)
61 |
62 | self.attn_layer.append(_attn_layer)
63 | self.ff_layer.append(_ff_layer)
64 |
65 | self.dropout = nn.Dropout(dropout)
66 | self.fc = nn.Linear(n_freq, n_freq)
67 | self.norm_layer = nn.LayerNorm(n_freq)
68 |
69 | def forward(self, x):
70 | # x: batch, n_time, n_freq
71 | B, T, Fr = x.shape
72 | x_reshape = x.reshape(B * T, self.n_group, self.d_model)
73 | x_reshape += positional_encoding(
74 | batch_size=x_reshape.shape[0], n_time=x_reshape.shape[1], n_feature=x_reshape.shape[2]
75 | ) * self.pr
76 |
77 | for _attn_layer, _ff_layer in zip(self.attn_layer, self.ff_layer):
78 | x_reshape, _ = _attn_layer(x_reshape, x_reshape, x_reshape, need_weights=False)
79 | x_reshape = _ff_layer(x_reshape)
80 |
81 | y = x_reshape.reshape(B, T, self.n_freq)
82 |
83 | y = self.dropout(y)
84 | y = self.fc(y)
85 | y = self.norm_layer(y)
86 | return y
87 |
88 |
89 | class EncoderTime(nn.Module):
90 | def __init__(self, n_freq, nhead=8, n_layers=5, dropout=0.5, pr=0.02):
91 | super().__init__()
92 | self.n_freq = n_freq
93 | self.n_layers = n_layers
94 | self.pr = pr
95 |
96 | self.attn_layer = nn.ModuleList()
97 | self.ff_layer = nn.ModuleList()
98 | for _ in range(self.n_layers):
99 | _attn_layer = nn.MultiheadAttention(n_freq, nhead, batch_first=True)
100 | _ff_layer = FeedForward(n_freq, dropout=dropout)
101 |
102 | self.attn_layer.append(_attn_layer)
103 | self.ff_layer.append(_ff_layer)
104 |
105 | self.norm_layer = nn.LayerNorm(n_freq)
106 | self.dropout = nn.Dropout(dropout)
107 | self.fc = nn.Linear(n_freq, n_freq)
108 |
109 | def forward(self, x):
110 | # x: batch, n_time, n_freq
111 | B, T, Fr = x.shape
112 | x += positional_encoding(B, T, Fr) * self.pr
113 |
114 | for _attn_layer, _ff_layer in zip(self.attn_layer, self.ff_layer):
115 | x, _ = _attn_layer(x, x, x, need_weights=False)
116 | x = _ff_layer(x)
117 |
118 | x = self.dropout(x)
119 | x = self.fc(x)
120 | x = self.norm_layer(x)
121 | return x
122 |
123 |
124 | class Decoder(nn.Module):
125 |
126 | def __init__(self, d_model=512, nhead=8, n_layers=5, dropout=0.5, r1=1.0, r2=1.0, wr=1.0, pr=0.01):
127 | super().__init__()
128 |
129 | self.r1 = r1
130 | self.r2 = r2
131 | self.wr = wr
132 | self.n_layers = n_layers
133 | self.pr = pr
134 |
135 | self.attn1_layer = nn.ModuleList()
136 | self.attn2_layer = nn.ModuleList()
137 | self.ff_layer = nn.ModuleList()
138 | for _ in range(n_layers):
139 | _layer1 = nn.MultiheadAttention(d_model, nhead, batch_first=True)
140 | _layer2 = nn.MultiheadAttention(d_model, nhead, batch_first=True)
141 | _layer3 = FeedForward(d_model, dropout=dropout)
142 | self.attn1_layer.append(_layer1)
143 | self.attn2_layer.append(_layer2)
144 | self.ff_layer.append(_layer3)
145 | self.dropout = nn.Dropout(dropout)
146 | self.fc = nn.Linear(d_model, d_model)
147 |
148 | def forward(self, x1, x2, weight=None):
149 | y = x1 * self.r1 + x2 * self.r2
150 | if weight is not None:
151 | y += weight * self.wr
152 |
153 | y += positional_encoding(y.shape[0], y.shape[1], y.shape[2]) * self.pr + self.pr
154 |
155 | for i in range(self.n_layers):
156 | _attn1_layer = self.attn1_layer[i]
157 | _attn2_layer = self.attn2_layer[i]
158 | _ff_layer = self.ff_layer[i]
159 | y, _ = _attn1_layer(y, y, y, need_weights=False)
160 | y, _ = _attn2_layer(y, x2, x2, need_weights=False)
161 | y = _ff_layer(y)
162 | output = self.dropout(y)
163 | output = self.fc(output)
164 | return output, y
165 |
166 |
167 | class BaseTransformer(nn.Module):
168 | def __init__(self,
169 | n_channel=2,
170 | n_freq=2048,
171 | # EncoderFre
172 | n_group=32,
173 | f_layers=2,
174 | f_nhead=4,
175 | f_dropout=0.5,
176 | f_pr=0.01,
177 | # EncoderTime
178 | t_layers=2,
179 | t_nhead=4,
180 | t_dropout=0.5,
181 | t_pr=0.01,
182 | # Decoder
183 | d_layers=2,
184 | d_nhead=4,
185 | d_dropout=0.5,
186 | d_pr=0.02,
187 | r1=1.0,
188 | r2=1.0,
189 | wr=0.2,
190 | ):
191 | super().__init__()
192 | self.n_channel = n_channel
193 | self.encoder_fre_layers = nn.ModuleList()
194 | self.encoder_time_layers = nn.ModuleList()
195 | for _ in range(n_channel):
196 | _encoder_fre_layer = EncoderFre(n_freq=n_freq, n_group=n_group, nhead=f_nhead, n_layers=f_layers,
197 | dropout=f_dropout, pr=f_pr)
198 | _encoder_time_layer = EncoderTime(n_freq=n_freq, nhead=t_nhead, n_layers=t_layers, dropout=t_dropout,
199 | pr=t_pr)
200 | self.encoder_fre_layers.append(_encoder_fre_layer)
201 | self.encoder_time_layers.append(_encoder_time_layer)
202 |
203 | self.decoder = Decoder(d_model=n_freq, nhead=d_nhead, n_layers=d_layers, dropout=d_dropout,
204 | r1=r1, r2=r2, wr=wr, pr=d_pr)
205 |
206 | def forward(self, x, weight=None):
207 | # x: batch, channel, n_time, n_freq
208 | ff_list = []
209 | tf_list = []
210 | for i in range(self.n_channel):
211 | x1 = self.encoder_fre_layers[i](x[:, i, :, :])
212 | x2 = self.encoder_time_layers[i](x[:, i, :, :])
213 | ff_list.append(x1)
214 | tf_list.append(x2)
215 |
216 | y1 = torch.sum(torch.stack(ff_list, dim=0), dim=0)
217 | y2 = torch.sum(torch.stack(tf_list, dim=0), dim=0)
218 | y, w = self.decoder(y1, y2, weight=weight)
219 | return y, w
220 |
221 |
222 | if __name__ == '__main__':
223 | net = BaseTransformer(n_freq=2048)
224 | # net = EncoderTime(2048)
225 | # net = EncoderFre(2048, 32)
226 |
227 | # net = EncoderFre(2048, 8)
228 | x = torch.randn(1, 2, 1024, 2048)
229 |
230 | y, logits = net(x)
231 | y1, logits = net(x)
232 |
233 | # print(np.allclose(x, y))
234 | print(y.shape, logits.shape)
235 | print(y[0, 0, 0], y[0, 1, 0])
236 | print(y1[0, 0, 0], y1[0, 1, 0])
237 |
--------------------------------------------------------------------------------
/python/modulation.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from itertools import product
3 |
4 | from common import MAJOR_CHORDS, MINOR_CHORDS, MAJ2MIN_MAP
5 |
6 |
7 | def get_chord_tonic(chord):
8 | chord = chord.replace('Dim', 'dim').replace('Aug', 'aug')
9 | if len(chord) >= 3:
10 | if chord[1] in ["#", 'b']:
11 | if 'maj7' in chord:
12 | return chord[:2]
13 | # elif 'dim' in chord:
14 | # return chord[:2] + 'm'
15 | elif chord[2] in ['m']:
16 | return chord[:3]
17 | elif 'sus2' in chord or 'sus4' in chord or 'dim' in chord or 'aug' in chord:
18 | return chord
19 | else:
20 | return chord[:2]
21 | elif chord[1] == "5":
22 | return chord[:2]
23 | # elif 'dim' in chord:
24 | # return chord[:1] + 'm'
25 | elif 'sus2' in chord or 'sus4' in chord or 'dim' in chord or 'aug' in chord:
26 | return chord
27 | elif 'maj7' in chord:
28 | return chord.replace('maj7', '')
29 | elif '7M' in chord:
30 | return chord.replace('7M', '')
31 | elif chord[1] in ['m']:
32 | return chord[:2]
33 | else:
34 | return chord[:1]
35 | elif len(chord) == 2:
36 | if chord[1] in ['#', 'b', 'm', '5']:
37 | return chord[:2]
38 | else:
39 | return chord[:1]
40 | else:
41 | return chord[:1]
42 |
43 |
44 | def chord_name_filter(chords):
45 | search_name_list = []
46 | for i in chords:
47 | if not i:
48 | continue
49 | search_name_list.append(get_chord_tonic(i))
50 | return search_name_list
51 |
52 |
53 | def chord_to_tone(chord):
54 | tone = chord[0]
55 | if len(chord) > 2:
56 | if chord[1] in ['b', '#']:
57 | tone += chord[1]
58 | return tone
59 |
60 |
61 | def get_chords_repeat_num3(chords):
62 | chords_score_map = {}
63 | for ex_chords in MAJOR_CHORDS:
64 | chords_score_map[ex_chords[0]] = len([i for i in chords if i in ex_chords])
65 | return chords_score_map
66 |
67 |
68 | def process_sus_chords(filter_total_chords, sus_chords):
69 | tone_count_dic = defaultdict(int)
70 | for chord in sus_chords:
71 | tone_count_dic[chord] += 1
72 | keys = list(tone_count_dic)
73 | product_list = list(product(*([[0, 1]] * len(keys))))
74 | chords_score_list = []
75 | for item in product_list:
76 | sus_conv_chord = []
77 | for i, x in enumerate(item):
78 | key = keys[i]
79 | tone = chord_to_tone(key)
80 | if x == 0:
81 | sus_conv_chord.extend([tone] * tone_count_dic[key])
82 | else:
83 | sus_conv_chord.extend([tone + 'm'] * tone_count_dic[key])
84 | chords_score_map = get_chords_repeat_num3(filter_total_chords + sus_conv_chord)
85 | chords_score_list.extend([(k, i, item) for k, i in chords_score_map.items() if i != 0])
86 | chords_score_list = sorted(chords_score_list, key=lambda x: x[1], reverse=True)
87 | return chords_score_list, keys
88 |
89 |
90 | def get_ret_key(best_key, chord, keys, maj_min_dic, is_sus, ret_key, sus_dic):
91 | # print(best_key, "best_key", chord)
92 | for k, i, item in best_key:
93 | if 'sus' in chord or 'dim' in chord or 'aug' in chord or '5' in chord:
94 | tone = chord_to_tone(chord)
95 | if item[keys.index(chord)] == 1:
96 | tone += 'm'
97 | chord = tone
98 | if chord == maj_min_dic[k][0] or chord == maj_min_dic[k][5]:
99 | ret_key = k
100 | if is_sus:
101 | sus_dic = {keys[i]: chord_to_tone(keys[i]) if v == 0 else chord_to_tone(keys[i]) + 'm' for i, v in
102 | enumerate(item)}
103 | break
104 | return ret_key, sus_dic
105 |
106 |
107 | def get_ret_key2(best_key, head_chord, tail_chord, maj_min_dic, ret_key):
108 | for k, i in best_key:
109 | if head_chord == maj_min_dic[k][0] or head_chord == maj_min_dic[k][5]:
110 | ret_key = k
111 | break
112 | if not ret_key:
113 | for k, i in best_key:
114 | if tail_chord == maj_min_dic[k][0] or tail_chord == maj_min_dic[k][5]:
115 | ret_key = k
116 | break
117 | return ret_key
118 |
119 |
120 | def process_best_key(chords_score_list, head_chord, tail_chord, keys, maj_min_dic, ret_key, sus_dic, is_sus=True):
121 | # 获取key
122 | if not chords_score_list:
123 | return None, None
124 | if is_sus:
125 | best_key = [x for x in chords_score_list if x[1] == chords_score_list[0][1]]
126 | else:
127 | best_key = [x for x in chords_score_list if x[-1] == chords_score_list[0][-1]]
128 | if len(best_key) > 1 and is_sus:
129 | ret_key, sus_dic = get_ret_key(best_key, head_chord, keys, maj_min_dic, is_sus, ret_key, sus_dic)
130 | if not ret_key:
131 | ret_key, sus_dic = get_ret_key(best_key, tail_chord, keys, maj_min_dic, is_sus, ret_key, sus_dic)
132 | elif len(best_key) > 1 and not is_sus:
133 | ret_key = get_ret_key2(best_key, head_chord, tail_chord, maj_min_dic, ret_key)
134 | # else:
135 | # return None, None
136 | if not ret_key:
137 | if len(best_key[0]) == 2:
138 | ret_key = best_key[0][0]
139 | else:
140 | ret_key = best_key[0][0]
141 | sus_dic = {keys[i]: chord_to_tone(keys[i]) if v == 0 else chord_to_tone(keys[i]) + 'm' for i, v in
142 | enumerate(best_key[0][2])}
143 | return ret_key, sus_dic
144 |
145 |
146 | def get_maj_or_min_key(ret_key, maj_min_dic, head_chord, tail_chord, sus_dic):
147 | # check maj or min
148 | maj_key = ret_key
149 | min_key = MAJ2MIN_MAP[ret_key]
150 |
151 | maj_chords = maj_min_dic[maj_key]
152 |
153 | if 'sus' in head_chord or 'dim' in head_chord or 'aug' in head_chord or '5' in head_chord:
154 | head_chord = sus_dic[head_chord]
155 |
156 | if head_chord == maj_chords[0]:
157 | ret_key = maj_key
158 | elif head_chord == maj_chords[5]:
159 | ret_key = min_key
160 | else:
161 | if 'sus' in tail_chord or 'dim' in tail_chord or 'aug' in tail_chord or '5' in tail_chord:
162 | tail_chord = sus_dic[tail_chord]
163 |
164 | if tail_chord == maj_chords[0]:
165 | ret_key = maj_key
166 | elif tail_chord == maj_chords[5]:
167 | ret_key = min_key
168 | else:
169 | ret_key = maj_key
170 | return ret_key
171 |
172 |
173 | def search_key(total_chords):
174 | if not total_chords:
175 | return "C", 0
176 | maj_min_dic = {x[0]: x for x in MAJOR_CHORDS + MINOR_CHORDS}
177 | ret_key = None
178 | keys = []
179 | sus_dic = {}
180 | total_chords = chord_name_filter(total_chords)
181 | head_chord = total_chords[0]
182 | tail_chord = total_chords[-1]
183 |
184 | sus_chords = [x for x in total_chords if 'sus' in x or 'dim' in x or 'aug' in x or '5' in x]
185 | filter_total_chords = [x for x in total_chords if
186 | 'sus' not in x and 'dim' not in x and 'aug' not in x and '5' not in x]
187 | if not filter_total_chords and not sus_chords:
188 | return "C", 0
189 |
190 | if sus_chords:
191 | chords_score_list, keys = process_sus_chords(filter_total_chords, sus_chords)
192 |
193 | ret_key, sus_dic = process_best_key(chords_score_list, head_chord, tail_chord, keys, maj_min_dic, ret_key,
194 | sus_dic)
195 | else:
196 | chords_score_map = get_chords_repeat_num3(filter_total_chords)
197 | chords_score_list = [(k, i) for k, i in chords_score_map.items() if i != 0]
198 | chords_score_list = sorted(chords_score_list, key=lambda x: x[1], reverse=True)
199 | print(chords_score_list)
200 | ret_key, sus_dic = process_best_key(chords_score_list, head_chord, tail_chord, keys, maj_min_dic, ret_key,
201 | sus_dic, is_sus=False)
202 | key_score = 0
203 | for i in chords_score_list:
204 | if i[0] == ret_key:
205 | key_score = i[1]
206 | if not ret_key:
207 | return "C", 0
208 | ret_key = get_maj_or_min_key(ret_key, maj_min_dic, head_chord, tail_chord, sus_dic)
209 |
210 | return ret_key, key_score
211 |
212 |
213 | def chord_convert(chord):
214 | chord_convert_map = {
215 | "C": "Db",
216 | "D": "Eb",
217 | "E": "F",
218 | "F": "Gb",
219 | "G": "Ab",
220 | "A": "Bb",
221 | "B": "C"
222 | }
223 | if "#" in chord:
224 | a, b = chord.split('#')
225 | new_chord = chord_convert_map[a] + b
226 | else:
227 | new_chord = chord
228 | return new_chord
229 |
230 |
231 | if __name__ == '__main__':
232 | demo_chords = ['D', 'D', 'Bm7', 'Bm7', 'Asus4', 'A', 'G', 'G', 'D', 'D', 'D', 'Bm7', 'A', 'Asus4', 'A', 'G', 'D',
233 | 'D', 'D', 'D', 'Bm7', 'A', 'Asus4', 'A', 'G', 'D', 'E', 'D', 'D', 'Bm', 'Bm7', 'A', 'A', 'G', 'G',
234 | 'D', 'D', 'Bm7', 'Bm', 'Bm7', 'A', 'A', 'D', 'G', 'G', 'Bm7', 'Bm', 'A', 'G', 'G', 'A', 'D', 'D',
235 | 'D', 'D', 'A', 'Asus4', 'A', 'Asus4', 'Em7', 'G', 'D', 'D', 'D', 'D', 'D', 'A', 'A7', 'A7', 'G', 'G',
236 | 'D', 'D', 'Bm7', 'Bm7', 'A', 'A', 'G', 'G', 'D', 'D', 'Bm', 'Bm7', 'A', 'Asus4', 'A', 'D', 'G', 'G',
237 | 'Bm', 'A', 'G', 'G', 'A', 'G', 'G', 'Abm', 'Dbm', 'B', 'A', 'Abm7', 'Ab7', 'A', 'Gbm7', 'E', 'Gbm7',
238 | 'Ab7', 'Dbm7', 'B', 'A', 'E', 'Ab7', 'A', 'E', 'Gbm7', 'Ab7', 'Dbm', 'B', 'A', 'A', 'B', 'E', 'B',
239 | 'Abm', 'A', 'E', 'B', 'Abm7', 'Abm', 'A', 'F', 'C', 'A', 'Bb', 'F', 'C', 'A', 'F', 'F', 'Bb', 'Bb',
240 | 'F', 'F', 'Dm7', 'D', 'Dm7', 'C', 'C', 'Bb', 'Bb', 'F', 'F', 'C', 'Dm7', 'D', 'Dm7', 'F', 'Gsus4',
241 | 'C', 'Bb', 'Bb', 'Bb', 'Bb', 'F', 'F']
242 | demo_chords1 = []
243 | for i in demo_chords:
244 | demo_chords1.append(chord_convert(i))
245 | print(search_key(demo_chords))
246 | print(search_key(demo_chords1))
247 |
--------------------------------------------------------------------------------
/cover/models/pitch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchaudio
3 | from torch import nn
4 | import torch.nn.functional as F
5 |
6 | from utils.spec import stft
7 |
8 |
9 | class ResConvBlock(nn.Module):
10 | def __init__(self,
11 | in_channels,
12 | out_channels,
13 | kernel_size=(3, 3),
14 | stride=(1, 1),
15 | padding=(1, 1),
16 | momentum=0.01,
17 | bias=True):
18 | super().__init__()
19 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
20 | self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
21 | self.relu1 = nn.ReLU()
22 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, bias=bias)
23 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
24 | self.relu2 = nn.ReLU()
25 |
26 | if in_channels != out_channels or stride != (1, 1):
27 | self.residual_connection = nn.Conv2d(in_channels,
28 | out_channels,
29 | kernel_size=(1, 1),
30 | stride=stride,
31 | padding=(0, 0),
32 | bias=bias)
33 | else:
34 | self.residual_connection = nn.Identity()
35 |
36 | def forward(self, x):
37 | identity = self.residual_connection(x)
38 | x = self.conv1(x)
39 | x = self.bn1(x)
40 | x = self.relu1(x)
41 | x = self.conv2(x)
42 | x = self.bn2(x)
43 | x += identity
44 | x = self.relu2(x)
45 | return x
46 |
47 |
48 | class ResEncoderBlock(nn.Module):
49 | def __init__(self, in_channels, out_channels, n_blocks=1, momentum=0.01):
50 | super().__init__()
51 | self.n_blocks = n_blocks
52 | self.res_conv_layers = nn.ModuleList()
53 | for i in range(self.n_blocks):
54 | if i == 0:
55 | self.res_conv_layers.append(ResConvBlock(in_channels, out_channels, momentum=momentum))
56 | else:
57 | self.res_conv_layers.append(ResConvBlock(out_channels, out_channels, momentum=momentum))
58 |
59 | def forward(self, x):
60 | for i in range(self.n_blocks):
61 | x = self.res_conv_layers[i](x)
62 | return x
63 |
64 |
65 | class Encoder(nn.Module):
66 | def __init__(self,
67 | in_channels,
68 | out_channels,
69 | n_layers,
70 | pool_size,
71 | n_blocks=1,
72 | momentum=0.01):
73 | super().__init__()
74 | self.n_layers = n_layers
75 |
76 | self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
77 | self.layers = nn.ModuleList()
78 | self.pool_layers = nn.ModuleList()
79 | for i in range(self.n_layers):
80 | self.layers.append(ResEncoderBlock(in_channels, out_channels, n_blocks=n_blocks, momentum=momentum))
81 | self.pool_layers.append(nn.AvgPool2d(kernel_size=pool_size))
82 | in_channels = out_channels
83 | out_channels *= 2
84 | self.out_channels = out_channels
85 |
86 | def forward(self, x):
87 | x = self.bn(x)
88 | h = []
89 | for i in range(self.n_layers):
90 | t = self.layers[i](x)
91 | x = self.pool_layers[i](t)
92 | h.append(t)
93 | return x, h
94 |
95 |
96 | class Enhancer(nn.Module):
97 | def __init__(self, in_channels, out_channels, n_layers, n_blocks, momentum=0.01):
98 | super().__init__()
99 |
100 | self.n_layers = n_layers
101 | self.layers = nn.ModuleList()
102 | for i in range(self.n_layers):
103 | if i == 0:
104 | self.layers.append(ResEncoderBlock(in_channels, out_channels, n_blocks=n_blocks, momentum=momentum))
105 | else:
106 | self.layers.append(ResEncoderBlock(out_channels, out_channels, n_blocks=n_blocks, momentum=momentum))
107 |
108 | def forward(self, x):
109 | for i in range(self.n_layers):
110 | x = self.layers[i](x)
111 | return x
112 |
113 |
114 | class ResDecoderBlock(nn.Module):
115 | def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
116 | super().__init__()
117 | self.n_blocks = n_blocks
118 | self.conv = nn.Sequential(
119 | nn.ConvTranspose2d(
120 | in_channels=in_channels,
121 | out_channels=out_channels,
122 | kernel_size=(3, 3),
123 | stride=stride,
124 | padding=(1, 1),
125 | output_padding=(0, 1) if stride == (1, 2) else (1, 1),
126 | bias=False,
127 | ),
128 | nn.BatchNorm2d(out_channels, momentum=momentum),
129 | nn.ReLU()
130 | )
131 | self.res_conv_layers = nn.ModuleList()
132 | for i in range(self.n_blocks):
133 | if i == 0:
134 | self.res_conv_layers.append(ResConvBlock(out_channels * 2, out_channels, momentum=momentum))
135 | else:
136 | self.res_conv_layers.append(ResConvBlock(out_channels, out_channels, momentum=momentum))
137 |
138 | def forward(self, x, h):
139 | x = self.conv(x)
140 | x = torch.cat((x, h), dim=1)
141 | for i in range(self.n_blocks):
142 | x = self.res_conv_layers[i](x)
143 | return x
144 |
145 |
146 | class Decoder(nn.Module):
147 | def __init__(self,
148 | in_channels,
149 | n_layers,
150 | stride,
151 | n_blocks,
152 | momentum=0.01):
153 | super().__init__()
154 | self.n_layers = n_layers
155 | self.layers = nn.ModuleList()
156 | for i in range(self.n_layers):
157 | self.layers.append(
158 | ResDecoderBlock(in_channels, in_channels // 2, stride, n_blocks=n_blocks, momentum=momentum))
159 | in_channels //= 2
160 |
161 | def forward(self, x, h):
162 | for i in range(self.n_layers):
163 | x = self.layers[i](x, h[-i - 1])
164 | return x
165 |
166 |
167 | class BiGRU(nn.Module):
168 | def __init__(self, input_size, hidden_size, num_layers, dropout=0.25):
169 | super().__init__()
170 | self.gru = nn.GRU(input_size,
171 | hidden_size,
172 | num_layers=num_layers,
173 | batch_first=True,
174 | bidirectional=True,
175 | dropout=dropout)
176 |
177 | def forward(self, x):
178 | x, _ = self.gru(x)
179 | return x
180 |
181 |
182 | class MelSpectrogram(nn.Module):
183 | def __init__(self,
184 | nfft,
185 | n_mels,
186 | hop_length = None,
187 | mel_f_min=30.0,
188 | mel_f_max=8000.0,
189 | samplerate=44100):
190 | super().__init__()
191 | self.samplerate = samplerate
192 | self.nfft = nfft
193 | self.n_stft = nfft // 2 + 1
194 | self.hop_length = nfft // 4 if hop_length is None else hop_length
195 | self.n_mels = n_mels
196 | self.mel_scale = torchaudio.transforms.MelScale(
197 | n_mels=n_mels, sample_rate=samplerate, n_stft=self.n_stft, f_min=mel_f_min, f_max=mel_f_max)
198 |
199 | def _stft(self, x):
200 | hl = self.hop_length
201 | nfft = self.nfft
202 | return stft(x, fft_size=nfft, hop_length=hl)
203 |
204 | def _mel(self, stft):
205 | magnitude = stft.abs().pow(2)
206 | mel = self.mel_scale(magnitude)
207 | return mel
208 |
209 | def forward(self, x):
210 | stft = self._stft(x)
211 | mel = self._mel(stft)
212 | return mel
213 |
214 |
215 | class PitchNet(nn.Module):
216 |
217 | def __init__(self,
218 | nfft,
219 | mel_f_min=30.0,
220 | mel_f_max=8000.0,
221 | hop_length=None,
222 | samplerate=16000,
223 | en_layers=5,
224 | re_layers=4,
225 | de_layers=5,
226 | n_blocks=5,
227 | in_channels=1,
228 | en_out_channels=16,
229 | pool_size=(2, 2),
230 | n_gru=1,
231 | n_mels=128,
232 | n_classes=360):
233 | super().__init__()
234 |
235 | self.mel_spectrogram = MelSpectrogram(n_mels=n_mels,
236 | samplerate=samplerate,
237 | hop_length=hop_length,
238 | nfft=nfft,
239 | mel_f_min=mel_f_min,
240 | mel_f_max=mel_f_max)
241 | self.encoder = Encoder(in_channels, en_out_channels, en_layers, n_blocks=n_blocks, pool_size=pool_size)
242 | self.enhancer = Enhancer(self.encoder.out_channels // 2, self.encoder.out_channels, re_layers,
243 | n_blocks=n_blocks)
244 | self.decoder = Decoder(self.encoder.out_channels, de_layers, stride=pool_size, n_blocks=n_blocks)
245 |
246 | self.conv = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
247 | if n_gru > 0:
248 | self.fc = nn.Sequential(
249 | BiGRU(3 * n_mels, 256, num_layers=n_gru),
250 | nn.Linear(256 * 2, n_classes),
251 | nn.Dropout(0.25),
252 | nn.Sigmoid()
253 | )
254 | else:
255 | self.fc = nn.Sequential(
256 | nn.Linear(3 * n_mels, n_classes),
257 | nn.Dropout(0.25),
258 | nn.Sigmoid()
259 | )
260 |
261 | def forward(self, x):
262 | x = self.mel_spectrogram(x)
263 | n_frames = x.size(-1)
264 | n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
265 | if n_pad > 0:
266 | x = F.pad(x, (0, n_pad), mode="constant")
267 | x = x.transpose(-1, -2).unsqueeze(1)
268 | x, h = self.encoder(x)
269 | x = self.enhancer(x)
270 | x = self.decoder(x, h)
271 | x = self.conv(x)
272 | x = x.transpose(1, 2).flatten(-2)
273 | x = self.fc(x)
274 | x = x[:, :n_frames]
275 | return x
276 |
277 |
--------------------------------------------------------------------------------
/cover/models/generate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | import torchaudio.transforms as transforms
5 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
6 |
7 | from utils.spec import stft, istft, chroma
8 | from utils.noise import gaussian_noise, markov_noise, random_walk_noise, spectral_folding
9 |
10 |
11 | def positional_encoding(batch_size, n_time, n_feature, zero_pad=False, scale=False, dtype=torch.float32):
12 | pos_indices = torch.tile(torch.unsqueeze(torch.arange(n_time), 0), [batch_size, 1])
13 |
14 | pos = torch.arange(n_time, dtype=dtype).reshape(-1, 1)
15 | pos_enc = pos / torch.pow(10000, 2 * torch.arange(0, n_feature, dtype=dtype) / n_feature)
16 | pos_enc[:, 0::2] = torch.sin(pos_enc[:, 0::2])
17 | pos_enc[:, 1::2] = torch.cos(pos_enc[:, 1::2])
18 |
19 | if zero_pad:
20 | pos_enc = torch.cat([torch.zeros(size=[1, n_feature]), pos_enc[1:, :]], 0)
21 |
22 | outputs = F.embedding(pos_indices, pos_enc)
23 |
24 | if scale:
25 | outputs = outputs * (n_feature ** 0.5)
26 |
27 | return outputs
28 |
29 |
30 | class FeedForward(nn.Module):
31 | def __init__(self, n_feature=2048, n_hidden=512, dropout=0.5):
32 | super().__init__()
33 | self.linear1 = nn.Linear(n_feature, n_hidden)
34 | self.linear2 = nn.Linear(n_hidden, n_feature)
35 |
36 | self.dropout1 = nn.Dropout(dropout)
37 | self.dropout2 = nn.Dropout(dropout)
38 | self.norm_layer = nn.LayerNorm(n_feature)
39 |
40 | def forward(self, x):
41 | y = self.linear1(x)
42 | y = F.relu(y)
43 | y = self.dropout1(y)
44 | y = self.linear2(y)
45 | y = self.dropout2(y)
46 | y = self.norm_layer(y)
47 | return y
48 |
49 |
50 | class MeanStdNormalization(nn.Module):
51 | def __init__(self, dim=0, eps=1e-5):
52 | super(MeanStdNormalization, self).__init__()
53 | self.dim = dim
54 | self.eps = eps
55 |
56 | def forward(self, x):
57 | mean = x.mean(dim=self.dim, keepdim=True)
58 | std = x.std(dim=self.dim, keepdim=True) + self.eps
59 | x_normalized = (x - mean) / std
60 | return x_normalized
61 |
62 |
63 | class FLEXLayer(nn.Module):
64 | def __init__(self, input_dim, output_dim, num_layers=3, dropout_rate=0.3):
65 | super().__init__()
66 |
67 | self.layers = nn.ModuleList()
68 | self.dropout = nn.Dropout(dropout_rate)
69 |
70 | if input_dim != output_dim:
71 | self.projection = nn.Conv1d(input_dim, output_dim, kernel_size=1)
72 | else:
73 | self.projection = None
74 |
75 | for _ in range(num_layers):
76 | self.layers.append(nn.Conv1d(output_dim, output_dim, kernel_size=3, padding=1))
77 | self.layers.append(nn.ReLU())
78 | self.layers.append(nn.BatchNorm1d(output_dim))
79 |
80 | def forward(self, x):
81 |
82 | if self.projection is not None:
83 | x = self.projection(x)
84 |
85 | x1 = x
86 | for layer in self.layers:
87 | x1 = layer(x1)
88 |
89 | x += x1
90 | x = self.dropout(x)
91 |
92 | return x
93 |
94 |
95 | class AFALayer(nn.Module):
96 | def __init__(self, d_model=2048, hidden_size=2048, n_conv_layers=3):
97 | super().__init__()
98 |
99 | self.lstm = nn.LSTM(input_size=d_model, hidden_size=hidden_size,
100 | batch_first=True, bidirectional=False)
101 |
102 | conv_layers = []
103 | for _ in range(n_conv_layers):
104 | conv_layers.append(nn.Conv1d(in_channels=hidden_size, out_channels=hidden_size,
105 | kernel_size=3, padding=1))
106 | conv_layers.append(nn.ReLU())
107 |
108 | self.conv = nn.Sequential(*conv_layers)
109 |
110 | self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=8, batch_first=True)
111 |
112 | def forward(self, x1, x2, x3):
113 | x1 = x1.permute(0, 2, 1)
114 | x2 = x2.permute(0, 2, 1)
115 | x3 = x3.permute(0, 2, 1)
116 |
117 | combined_features = torch.cat((x1, x2, x3), dim=-1)
118 |
119 | lstm_out, _ = self.lstm(combined_features)
120 |
121 | lstm_out = lstm_out.transpose(1, 2)
122 | conv_out = self.conv(lstm_out)
123 | conv_out = conv_out.transpose(1, 2)
124 |
125 | attn_output, _ = self.attention(conv_out, conv_out, conv_out)
126 |
127 | return attn_output
128 |
129 |
130 | class TimbreBlock(nn.Module):
131 | def __init__(self,
132 | n_fs=2049,
133 | n_fm=13,
134 | n_fc=12,
135 | n_out=768,
136 | nhead=8,
137 | num_layers=4,
138 | d_model=2048):
139 | super().__init__()
140 |
141 | self.auto_feature_x1 = FLEXLayer(input_dim=n_fs, output_dim=d_model)
142 | self.auto_feature_x2 = FLEXLayer(input_dim=n_fm, output_dim=d_model)
143 | self.auto_feature_x3 = FLEXLayer(input_dim=n_fc, output_dim=d_model)
144 |
145 | self.attention_fusion = AFALayer(d_model=d_model * 3)
146 |
147 | encoder_layers = TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
148 | self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers=num_layers)
149 |
150 | self.output_layer = nn.Linear(d_model, n_out)
151 |
152 | self.norm = MeanStdNormalization(dim=-1)
153 |
154 | def forward(self, x1, x2, x3):
155 | x1_transformed = self.auto_feature_x1(x1)
156 | x2_transformed = self.auto_feature_x2(self.norm(x2))
157 | x3_transformed = self.auto_feature_x3(self.norm(x3))
158 |
159 | fused_features = self.attention_fusion(x1_transformed, x2_transformed,
160 | x3_transformed)
161 |
162 | encoded_features = self.transformer_encoder(fused_features)
163 | output = self.output_layer(encoded_features)
164 | return output
165 |
166 |
167 | class EncoderBlock(nn.Module):
168 | def __init__(self, n_feat, nhead=8, n_layers=5, dropout=0.5, pr=0.02):
169 | super().__init__()
170 | self.n_feat = n_feat
171 | self.n_layers = n_layers
172 | self.pr = pr
173 |
174 | self.attn_layer1 = nn.ModuleList()
175 | self.attn_layer2 = nn.ModuleList()
176 | self.ff_layer1 = nn.ModuleList()
177 | self.ff_layer2 = nn.ModuleList()
178 | for _ in range(self.n_layers):
179 | _attn_layer1 = nn.MultiheadAttention(n_feat, nhead, batch_first=True)
180 | _attn_layer2 = nn.MultiheadAttention(n_feat, nhead, batch_first=True)
181 | _ff_layer1 = FeedForward(n_feat, dropout=dropout)
182 | _ff_layer2 = FeedForward(n_feat, dropout=dropout)
183 | self.attn_layer1.append(_attn_layer1)
184 | self.attn_layer2.append(_attn_layer2)
185 | self.ff_layer1.append(_ff_layer1)
186 | self.ff_layer2.append(_ff_layer2)
187 |
188 | self.norm_layer = nn.LayerNorm(n_feat)
189 | self.dropout = nn.Dropout(dropout)
190 | self.fc = nn.Linear(n_feat, n_feat)
191 |
192 | def forward(self, x, x1):
193 | x = x.permute(0, 2, 1)
194 | x1 = x1.permute(0, 2, 1)
195 |
196 | B, T, Ft = x.shape
197 | x += positional_encoding(B, T, Ft) * self.pr
198 | B1, T1, Ft1 = x1.shape
199 | x1 += positional_encoding(B1, T1, Ft1) * self.pr
200 |
201 | for _attn_layer1, _attn_layer2, _ff_layer1, _ff_layer2 in zip(self.attn_layer1,
202 | self.attn_layer2,
203 | self.ff_layer1,
204 | self.ff_layer2):
205 | x, _ = _attn_layer1(x, x, x, need_weights=False)
206 | x1, _ = _attn_layer2(x, x1, x1, need_weights=False)
207 | x = _ff_layer1(x)
208 | x1 = _ff_layer2(x1)
209 |
210 | x = torch.stack([x, x1], dim=1)
211 | x = self.dropout(x)
212 | x = self.fc(x)
213 | x = self.norm_layer(x)
214 | x1 = x[:, 1, ].permute(0, 2, 1)
215 | x = x[:, 0, ].permute(0, 2, 1)
216 | return x, x1
217 |
218 |
219 | class Encoder(nn.Module):
220 | def __init__(self,
221 | n_feat,
222 | n_stft,
223 | out_feat,
224 | n_block=5,
225 | nhead=8,
226 | n_layers=5,
227 | dropout=0.25,
228 | pr=0.02):
229 | super().__init__()
230 |
231 | self.n_block = n_block
232 |
233 | self.conv_pre1 = nn.Sequential(
234 | nn.ConvTranspose1d(n_stft, n_feat, kernel_size=1),
235 | nn.BatchNorm1d(n_feat),
236 | nn.LeakyReLU(0.01),
237 | )
238 | self.conv_pre2 = nn.Sequential(
239 | nn.ConvTranspose1d(n_feat, n_feat, kernel_size=1),
240 | nn.BatchNorm1d(n_feat),
241 | nn.LeakyReLU(0.01),
242 | )
243 | self.layers = nn.ModuleList()
244 | for i in range(n_block):
245 | self.layers.append(EncoderBlock(n_feat, nhead=nhead, n_layers=n_layers, dropout=dropout, pr=pr))
246 |
247 | self.dropout = nn.Dropout(dropout)
248 | self.conv = nn.Sequential(
249 | nn.Conv1d(n_feat * 2, n_feat, kernel_size=1),
250 | nn.BatchNorm1d(n_feat),
251 | nn.ReLU(),
252 | nn.Conv1d(n_feat, n_feat, kernel_size=1),
253 | nn.BatchNorm1d(n_feat),
254 | nn.ReLU(),
255 | )
256 | self.linear = nn.Linear(n_feat, out_feat)
257 |
258 | def forward(self, x, x1, x2):
259 | x = self.conv_pre1(x)
260 | x1 = x1 + x2
261 | x1 = x1.transpose(-1, -2)
262 | x1 = self.conv_pre2(x1)
263 |
264 | for layer in self.layers:
265 | x, x1 = layer(x, x1)
266 |
267 | z = torch.hstack([x, x1])
268 | z = self.conv(z)
269 | z = z.transpose(-2, -1)
270 | z = self.dropout(z)
271 | logits = self.linear(z)
272 | return logits, z
273 |
274 |
275 | class DecoderBlock(nn.Module):
276 |
277 | def __init__(self, d_model, out_feat, nhead, n_layers, dropout=0.5):
278 | super().__init__()
279 | self.d_model = d_model
280 | self.nhead = nhead
281 | self.n_layers = n_layers
282 |
283 | self.attn1_layer = nn.ModuleList()
284 | self.attn2_layer = nn.ModuleList()
285 | self.ff_layer = nn.ModuleList()
286 | for _ in range(n_layers):
287 | _layer1 = nn.MultiheadAttention(d_model, nhead, batch_first=True)
288 | _layer2 = nn.MultiheadAttention(d_model, nhead, batch_first=True)
289 | _layer3 = FeedForward(d_model, dropout=dropout)
290 | self.attn1_layer.append(_layer1)
291 | self.attn2_layer.append(_layer2)
292 | self.ff_layer.append(_layer3)
293 |
294 | self.dropout = nn.Dropout(dropout)
295 | self.fc = nn.Sequential(
296 | nn.Linear(d_model, d_model),
297 | nn.Dropout(dropout),
298 | nn.LeakyReLU(0.01),
299 | nn.Linear(d_model, out_feat)
300 | )
301 |
302 | def forward(self, x, x1):
303 | for i in range(self.n_layers):
304 | _attn1_layer = self.attn1_layer[i]
305 | _attn2_layer = self.attn2_layer[i]
306 | _ff_layer = self.ff_layer[i]
307 | x, _ = _attn1_layer(x, x, x, need_weights=False)
308 | x, _ = _attn2_layer(x, x1, x1, need_weights=False)
309 | x = _ff_layer(x)
310 |
311 | output = self.dropout(x)
312 | output = self.fc(output)
313 | return output
314 |
315 |
316 | class Decoder(nn.Module):
317 |
318 | def __init__(self,
319 | n_feat,
320 | n_timbre,
321 | n_pitch,
322 | nfft=1024,
323 | d_model=512,
324 | nhead=8,
325 | n_layers=5,
326 | dropout=0.5,
327 | pr=0.01):
328 | super().__init__()
329 |
330 | self.n_layers = n_layers
331 | self.d_model = d_model
332 | self.pr = pr
333 |
334 | self.emb_feat = nn.Linear(n_feat, d_model)
335 | self.emb_timbre = nn.Linear(n_timbre, d_model)
336 | self.emb_pitch = nn.Linear(n_pitch, d_model)
337 |
338 | self.lrelu = nn.LeakyReLU(0.01, inplace=True)
339 |
340 | n_out = nfft // 2 + 1
341 | self.dec = DecoderBlock(d_model, n_out, nhead, n_layers, dropout=dropout)
342 |
343 | self.exp_layer = nn.Sequential(
344 | nn.Conv1d(d_model, d_model, kernel_size=1, bias=True),
345 | nn.BatchNorm1d(d_model),
346 | nn.Conv1d(d_model, d_model, kernel_size=1, bias=True),
347 | nn.LeakyReLU(0.01),
348 | nn.Conv1d(d_model, d_model, kernel_size=1, bias=True),
349 | nn.BatchNorm1d(d_model),
350 | nn.LeakyReLU(0.01),
351 | nn.Conv1d(d_model, n_out, kernel_size=1, bias=True),
352 | nn.BatchNorm1d(n_out),
353 | nn.LeakyReLU(0.01),
354 | )
355 | self.mask_layer = nn.Sequential(
356 | nn.Linear(n_out, n_out),
357 | nn.LeakyReLU(0.01),
358 | nn.Linear(n_out, n_out),
359 | nn.LeakyReLU(0.01),
360 | nn.Sigmoid()
361 | )
362 |
363 | def forward(self, x1, x2, pitch):
364 | x = self.emb_timbre(x1) + self.emb_feat(x2)
365 | x = x * self.d_model ** (1 / 2)
366 | x = self.lrelu(x)
367 | y = x.transpose(-2, -1)
368 | y = self.exp_layer(y)
369 |
370 | x2 = self.emb_pitch(pitch)
371 | x += positional_encoding(x.shape[0], x.shape[1], x.shape[2]) * self.pr + self.pr
372 | x = self.dec(x, x2)
373 |
374 | y = y.transpose(-2, -1)
375 | z = y * x
376 |
377 | m = self.mask_layer(z)
378 | m = m.transpose(-2, -1)
379 | return m, z, x
380 |
381 |
382 | class CombineNet(nn.Module):
383 |
384 | def __init__(self,
385 | n_hidden,
386 | n_timbre,
387 | nfft=1024,
388 | hop_length=160,
389 | samplerate=16000,
390 | n_feat=768,
391 | n_pitch=360,
392 | n_mfcc=40,
393 | n_chroma=12,
394 | n_cqt=84,
395 | n_tb_head=8,
396 | n_tb_layers=5,
397 | tb_dropout=0.5,
398 | pr=0.02,
399 | noise_type=None,
400 | noise_kw=None,
401 | ):
402 | super().__init__()
403 |
404 | self.hop_length = hop_length
405 | self.nfft = nfft
406 | self.n_cqt = n_cqt
407 | self.n_chroma = n_chroma
408 | self.samplerate = samplerate
409 | self.noise_type = noise_type
410 | self.noise_kw = noise_kw
411 |
412 | self.mfcc_layer = transforms.MFCC(sample_rate=samplerate, n_mfcc=n_mfcc,
413 | melkwargs={"n_fft": nfft, "hop_length": hop_length})
414 |
415 | self.timber = TimbreBlock(n_fs=(nfft // 2) + 1, n_fc=n_chroma, n_fm=n_mfcc, n_out=n_feat)
416 |
417 | self.encoder = Encoder(
418 | n_feat,
419 | (nfft // 2) + 1,
420 | n_timbre,
421 | nhead=n_tb_head,
422 | n_layers=n_tb_layers,
423 | dropout=tb_dropout,
424 | pr=pr)
425 |
426 | self.dec_p = Decoder(
427 | n_feat=n_feat,
428 | n_timbre=n_timbre,
429 | n_pitch=n_pitch,
430 | d_model=n_hidden,
431 | nhead=8,
432 | n_layers=5,
433 | dropout=0.5,
434 | )
435 |
436 | def _stft(self, x):
437 | hl = self.hop_length
438 | nfft = self.nfft
439 | return stft(x, fft_size=nfft, hop_length=hl)
440 |
441 | def _chroma(self, x):
442 | hl = self.hop_length
443 | return chroma(x, n_chroma=self.n_chroma, sample_rate=self.samplerate, hop_length=hl)
444 |
445 | def _mfcc(self, x):
446 | return self.mfcc_layer(x)
447 |
448 | def _istft(self, x, n_time):
449 | hl = self.hop_length
450 | return istft(x, hop_length=hl, signal_length=n_time)
451 |
452 | def _magnitude(self, z):
453 | # return the magnitude of the spectrogram, except when cac is True,
454 | # in which case we just move the complex dimension to the channel one.
455 | if self.cac:
456 | B, C, Fr, T = z.shape
457 | m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
458 | m = m.reshape(B, C * 2, Fr, T)
459 | else:
460 | m = z.abs()
461 | return m
462 |
463 | def _noise(self, x):
464 | if self.noise_type is None:
465 | x = x
466 | else:
467 | noise_kw = self.noise_kw or {}
468 | if self.noise_type == 'gaussian':
469 | x = gaussian_noise(x, **noise_kw)
470 | elif self.noise_type == 'markov':
471 | x = markov_noise(x, **noise_kw)
472 | elif self.noise_type == 'random_walk':
473 | x = random_walk_noise(x, **noise_kw)
474 | elif self.noise_type == 'spectral_folding':
475 | x = spectral_folding(x, **noise_kw)
476 | return x
477 |
478 | def forward(self, audio, feat, f0, is_noise=False):
479 | n_len = audio.shape[-1]
480 | x = self._stft(audio)
481 | x_mfcc = self._mfcc(audio)
482 | x_chroma = self._chroma(audio)
483 |
484 | _mag = x.abs()
485 | if is_noise:
486 | _mag = self._noise(_mag)
487 |
488 | _, f0_t, f0_f = f0.shape
489 | _, ft_t, ft_f = feat.shape
490 | if f0_t < ft_t:
491 | feat = feat[:, :f0_t, :]
492 | elif f0_t > ft_t:
493 | feat = F.pad(feat, (0, 0, 0, f0_t - ft_t))
494 |
495 | timber_feat = self.timber(_mag, x_mfcc, x_chroma)
496 | logits, h = self.encoder(_mag, feat, timber_feat)
497 | m, z, y = self.dec_p(logits, feat, f0)
498 |
499 | z = x * m
500 | o = self._istft(z, n_len)
501 | o = o[..., :n_len]
502 | return o, z, (y, m, logits, h)
503 |
--------------------------------------------------------------------------------