├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── build_svp.py
├── infer.py
├── inference
├── __init__.py
├── base_infer.py
├── me_infer.py
└── me_quant_infer.py
├── modules
├── __init__.py
├── attention
│ ├── __init__.py
│ └── base_attention.py
├── conform
│ ├── Gconform.py
│ └── __init__.py
├── conv
│ └── base_conv.py
├── model
│ ├── Gmidi_conform.py
│ └── __init__.py
└── rmvpe
│ ├── __init__.py
│ ├── constants.py
│ ├── deepunet.py
│ ├── inference.py
│ ├── model.py
│ ├── seq.py
│ ├── spec.py
│ └── utils.py
├── requirements.txt
├── template.json
├── utils
├── __init__.py
├── infer_utils.py
├── pitch_utils.py
└── slicer2.py
├── webui.py
└── weights
└── config.yaml
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.ckpt
2 | *.pt
3 | __pycache__
4 | input/
5 | results/
6 | workenv/
7 | temp.txt
8 | *.svp
9 | *.bat
10 | *.7z
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Sucial
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # wav2svp: Waveform to Synthesizer V Project
4 |
5 |
6 |
7 | ### Description
8 |
9 | wav2svp is a project that converts a waveform to a Synthesizer V Project (SVP) file. It is based on the [SOME](https://github.com/openvpi/SOME) and [RMVPE](https://github.com/Dream-High/RMVPE). In addition to automatically extracting MIDI, this project can also extract **pitch data**, tension data(Experimental) and breathiness data(Experimental) simultaneously. But unfortunately, at present, it's unable to simultaneously extract lyrics.
10 |
11 | ### Usage
12 |
13 | You can download the **One click startup package** from [releases](https://github.com/SUC-DriverOld/wav2svp/releases), unzip and double click `go-webui.bat` to start the WebUI.
14 |
15 | ### Run from Code
16 |
17 | 1. Clone this repository and install the dependencies. We recommand to use python 3.10.
18 |
19 | ```shell
20 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
21 | pip install -r requirements.txt
22 | ```
23 |
24 | 3. Download pre-trained models:
25 |
26 | - [0119_continuous128_5spk](https://github.com/openvpi/SOME/releases/download/v1.0.0-baseline/0119_continuous128_5spk.zip) and unzip it to `weights`.
27 | - [rmvpe](https://github.com/yxlllc/RMVPE/releases/download/230917/rmvpe.zip) and unzip it to `weights` and rename it to `rmvpe.pt`.
28 | - Order the `weights` folder as follows:
29 |
30 | ```shell
31 | weights
32 | ├-config.yaml
33 | ├-model_steps_64000_simplified.ckpt
34 | └-rmvpe.pt
35 | ```
36 |
37 | 4. Run the following command to start WebUI:
38 |
39 | ```shell
40 | python webui.py
41 | ```
42 |
43 | 5. You can download the inference results from WebUI interface or from the `results` folder.
44 |
45 | ### Command Line Usage
46 |
47 | Use `infer.py`:
48 |
49 | ```shell
50 | usage: infer.py [-h] [--model_path MODEL_PATH] [--tempo TEMPO] [--extract_pitch] [--extract_tension] [--extract_breathiness] audio_path
51 |
52 | Inference for wav2svp
53 |
54 | positional arguments:
55 | audio_path Path to the input audio file
56 |
57 | options:
58 | -h, --help show this help message and exit
59 | --model_path MODEL_PATH
60 | Path to the model file, default: weights/model_steps_64000_simplified.ckpt
61 | --tempo TEMPO Tempo value for the midi file, default: 120
62 | --extract_pitch Whether to extract pitch from the audio file, default: False
63 | --extract_tension Whether to extract tension from the audio file, default: False
64 | --extract_breathiness
65 | Whether to extract breathiness from the audio file, default: False
66 | ```
67 |
68 | You can find the results in the `results` folder.
69 |
70 | ### Thanks
71 |
72 | - [openvpi/SOME] [openvpi/SOME](https://github.com/openvpi/SOME)
73 | - [Dream-High/RMVPE] [Dream-High/RMVPE](https://github.com/Dream-High/RMVPE)
74 | - [yxlllc/RMVPE] [yxlllc/RMVPE](https://github.com/yxlllc/RMVPE)
--------------------------------------------------------------------------------
/build_svp.py:
--------------------------------------------------------------------------------
1 | import json
2 | import uuid
3 | import os
4 | import math
5 |
6 |
7 | per_dur = 705600000 # 每拍在sv的时长
8 | time_per_frame = 0.02 # 每帧的时间 hop_size / sample_rate
9 |
10 |
11 | def build_svp(template, midis, arguments, tempo, basename, extract_pitch, extract_tension, extract_breathiness) -> str:
12 | notes = [] # 用于保存的音符数据
13 | datas = [] # 用于记录的音符数据
14 | new_uuid = str(uuid.uuid4()).lower()
15 |
16 | per_time = 60 / tempo # 每拍的时间
17 | template["time"]["tempo"] = [{"position": 0, "bpm": tempo}]
18 |
19 | index = 0
20 | for midi in midis:
21 | offset = int(arguments[index]["offset"] / per_time * per_dur) # 音符的起始时间在sv的时长
22 |
23 | dur = midi["note_dur"] # 音符的时长
24 | pitch = midi["note_midi"] # 音符的音高
25 | rest = midi["note_rest"] # 是否为休止符
26 | midi_duration = 0 # 该段音符的总时长
27 |
28 | for i in range(len(pitch)):
29 | current_duration = dur[i] / per_time * per_dur # 当前音符在sv的时长
30 | onset = midi_duration + offset # 音符的起始时间
31 | midi_duration += int(current_duration)
32 | if rest[i]: # 休止符
33 | continue
34 | current_pitch = round(pitch[i])
35 |
36 | note = {
37 | "musicalType": "singing",
38 | "onset": int(onset),
39 | "duration": int(current_duration),
40 | "lyrics": "la",
41 | "phonemes": "",
42 | "accent": "",
43 | "pitch": int(current_pitch),
44 | "detune": 0,
45 | "instantMode": False,
46 | "attributes": {"evenSyllableDuration": True},
47 | "systemAttributes": {"evenSyllableDuration": True},
48 | "pitchTakes": {"activeTakeId": 0,"takes": [{"id": 0,"expr": 0,"liked": False}]},
49 | "timbreTakes": {"activeTakeId": 0,"takes": [{"id": 0,"expr": 0,"liked": False}]}
50 | }
51 | notes.append(note)
52 |
53 | data = {"start": int(onset),"finish": int(current_duration + onset),"pitch": int(current_pitch)}
54 | datas.append(data)
55 | index += 1
56 |
57 | template["tracks"][0]["mainGroup"]["notes"] = notes
58 | template["tracks"][0]["mainGroup"]["uuid"] = new_uuid
59 | template["tracks"][0]["mainRef"]["groupID"] = new_uuid
60 |
61 | pitch, tension, breathiness = [], [], []
62 |
63 | if extract_pitch:
64 | pitch = build_pitch(datas, arguments, tempo)
65 | template["tracks"][0]["mainGroup"]["parameters"]["pitchDelta"]["points"] = pitch
66 |
67 | if extract_tension:
68 | tension = build_arguments(arguments, "tension", tempo)
69 | template["tracks"][0]["mainGroup"]["parameters"]["tension"]["points"] = tension
70 |
71 | if extract_breathiness:
72 | breathiness = build_arguments(arguments, "breathiness", tempo)
73 | template["tracks"][0]["mainGroup"]["parameters"]["breathiness"]["points"] = breathiness
74 |
75 | if extract_pitch and not extract_tension and not extract_breathiness:
76 | template["tracks"][0]["mainRef"]["voice"]["dF0Vbr"] = 0
77 |
78 | file_path = os.path.join("results", f"{basename}.svp")
79 | with open(file_path, "w", encoding="utf-8") as f:
80 | json.dump(template, f)
81 |
82 | return file_path
83 |
84 |
85 | def build_pitch(datas: list, arguments: list, tempo: int) -> list:
86 | pitch = [] # 用于保存的音高数据
87 | per_time = 60 / tempo # 每拍的时间
88 |
89 | for f0 in arguments:
90 | offset = f0["offset"]
91 | f0_data = f0["f0"]
92 | for i in range(len(f0_data)):
93 | pitch_onset, pitch_cents = None, None
94 | f0_value = f0_data[i] # 当前帧的f0值
95 | if f0_value == 0.0:
96 | continue
97 | onset_time = offset + i * time_per_frame # 当前帧的起始时间
98 | onset = (onset_time / per_time) * per_dur # 当前帧的起始时间在sv的时长
99 | pitch_onset = int(onset)
100 | for data in datas:
101 | if data["start"] <= onset and onset < data["finish"]: # 当前帧在音符的时间范围内
102 | pitch_cents = calculate_cents_difference(data["pitch"], f0_value)
103 | break
104 | if pitch_onset is None or pitch_cents is None:
105 | continue
106 | pitch.append(pitch_onset)
107 | pitch.append(pitch_cents)
108 | return pitch
109 |
110 |
111 | def build_arguments(arguments: dict, data: str, tempo: int, rate=1.0, argu_offset=0.0):
112 | args = []
113 | per_time = 60 / tempo # 每拍的时间
114 |
115 | for arg in arguments:
116 | offset = arg["offset"]
117 | datas = arg[data]
118 | for i in range(len(datas)):
119 | onset_time = offset + i * time_per_frame # 当前帧的起始时间
120 | onset = (onset_time / per_time) * per_dur # 当前帧的起始时间在sv的时长
121 | args.append(onset)
122 | args.append(datas[i] * rate + argu_offset)
123 | return args
124 |
125 |
126 | def calculate_cents_difference(midi_note, f0):
127 | def midi_to_freq(midi_note):
128 | A4 = 440.0
129 | return A4 * (2 ** ((midi_note - 69) / 12))
130 |
131 | def cents_difference(f0, midi_note):
132 | midi_freq = midi_to_freq(midi_note)
133 | return 1200 * math.log2(f0 / midi_freq)
134 |
135 | cents_diff = cents_difference(f0, midi_note)
136 | return round(cents_diff, 5)
137 |
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import os
3 | import yaml
4 | import json
5 | import inference
6 | import importlib
7 | import numpy as np
8 | from tqdm import tqdm
9 | from scipy.ndimage import gaussian_filter
10 |
11 | from modules.rmvpe.inference import RMVPE
12 | from utils.slicer2 import Slicer
13 | from utils.infer_utils import build_midi_file
14 | from build_svp import build_svp
15 |
16 |
17 | def load_config(config_path: str) -> dict:
18 | if config_path.endswith('.yaml'):
19 | with open(config_path, 'r', encoding='utf8') as f:
20 | config = yaml.safe_load(f)
21 | elif config_path.endswith('.json'):
22 | with open(config_path, 'r', encoding='utf8') as f:
23 | config = json.load(f)
24 | else:
25 | raise ValueError(f'Unsupported config file format: {config_path}')
26 | return config
27 |
28 |
29 | config = load_config('weights/config.yaml')
30 | sr = config['audio_sample_rate']
31 |
32 |
33 | def audio_slicer(audio_path: str) -> list:
34 | """
35 | Returns:
36 | list of dict: [{
37 | "offset": np.float64,
38 | "waveform": array of float, dtype=float32,
39 | }, ...]
40 | """
41 | waveform, _ = librosa.load(audio_path, sr=sr, mono=True)
42 | slicer = Slicer(sr=sr, max_sil_kept=1000)
43 | chunks = slicer.slice(waveform)
44 | for c in chunks:
45 | c['waveform_16k'] = librosa.resample(y=c['waveform'], orig_sr=sr, target_sr=16000)
46 | return chunks
47 |
48 |
49 | def get_midi(chunks: list, model_path: str) -> list:
50 | """
51 | Args:
52 | chunks (list): results from audio_slicer
53 |
54 | Returns:
55 | list of dict: [{
56 | "note_midi": array of float, dtype=float32,
57 | "note_dur": array of float,
58 | "note_rest": array of bool,
59 | }, ...]
60 | """
61 | infer_cls = inference.task_inference_mapping[config['task_cls']]
62 | pkg = ".".join(infer_cls.split(".")[:-1])
63 | cls_name = infer_cls.split(".")[-1]
64 | infer_cls = getattr(importlib.import_module(pkg), cls_name)
65 | assert issubclass(infer_cls, inference.BaseInference), \
66 | f'Inference class {infer_cls} is not a subclass of {inference.BaseInference}.'
67 | infer_ins = infer_cls(config=config, model_path=model_path)
68 | midis = infer_ins.infer([c['waveform'] for c in chunks])
69 | return midis
70 |
71 |
72 | def save_midi(midis: list, tempo: int, chunks: list, midi_path: str) -> None:
73 | midi_file = build_midi_file([c['offset'] for c in chunks], midis, tempo=tempo)
74 | midi_file.save(midi_path)
75 |
76 |
77 | def get_f0(chunks: list):
78 | rmvpe = RMVPE(model_path='weights/rmvpe.pt') # hop_size=160
79 | for chunk in tqdm(chunks, desc='Extracting F0'):
80 | chunk['f0'] = rmvpe.infer_from_audio(chunk['waveform_16k'], sample_rate=16000)[::2].astype(float)
81 | return chunks
82 |
83 |
84 | def get_energy_librosa(waveform, hop_size, win_size):
85 | energy = librosa.feature.rms(y=waveform, frame_length=win_size, hop_length=hop_size)[0]
86 | return energy
87 |
88 |
89 | def get_breathiness(chunks, hop_size, win_size, sigma=1.0):
90 | for chunk in tqdm(chunks, desc='Extracting Breathiness'):
91 | waveform = chunk['waveform_16k']
92 | waveform_ap = librosa.effects.percussive(waveform)
93 | breathiness = get_energy_librosa(waveform_ap, hop_size, win_size)
94 | breathiness = (2 / max(abs(np.max(breathiness)), abs(np.min(breathiness)))) * breathiness
95 | breathiness = np.tanh(breathiness - np.mean(breathiness))
96 | breathiness_smoothed = gaussian_filter(breathiness, sigma=sigma)
97 | chunk['breathiness'] = breathiness_smoothed[::2].astype(float)
98 | return chunks
99 |
100 |
101 | def get_tension(chunks, hop_size, win_size, sigma=1.0):
102 | for chunk in tqdm(chunks, desc='Extracting Tension'):
103 | waveform = chunk['waveform_16k']
104 | waveform_h = librosa.effects.harmonic(waveform)
105 | waveform_base_h = librosa.effects.harmonic(waveform, power=0.5)
106 | energy_base_h = get_energy_librosa(waveform_base_h, hop_size, win_size)
107 | energy_h = get_energy_librosa(waveform_h, hop_size, win_size)
108 | tension = np.sqrt(np.clip(energy_h ** 2 - energy_base_h ** 2, 0, None)) / (energy_h + 1e-5)
109 | tension = (2 / max(abs(np.max(tension)), abs(np.min(tension)))) * tension
110 | tension = np.tanh(tension - np.mean(tension))
111 | tension_smoothed = gaussian_filter(tension, sigma=sigma)
112 | chunk['tension'] = tension_smoothed[::2].astype(float)
113 | return chunks
114 |
115 |
116 | def get_arguments(chunks, hop_size, win_size, extract_pitch=False, extract_tension=False, extract_breathiness=False):
117 | if extract_pitch:
118 | chunks = get_f0(chunks)
119 | if extract_tension:
120 | chunks = get_tension(chunks, hop_size, win_size)
121 | if extract_breathiness:
122 | chunks = get_breathiness(chunks, hop_size, win_size)
123 | return chunks
124 |
125 |
126 | def wav2svp(audio_path, model_path, tempo=120, extract_pitch=False, extract_tension=False, extract_breathiness=False):
127 | os.makedirs('results', exist_ok=True)
128 | basename = os.path.basename(audio_path).split('.')[0]
129 |
130 | chunks = audio_slicer(audio_path)
131 | midis = get_midi(chunks, model_path)
132 | arguments = get_arguments(
133 | chunks, hop_size=160, win_size=1024,
134 | extract_pitch=extract_pitch, extract_tension=extract_tension, extract_breathiness=extract_breathiness
135 | )
136 |
137 | template = load_config('template.json')
138 |
139 | print("building svp file")
140 | svp_path = build_svp(template, midis, arguments, tempo, basename, extract_pitch, extract_tension, extract_breathiness)
141 |
142 | print("building midi file")
143 | midi_path = os.path.join('results', f'{basename}.mid')
144 | save_midi(midis, tempo, chunks, midi_path)
145 |
146 | print("Success")
147 | return svp_path, midi_path
148 |
149 |
150 | if __name__ == '__main__':
151 | import argparse
152 | parser = argparse.ArgumentParser(description='Inference for wav2svp')
153 | parser.add_argument('audio_path', type=str, help='Path to the input audio file')
154 | parser.add_argument('--model_path', type=str, default="weights/model_steps_64000_simplified.ckpt", help='Path to the model file, default: weights/model_steps_64000_simplified.ckpt')
155 | parser.add_argument('--tempo', type=float, default=120.0, help='Tempo value for the midi file, default: 120')
156 | parser.add_argument('--extract_pitch', action='store_true', help='Whether to extract pitch from the audio file, default: False')
157 | parser.add_argument('--extract_tension', action='store_true', help='Whether to extract tension from the audio file, default: False')
158 | parser.add_argument('--extract_breathiness', action='store_true', help='Whether to extract breathiness from the audio file, default: False')
159 | args = parser.parse_args()
160 |
161 | assert os.path.isfile("weights/rmvpe.pt"), "RMVPE model not found"
162 | assert os.path.isfile(args.model_path), "SOME Model not found"
163 |
164 | wav2svp(args.audio_path, args.model_path, args.tempo, args.extract_pitch, args.extract_tension, args.extract_breathiness)
--------------------------------------------------------------------------------
/inference/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_infer import BaseInference
2 | from .me_infer import MIDIExtractionInference
3 | from .me_quant_infer import QuantizedMIDIExtractionInference
4 |
5 | task_inference_mapping = {
6 | 'training.MIDIExtractionTask': 'inference.MIDIExtractionInference',
7 | 'training.QuantizedMIDIExtractionTask': 'inference.QuantizedMIDIExtractionInference',
8 | }
9 |
--------------------------------------------------------------------------------
/inference/base_infer.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from collections import OrderedDict
3 | from typing import Dict, List
4 |
5 | import numpy as np
6 | import torch
7 | import tqdm
8 | from torch import nn
9 |
10 | from utils import build_object_from_class_name
11 |
12 |
13 | class BaseInference:
14 | def __init__(self, config: dict, model_path: pathlib.Path, device=None):
15 | if device is None:
16 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
17 | self.config = config
18 | self.model_path = model_path
19 | self.device = device
20 | self.timestep = self.config['hop_size'] / self.config['audio_sample_rate']
21 | self.model: torch.nn.Module = self.build_model()
22 |
23 | def build_model(self) -> nn.Module:
24 | model: nn.Module = build_object_from_class_name(
25 | self.config['model_cls'], nn.Module, config=self.config
26 | ).eval().to(self.device)
27 | state_dict = torch.load(self.model_path, map_location=self.device, weights_only=True)['state_dict']
28 | prefix_in_ckpt = 'model'
29 | state_dict = OrderedDict({
30 | k[len(prefix_in_ckpt) + 1:]: v
31 | for k, v in state_dict.items() if k.startswith(f'{prefix_in_ckpt}.')
32 | })
33 | model.load_state_dict(state_dict, strict=True)
34 | print(f'load \'{prefix_in_ckpt}\' from \'{self.model_path}\'.')
35 | return model
36 |
37 | def preprocess(self, waveform: np.ndarray) -> Dict[str, torch.Tensor]:
38 | raise NotImplementedError()
39 |
40 | def forward_model(self, sample: Dict[str, torch.Tensor]):
41 | raise NotImplementedError()
42 |
43 | def postprocess(self, results: Dict[str, torch.Tensor]) -> List[Dict[str, np.ndarray]]:
44 | raise NotImplementedError()
45 |
46 | def infer(self, waveforms: List[np.ndarray]) -> List[Dict[str, np.ndarray]]:
47 | results = []
48 | for w in tqdm.tqdm(waveforms, desc='SOME Inference'):
49 | model_in = self.preprocess(w)
50 | model_out = self.forward_model(model_in)
51 | res = self.postprocess(model_out)
52 | results.append(res)
53 | return results
54 |
--------------------------------------------------------------------------------
/inference/me_infer.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from typing import Dict, List
3 |
4 | import numpy as np
5 | import torch
6 |
7 | import modules.rmvpe
8 | from utils.infer_utils import decode_bounds_to_alignment, decode_gaussian_blurred_probs, decode_note_sequence
9 | from .base_infer import BaseInference
10 |
11 |
12 | class MIDIExtractionInference(BaseInference):
13 | def __init__(self, config: dict, model_path: pathlib.Path, device=None):
14 | super().__init__(config, model_path, device=device)
15 | self.mel_spec = modules.rmvpe.MelSpectrogram(
16 | n_mel_channels=self.config['units_dim'], sampling_rate=self.config['audio_sample_rate'],
17 | win_length=self.config['win_size'], hop_length=self.config['hop_size'],
18 | mel_fmin=self.config['fmin'], mel_fmax=self.config['fmax']
19 | ).to(self.device)
20 | self.rmvpe = None
21 | self.midi_min = self.config['midi_min']
22 | self.midi_max = self.config['midi_max']
23 | self.midi_deviation = self.config['midi_prob_deviation']
24 | self.rest_threshold = self.config['rest_threshold']
25 |
26 | def preprocess(self, waveform: np.ndarray) -> Dict[str, torch.Tensor]:
27 | wav_tensor = torch.from_numpy(waveform).unsqueeze(0).to(self.device)
28 | units = self.mel_spec(wav_tensor).transpose(1, 2)
29 | length = units.shape[1]
30 |
31 | pitch = torch.zeros(units.shape[:2], dtype=torch.float32, device=self.device)
32 | return {
33 | 'units': units,
34 | 'pitch': pitch,
35 | 'masks': torch.ones_like(pitch, dtype=torch.bool)
36 | }
37 |
38 | @torch.no_grad()
39 | def forward_model(self, sample: Dict[str, torch.Tensor]):
40 |
41 | probs, bounds = self.model(x=sample['units'], f0=sample['pitch'], mask=sample['masks'],sig=True)
42 |
43 | return {
44 | 'probs': probs,
45 | 'bounds': bounds,
46 | 'masks': sample['masks'],
47 | }
48 |
49 | def postprocess(self, results: Dict[str, torch.Tensor]) -> List[Dict[str, np.ndarray]]:
50 | probs = results['probs']
51 | bounds = results['bounds']
52 | masks = results['masks']
53 | probs *= masks[..., None]
54 | bounds *= masks
55 | unit2note_pred = decode_bounds_to_alignment(bounds) * masks
56 | midi_pred, rest_pred = decode_gaussian_blurred_probs(
57 | probs, vmin=self.midi_min, vmax=self.midi_max,
58 | deviation=self.midi_deviation, threshold=self.rest_threshold
59 | )
60 | note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence(
61 | unit2note_pred, midi_pred, ~rest_pred & masks
62 | )
63 | note_rest_pred = ~note_mask_pred
64 | return {
65 | 'note_midi': note_midi_pred.squeeze(0).cpu().numpy(),
66 | 'note_dur': note_dur_pred.squeeze(0).cpu().numpy() * self.timestep,
67 | 'note_rest': note_rest_pred.squeeze(0).cpu().numpy()
68 | }
69 |
--------------------------------------------------------------------------------
/inference/me_quant_infer.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from utils.infer_utils import decode_bounds_to_alignment, decode_note_sequence
7 | from .me_infer import MIDIExtractionInference
8 |
9 |
10 | class QuantizedMIDIExtractionInference(MIDIExtractionInference):
11 | @torch.no_grad()
12 | def forward_model(self, sample: Dict[str, torch.Tensor]):
13 | probs, bounds = self.model(x=sample['units'], f0=sample['pitch'], mask=sample['masks'], softmax=True)
14 |
15 | return {
16 | 'probs': probs,
17 | 'bounds': bounds,
18 | 'masks': sample['masks'],
19 | }
20 |
21 | def postprocess(self, results: Dict[str, torch.Tensor]) -> List[Dict[str, np.ndarray]]:
22 | probs = results['probs']
23 | bounds = results['bounds']
24 | masks = results['masks']
25 | probs *= masks[..., None]
26 | bounds *= masks
27 | unit2note_pred = decode_bounds_to_alignment(bounds) * masks
28 | midi_pred = probs.argmax(dim=-1)
29 | rest_pred = midi_pred == 128
30 | note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence(
31 | unit2note_pred, midi_pred.clip(min=0, max=127), ~rest_pred & masks
32 | )
33 | note_rest_pred = ~note_mask_pred
34 | return {
35 | 'note_midi': note_midi_pred.squeeze(0).cpu().numpy(),
36 | 'note_dur': note_dur_pred.squeeze(0).cpu().numpy() * self.timestep,
37 | 'note_rest': note_rest_pred.squeeze(0).cpu().numpy()
38 | }
39 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SUC-DriverOld/wav2svp/60396b79431f672b6f92d8bb43912d1c861595b9/modules/__init__.py
--------------------------------------------------------------------------------
/modules/attention/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SUC-DriverOld/wav2svp/60396b79431f672b6f92d8bb43912d1c861595b9/modules/attention/__init__.py
--------------------------------------------------------------------------------
/modules/attention/base_attention.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from einops import rearrange
6 |
7 |
8 | class Attention(nn.Module):
9 | def __init__(self, dim, heads=4, dim_head=32, conditiondim=None):
10 | super().__init__()
11 | if conditiondim is None:
12 | conditiondim = dim
13 |
14 | self.scale = dim_head ** -0.5
15 | self.heads = heads
16 | hidden_dim = dim_head * heads
17 | self.to_q = nn.Linear(dim, hidden_dim, bias=False)
18 | self.to_kv = nn.Linear(conditiondim, hidden_dim * 2, bias=False)
19 |
20 | self.to_out = nn.Sequential(nn.Linear(hidden_dim, dim, ),
21 | )
22 |
23 | def forward(self, q, kv=None, mask=None):
24 | # b, c, h, w = x.shape
25 | if kv is None:
26 | kv = q
27 | # q, kv = map(
28 | # lambda t: rearrange(t, "b c t -> b t c", ), (q, kv)
29 | # )
30 |
31 | q = self.to_q(q)
32 | k, v = self.to_kv(kv).chunk(2, dim=2)
33 |
34 | q, k, v = map(
35 | lambda t: rearrange(t, "b t (h c) -> b h t c", h=self.heads), (q, k, v)
36 | )
37 |
38 | if mask is not None:
39 | mask = mask.unsqueeze(1).unsqueeze(1)
40 |
41 | with torch.backends.cuda.sdp_kernel(enable_math=False
42 | ):
43 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
44 |
45 | out = rearrange(out, "b h t c -> b t (h c) ", h=self.heads, )
46 | return self.to_out(out)
47 |
--------------------------------------------------------------------------------
/modules/conform/Gconform.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 |
8 | from modules.attention.base_attention import Attention
9 | from modules.conv.base_conv import conform_conv
10 | class GLU(nn.Module):
11 | def __init__(self, dim):
12 | super().__init__()
13 | self.dim = dim
14 |
15 | def forward(self, x):
16 | out, gate = x.chunk(2, dim=self.dim)
17 |
18 | return out * gate.sigmoid()
19 |
20 | class conform_ffn(nn.Module):
21 | def __init__(self, dim, DropoutL1: float = 0.1, DropoutL2: float = 0.1):
22 | super().__init__()
23 | self.ln1 = nn.Linear(dim, dim * 4)
24 | self.ln2 = nn.Linear(dim * 4, dim)
25 | self.drop1 = nn.Dropout(DropoutL1) if DropoutL1 > 0. else nn.Identity()
26 | self.drop2 = nn.Dropout(DropoutL2) if DropoutL2 > 0. else nn.Identity()
27 | self.act = nn.SiLU()
28 |
29 | def forward(self, x):
30 | x = self.ln1(x)
31 | x = self.act(x)
32 | x = self.drop1(x)
33 | x = self.ln2(x)
34 | return self.drop2(x)
35 |
36 |
37 | class conform_blocke(nn.Module):
38 | def __init__(self, dim: int, kernel_size: int = 31, conv_drop: float = 0.1, ffn_latent_drop: float = 0.1,
39 | ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4,
40 | attention_heads_dim: int = 64):
41 | super().__init__()
42 | self.ffn1 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop)
43 | self.ffn2 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop)
44 | self.att = Attention(dim, heads=attention_heads, dim_head=attention_heads_dim)
45 | self.attdrop = nn.Dropout(attention_drop) if attention_drop > 0. else nn.Identity()
46 | self.conv = conform_conv(dim, kernel_size=kernel_size,
47 |
48 | DropoutL=conv_drop, )
49 | self.norm1 = nn.LayerNorm(dim)
50 | self.norm2 = nn.LayerNorm(dim)
51 | self.norm3 = nn.LayerNorm(dim)
52 | self.norm4 = nn.LayerNorm(dim)
53 | self.norm5 = nn.LayerNorm(dim)
54 |
55 |
56 | def forward(self, x, mask=None,):
57 | x = self.ffn1(self.norm1(x)) * 0.5 + x
58 |
59 |
60 | x = self.attdrop(self.att(self.norm2(x), mask=mask)) + x
61 | x = self.conv(self.norm3(x)) + x
62 | x = self.ffn2(self.norm4(x)) * 0.5 + x
63 | return self.norm5(x)
64 |
65 | # return x
66 |
67 |
68 | class Gcf(nn.Module):
69 | def __init__(self,dim: int, kernel_size: int = 31, conv_drop: float = 0.1, ffn_latent_drop: float = 0.1,
70 | ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4,
71 | attention_heads_dim: int = 64):
72 | super().__init__()
73 | self.att1=conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop,
74 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads,
75 | attention_heads_dim=attention_heads_dim)
76 | self.att2 = conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop,
77 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads,
78 | attention_heads_dim=attention_heads_dim)
79 | self.glu1=nn.Sequential(nn.Linear(dim, dim*2),GLU(2) )
80 | self.glu2 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2))
81 |
82 | def forward(self, midi,bound):
83 | midi=self.att1(midi)
84 | bound=self.att2(bound)
85 | midis=self.glu1(midi)
86 | bounds=self.glu2(bound)
87 | return midi+bounds,bound+midis
88 |
89 |
90 |
91 |
92 | class Gmidi_conform(nn.Module):
93 | def __init__(self, lay: int, dim: int, indim: int, outdim: int, use_lay_skip: bool, kernel_size: int = 31,
94 | conv_drop: float = 0.1,
95 | ffn_latent_drop: float = 0.1,
96 | ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4,
97 | attention_heads_dim: int = 64):
98 | super().__init__()
99 |
100 | self.inln = nn.Linear(indim, dim)
101 | self.inln1 = nn.Linear(indim, dim)
102 | self.outln = nn.Linear(dim, outdim)
103 | self.cutheard = nn.Linear(dim, 1)
104 | # self.cutheard = nn.Linear(dim, outdim)
105 | self.lay = lay
106 | self.use_lay_skip = use_lay_skip
107 | self.cf_lay = nn.ModuleList(
108 | [Gcf(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop,
109 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads,
110 | attention_heads_dim=attention_heads_dim) for _ in range(lay)])
111 | self.att1=conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop,
112 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads,
113 | attention_heads_dim=attention_heads_dim)
114 | self.att2 = conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop,
115 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads,
116 | attention_heads_dim=attention_heads_dim)
117 |
118 |
119 | def forward(self, x, pitch, mask=None):
120 |
121 | # torch.masked_fill()
122 | x1=x.clone()
123 |
124 | x = self.inln(x )
125 | x1=self.inln1(x1)
126 | if mask is not None:
127 | x = x.masked_fill(~mask.unsqueeze(-1), 0)
128 | for idx, i in enumerate(self.cf_lay):
129 | x,x1 = i(x,x1)
130 |
131 | if mask is not None:
132 | x = x.masked_fill(~mask.unsqueeze(-1), 0)
133 | x,x1=self.att1(x),self.att2(x1)
134 |
135 | cutprp = self.cutheard(x1)
136 | midiout = self.outln(x)
137 | cutprp = torch.sigmoid(cutprp)
138 | cutprp = torch.squeeze(cutprp, -1)
139 |
140 | return midiout, cutprp
141 |
--------------------------------------------------------------------------------
/modules/conform/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SUC-DriverOld/wav2svp/60396b79431f672b6f92d8bb43912d1c861595b9/modules/conform/__init__.py
--------------------------------------------------------------------------------
/modules/conv/base_conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from einops import rearrange
5 |
6 |
7 | class GLU(nn.Module):
8 | def __init__(self, dim):
9 | super().__init__()
10 | self.dim = dim
11 |
12 | def forward(self, x):
13 | out, gate = x.chunk(2, dim=self.dim)
14 |
15 | return out * gate.sigmoid()
16 |
17 |
18 | class conform_conv(nn.Module):
19 | def __init__(self, channels: int,
20 | kernel_size: int = 31,
21 |
22 | DropoutL=0.1,
23 |
24 | bias: bool = True):
25 | super().__init__()
26 | self.act2 = nn.SiLU()
27 | self.act1 = GLU(1)
28 |
29 | self.pointwise_conv1 = nn.Conv1d(
30 | channels,
31 | 2 * channels,
32 | kernel_size=1,
33 | stride=1,
34 | padding=0,
35 | bias=bias)
36 |
37 | # self.lorder is used to distinguish if it's a causal convolution,
38 | # if self.lorder > 0:
39 | # it's a causal convolution, the input will be padded with
40 | # `self.lorder` frames on the left in forward (causal conv impl).
41 | # else: it's a symmetrical convolution
42 |
43 | assert (kernel_size - 1) % 2 == 0
44 | padding = (kernel_size - 1) // 2
45 |
46 | self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size,
47 | stride=1,
48 | padding=padding,
49 | groups=channels,
50 | bias=bias)
51 |
52 |
53 | self.norm = nn.BatchNorm1d(channels)
54 |
55 |
56 | self.pointwise_conv2 = nn.Conv1d(channels,
57 | channels,
58 | kernel_size=1,
59 | stride=1,
60 | padding=0,
61 | bias=bias)
62 | self.drop=nn.Dropout(DropoutL) if DropoutL>0. else nn.Identity()
63 | def forward(self,x):
64 | x=x.transpose(1,2)
65 | x=self.act1(self.pointwise_conv1(x))
66 | x=self.depthwise_conv (x)
67 | x=self.norm(x)
68 | x=self.act2(x)
69 | x=self.pointwise_conv2(x)
70 | return self.drop(x).transpose(1,2)
71 |
72 |
--------------------------------------------------------------------------------
/modules/model/Gmidi_conform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from modules.conform.Gconform import Gmidi_conform
5 |
6 |
7 |
8 | class midi_loss(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 | self.loss = nn.BCELoss()
12 |
13 | def forward(self, x, target):
14 | midiout, cutp = x
15 | midi_target, cutp_target = target
16 |
17 | cutploss = self.loss(cutp, cutp_target)
18 | midiloss = self.loss(midiout, midi_target)
19 | return midiloss, cutploss
20 |
21 |
22 | class midi_conforms(nn.Module):
23 | def __init__(self, config):
24 | super().__init__()
25 |
26 | cfg = config['midi_extractor_args']
27 | cfg.update({'indim': config['units_dim'], 'outdim': config['midi_num_bins']})
28 | self.model = Gmidi_conform(**cfg)
29 |
30 | def forward(self, x, f0, mask=None,softmax=False,sig=False):
31 |
32 | midi,bound=self.model(x, f0, mask)
33 | if sig:
34 | midi = torch.sigmoid(midi)
35 |
36 | if softmax:
37 | midi=F.softmax(midi,dim=2)
38 |
39 |
40 | return midi,bound
41 |
42 | def get_loss(self):
43 | return midi_loss()
44 |
--------------------------------------------------------------------------------
/modules/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SUC-DriverOld/wav2svp/60396b79431f672b6f92d8bb43912d1c861595b9/modules/model/__init__.py
--------------------------------------------------------------------------------
/modules/rmvpe/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import *
2 | from .model import E2E0
3 | from .utils import to_local_average_f0, to_viterbi_f0
4 | from .spec import MelSpectrogram
5 |
--------------------------------------------------------------------------------
/modules/rmvpe/constants.py:
--------------------------------------------------------------------------------
1 | SAMPLE_RATE = 16000
2 |
3 | N_CLASS = 360
4 |
5 | N_MELS = 128
6 | MEL_FMIN = 30
7 | MEL_FMAX = 8000
8 | WINDOW_LENGTH = 1024
9 | CONST = 1997.3794084376191
10 |
--------------------------------------------------------------------------------
/modules/rmvpe/deepunet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .constants import N_MELS
4 |
5 |
6 | class ConvBlockRes(nn.Module):
7 | def __init__(self, in_channels, out_channels, momentum=0.01):
8 | super(ConvBlockRes, self).__init__()
9 | self.conv = nn.Sequential(
10 | nn.Conv2d(in_channels=in_channels,
11 | out_channels=out_channels,
12 | kernel_size=(3, 3),
13 | stride=(1, 1),
14 | padding=(1, 1),
15 | bias=False),
16 | nn.BatchNorm2d(out_channels, momentum=momentum),
17 | nn.ReLU(),
18 |
19 | nn.Conv2d(in_channels=out_channels,
20 | out_channels=out_channels,
21 | kernel_size=(3, 3),
22 | stride=(1, 1),
23 | padding=(1, 1),
24 | bias=False),
25 | nn.BatchNorm2d(out_channels, momentum=momentum),
26 | nn.ReLU(),
27 | )
28 | if in_channels != out_channels:
29 | self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
30 | self.is_shortcut = True
31 | else:
32 | self.is_shortcut = False
33 |
34 | def forward(self, x):
35 | if self.is_shortcut:
36 | return self.conv(x) + self.shortcut(x)
37 | else:
38 | return self.conv(x) + x
39 |
40 |
41 | class ResEncoderBlock(nn.Module):
42 | def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
43 | super(ResEncoderBlock, self).__init__()
44 | self.n_blocks = n_blocks
45 | self.conv = nn.ModuleList()
46 | self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
47 | for i in range(n_blocks - 1):
48 | self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
49 | self.kernel_size = kernel_size
50 | if self.kernel_size is not None:
51 | self.pool = nn.AvgPool2d(kernel_size=kernel_size)
52 |
53 | def forward(self, x):
54 | for i in range(self.n_blocks):
55 | x = self.conv[i](x)
56 | if self.kernel_size is not None:
57 | return x, self.pool(x)
58 | else:
59 | return x
60 |
61 |
62 | class ResDecoderBlock(nn.Module):
63 | def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
64 | super(ResDecoderBlock, self).__init__()
65 | out_padding = (0, 1) if stride == (1, 2) else (1, 1)
66 | self.n_blocks = n_blocks
67 | self.conv1 = nn.Sequential(
68 | nn.ConvTranspose2d(in_channels=in_channels,
69 | out_channels=out_channels,
70 | kernel_size=(3, 3),
71 | stride=stride,
72 | padding=(1, 1),
73 | output_padding=out_padding,
74 | bias=False),
75 | nn.BatchNorm2d(out_channels, momentum=momentum),
76 | nn.ReLU(),
77 | )
78 | self.conv2 = nn.ModuleList()
79 | self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
80 | for i in range(n_blocks-1):
81 | self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
82 |
83 | def forward(self, x, concat_tensor):
84 | x = self.conv1(x)
85 | x = torch.cat((x, concat_tensor), dim=1)
86 | for i in range(self.n_blocks):
87 | x = self.conv2[i](x)
88 | return x
89 |
90 |
91 | class Encoder(nn.Module):
92 | def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
93 | super(Encoder, self).__init__()
94 | self.n_encoders = n_encoders
95 | self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
96 | self.layers = nn.ModuleList()
97 | self.latent_channels = []
98 | for i in range(self.n_encoders):
99 | self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
100 | self.latent_channels.append([out_channels, in_size])
101 | in_channels = out_channels
102 | out_channels *= 2
103 | in_size //= 2
104 | self.out_size = in_size
105 | self.out_channel = out_channels
106 |
107 | def forward(self, x):
108 | concat_tensors = []
109 | x = self.bn(x)
110 | for i in range(self.n_encoders):
111 | _, x = self.layers[i](x)
112 | concat_tensors.append(_)
113 | return x, concat_tensors
114 |
115 |
116 | class Intermediate(nn.Module):
117 | def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
118 | super(Intermediate, self).__init__()
119 | self.n_inters = n_inters
120 | self.layers = nn.ModuleList()
121 | self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
122 | for i in range(self.n_inters-1):
123 | self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
124 |
125 | def forward(self, x):
126 | for i in range(self.n_inters):
127 | x = self.layers[i](x)
128 | return x
129 |
130 |
131 | class Decoder(nn.Module):
132 | def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
133 | super(Decoder, self).__init__()
134 | self.layers = nn.ModuleList()
135 | self.n_decoders = n_decoders
136 | for i in range(self.n_decoders):
137 | out_channels = in_channels // 2
138 | self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
139 | in_channels = out_channels
140 |
141 | def forward(self, x, concat_tensors):
142 | for i in range(self.n_decoders):
143 | x = self.layers[i](x, concat_tensors[-1-i])
144 | return x
145 |
146 |
147 | class TimbreFilter(nn.Module):
148 | def __init__(self, latent_rep_channels):
149 | super(TimbreFilter, self).__init__()
150 | self.layers = nn.ModuleList()
151 | for latent_rep in latent_rep_channels:
152 | self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0]))
153 |
154 | def forward(self, x_tensors):
155 | out_tensors = []
156 | for i, layer in enumerate(self.layers):
157 | out_tensors.append(layer(x_tensors[i]))
158 | return out_tensors
159 |
160 |
161 | class DeepUnet0(nn.Module):
162 | def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
163 | super(DeepUnet0, self).__init__()
164 | self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels)
165 | self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
166 | self.tf = TimbreFilter(self.encoder.latent_channels)
167 | self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
168 |
169 | def forward(self, x):
170 | x, concat_tensors = self.encoder(x)
171 | x = self.intermediate(x)
172 | x = self.decoder(x, concat_tensors)
173 | return x
174 |
--------------------------------------------------------------------------------
/modules/rmvpe/inference.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torchaudio.transforms import Resample
5 |
6 | from utils.pitch_utils import interp_f0, resample_align_curve
7 | from .constants import *
8 | from .model import E2E0
9 | from .spec import MelSpectrogram
10 | from .utils import to_local_average_f0, to_viterbi_f0
11 |
12 |
13 | class RMVPE:
14 | def __init__(self, model_path, hop_length=160, device=None):
15 | self.resample_kernel = {}
16 | if device is None:
17 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
18 | else:
19 | self.device = device
20 | self.model = E2E0(4, 1, (2, 2)).eval().to(self.device)
21 | ckpt = torch.load(model_path, map_location=self.device, weights_only=True)
22 | self.model.load_state_dict(ckpt['model'], strict=False)
23 | self.mel_extractor = MelSpectrogram(
24 | N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX
25 | ).to(self.device)
26 |
27 | @torch.no_grad()
28 | def mel2hidden(self, mel):
29 | n_frames = mel.shape[-1]
30 | mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant')
31 | hidden = self.model(mel)
32 | return hidden[:, :n_frames]
33 |
34 | def decode(self, hidden, thred=0.03, use_viterbi=False):
35 | if use_viterbi:
36 | f0 = to_viterbi_f0(hidden, thred=thred)
37 | else:
38 | f0 = to_local_average_f0(hidden, thred=thred)
39 | return f0
40 |
41 | def infer_from_audio(self, audio, sample_rate=16000, thred=0.03, use_viterbi=False):
42 | audio = torch.from_numpy(audio).float().unsqueeze(0).to(self.device)
43 | if sample_rate == 16000:
44 | audio_res = audio
45 | else:
46 | key_str = str(sample_rate)
47 | if key_str not in self.resample_kernel:
48 | self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128)
49 | self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.device)
50 | audio_res = self.resample_kernel[key_str](audio)
51 | mel = self.mel_extractor(audio_res, center=True)
52 | hidden = self.mel2hidden(mel)
53 | f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi)
54 | return f0
55 |
56 | def get_pitch(self, waveform, sample_rate, hop_size, length, interp_uv=False):
57 | f0 = self.infer_from_audio(waveform, sample_rate=sample_rate)
58 | uv = f0 == 0
59 | f0, uv = interp_f0(f0, uv)
60 |
61 | time_step = hop_size / sample_rate
62 | f0_res = resample_align_curve(f0, 0.01, time_step, length)
63 | uv_res = resample_align_curve(uv.astype(np.float32), 0.01, time_step, length) > 0.5
64 | if not interp_uv:
65 | f0_res[uv_res] = 0
66 | return f0_res, uv_res
67 |
--------------------------------------------------------------------------------
/modules/rmvpe/model.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from .constants import *
4 | from .deepunet import DeepUnet0
5 | from .seq import BiGRU
6 |
7 |
8 | class E2E0(nn.Module):
9 | def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1,
10 | en_out_channels=16):
11 | super(E2E0, self).__init__()
12 | self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
13 | self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
14 | if n_gru:
15 | self.fc = nn.Sequential(
16 | BiGRU(3 * N_MELS, 256, n_gru),
17 | nn.Linear(512, N_CLASS),
18 | nn.Dropout(0.25),
19 | nn.Sigmoid()
20 | )
21 | else:
22 | self.fc = nn.Sequential(
23 | nn.Linear(3 * N_MELS, N_CLASS),
24 | nn.Dropout(0.25),
25 | nn.Sigmoid()
26 | )
27 |
28 | def forward(self, mel):
29 | mel = mel.transpose(-1, -2).unsqueeze(1)
30 | x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
31 | x = self.fc(x)
32 | return x
33 |
--------------------------------------------------------------------------------
/modules/rmvpe/seq.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class BiGRU(nn.Module):
5 | def __init__(self, input_features, hidden_features, num_layers):
6 | super(BiGRU, self).__init__()
7 | self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
8 |
9 | def forward(self, x):
10 | return self.gru(x)[0]
11 |
--------------------------------------------------------------------------------
/modules/rmvpe/spec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 | from librosa.filters import mel
5 |
6 |
7 | class MelSpectrogram(torch.nn.Module):
8 | def __init__(
9 | self,
10 | n_mel_channels,
11 | sampling_rate,
12 | win_length,
13 | hop_length,
14 | n_fft=None,
15 | mel_fmin=0,
16 | mel_fmax=None,
17 | clamp=1e-5
18 | ):
19 | super().__init__()
20 | n_fft = win_length if n_fft is None else n_fft
21 | self.hann_window = {}
22 | mel_basis = mel(
23 | sr=sampling_rate,
24 | n_fft=n_fft,
25 | n_mels=n_mel_channels,
26 | fmin=mel_fmin,
27 | fmax=mel_fmax,
28 | htk=True)
29 | mel_basis = torch.from_numpy(mel_basis).float()
30 | self.register_buffer("mel_basis", mel_basis)
31 | self.n_fft = win_length if n_fft is None else n_fft
32 | self.hop_length = hop_length
33 | self.win_length = win_length
34 | self.sampling_rate = sampling_rate
35 | self.n_mel_channels = n_mel_channels
36 | self.clamp = clamp
37 |
38 | def forward(self, audio, keyshift=0, speed=1, center=True):
39 | factor = 2 ** (keyshift / 12)
40 | n_fft_new = int(np.round(self.n_fft * factor))
41 | win_length_new = int(np.round(self.win_length * factor))
42 | hop_length_new = int(np.round(self.hop_length * speed))
43 |
44 | keyshift_key = str(keyshift) + '_' + str(audio.device)
45 | if keyshift_key not in self.hann_window:
46 | self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
47 | if center:
48 | pad_left = win_length_new // 2
49 | pad_right = (win_length_new + 1) // 2
50 | audio = F.pad(audio, (pad_left, pad_right))
51 |
52 | fft = torch.stft(
53 | audio,
54 | n_fft=n_fft_new,
55 | hop_length=hop_length_new,
56 | win_length=win_length_new,
57 | window=self.hann_window[keyshift_key],
58 | center=False,
59 | return_complex=True
60 | )
61 | magnitude = fft.abs()
62 |
63 | if keyshift != 0:
64 | size = self.n_fft // 2 + 1
65 | resize = magnitude.size(1)
66 | if resize < size:
67 | magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
68 | magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
69 |
70 | mel_output = torch.matmul(self.mel_basis, magnitude)
71 | log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
72 | return log_mel_spec
73 |
--------------------------------------------------------------------------------
/modules/rmvpe/utils.py:
--------------------------------------------------------------------------------
1 | import librosa
2 | import numpy as np
3 | import torch
4 |
5 | from .constants import *
6 |
7 |
8 | def to_local_average_f0(hidden, center=None, thred=0.03):
9 | idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N]
10 | idx_cents = idx * 20 + CONST # [B=1, N]
11 | if center is None:
12 | center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1]
13 | start = torch.clip(center - 4, min=0) # [B, T, 1]
14 | end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1]
15 | idx_mask = (idx >= start) & (idx < end) # [B, T, N]
16 | weights = hidden * idx_mask # [B, T, N]
17 | product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T]
18 | weight_sum = torch.sum(weights, dim=2) # [B, T]
19 | cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T]
20 | f0 = 10 * 2 ** (cents / 1200)
21 | uv = hidden.max(dim=2)[0] < thred # [B, T]
22 | f0 = f0 * ~uv
23 | return f0.squeeze(0).cpu().numpy()
24 |
25 |
26 | def to_viterbi_f0(hidden, thred=0.03):
27 | # Create viterbi transition matrix
28 | if not hasattr(to_viterbi_f0, 'transition'):
29 | xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS))
30 | transition = np.maximum(30 - abs(xx - yy), 0)
31 | transition = transition / transition.sum(axis=1, keepdims=True)
32 | to_viterbi_f0.transition = transition
33 |
34 | # Convert to probability
35 | prob = hidden.squeeze(0).cpu().numpy()
36 | prob = prob.T
37 | prob = prob / prob.sum(axis=0)
38 |
39 | # Perform viterbi decoding
40 | path = librosa.sequence.viterbi(prob, to_viterbi_f0.transition).astype(np.int64)
41 | center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device)
42 |
43 | return to_local_average_f0(hidden, center=center, thred=thred)
44 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # It's recommand to download and install torch manually.
2 | # pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
3 |
4 | gradio
5 | librosa
6 | mido
7 | einops
--------------------------------------------------------------------------------
/template.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": 153,
3 | "time": {
4 | "meter": [
5 | {
6 | "index": 0,
7 | "numerator": 4,
8 | "denominator": 4
9 | }
10 | ],
11 | "tempo": [
12 | {
13 | "position": 0,
14 | "bpm": 120
15 | }
16 | ]
17 | },
18 | "library": [],
19 | "tracks": [
20 | {
21 | "name": "未命名音轨",
22 | "dispColor": "ff7db235",
23 | "dispOrder": 0,
24 | "renderEnabled": false,
25 | "mixer": {
26 | "gainDecibel": 0,
27 | "pan": 0,
28 | "mute": false,
29 | "solo": false,
30 | "display": true
31 | },
32 | "mainGroup": {
33 | "name": "main",
34 | "uuid": "a6cbc361-ef84-4518-8da4-4ac7640aae8c",
35 | "parameters": {
36 | "pitchDelta": {
37 | "mode": "cubic",
38 | "points": []
39 | },
40 | "vibratoEnv": {
41 | "mode": "cubic",
42 | "points": []
43 | },
44 | "loudness": {
45 | "mode": "cubic",
46 | "points": []
47 | },
48 | "tension": {
49 | "mode": "cubic",
50 | "points": []
51 | },
52 | "breathiness": {
53 | "mode": "cubic",
54 | "points": []
55 | },
56 | "voicing": {
57 | "mode": "cubic",
58 | "points": []
59 | },
60 | "gender": {
61 | "mode": "cubic",
62 | "points": []
63 | },
64 | "toneShift": {
65 | "mode": "cubic",
66 | "points": []
67 | }
68 | },
69 | "vocalModes": {},
70 | "notes": [
71 | {
72 | "musicalType": "singing",
73 | "onset": 0,
74 | "duration": 705600000,
75 | "lyrics": "la",
76 | "phonemes": "",
77 | "accent": "",
78 | "pitch": 60,
79 | "detune": 0,
80 | "instantMode": false,
81 | "attributes": {
82 | "evenSyllableDuration": true
83 | },
84 | "systemAttributes": {
85 | "evenSyllableDuration": true
86 | },
87 | "pitchTakes": {
88 | "activeTakeId": 0,
89 | "takes": [
90 | {
91 | "id": 0,
92 | "expr": 0,
93 | "liked": false
94 | }
95 | ]
96 | },
97 | "timbreTakes": {
98 | "activeTakeId": 0,
99 | "takes": [
100 | {
101 | "id": 0,
102 | "expr": 0,
103 | "liked": false
104 | }
105 | ]
106 | }
107 | }
108 | ]
109 | },
110 | "mainRef": {
111 | "groupID": "a6cbc361-ef84-4518-8da4-4ac7640aae8c",
112 | "blickAbsoluteBegin": 0,
113 | "blickAbsoluteEnd": -1,
114 | "blickOffset": 0,
115 | "pitchOffset": 0,
116 | "isInstrumental": false,
117 | "systemPitchDelta": {
118 | "mode": "cubic",
119 | "points": []
120 | },
121 | "database": {
122 | "name": "",
123 | "language": "",
124 | "phoneset": "",
125 | "languageOverride": "",
126 | "phonesetOverride": "",
127 | "backendType": "",
128 | "version": "-2"
129 | },
130 | "dictionary": "",
131 | "voice": {
132 | "vocalModeInherited": true,
133 | "vocalModePreset": "",
134 | "vocalModeParams": {}
135 | },
136 | "pitchTakes": {
137 | "activeTakeId": 0,
138 | "takes": [
139 | {
140 | "id": 0,
141 | "expr": 0,
142 | "liked": false
143 | }
144 | ]
145 | },
146 | "timbreTakes": {
147 | "activeTakeId": 0,
148 | "takes": [
149 | {
150 | "id": 0,
151 | "expr": 0,
152 | "liked": false
153 | }
154 | ]
155 | }
156 | },
157 | "groups": []
158 | }
159 | ],
160 | "renderConfig": {
161 | "destination": "",
162 | "filename": "未命名",
163 | "numChannels": 1,
164 | "aspirationFormat": "noAspiration",
165 | "bitDepth": 16,
166 | "sampleRate": 44100,
167 | "exportMixDown": true,
168 | "exportPitch": false
169 | }
170 | }
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import pathlib
4 | import re
5 | import types
6 | from collections import OrderedDict
7 |
8 | import numpy as np
9 | import torch
10 |
11 | def tensors_to_scalars(metrics):
12 | new_metrics = {}
13 | for k, v in metrics.items():
14 | if isinstance(v, torch.Tensor):
15 | v = v.item()
16 | if type(v) is dict:
17 | v = tensors_to_scalars(v)
18 | new_metrics[k] = v
19 | return new_metrics
20 |
21 |
22 | def collate_nd(values, pad_value=0, max_len=None):
23 | """
24 | Pad a list of Nd tensors on their first dimension and stack them into a (N+1)d tensor.
25 | """
26 | size = ((max(v.size(0) for v in values) if max_len is None else max_len), *values[0].shape[1:])
27 | res = torch.full((len(values), *size), fill_value=pad_value, dtype=values[0].dtype, device=values[0].device)
28 |
29 | for i, v in enumerate(values):
30 | res[i, :len(v), ...] = v
31 | return res
32 |
33 |
34 | def random_continuous_masks(*shape: int, dim: int, device: str | torch.device = 'cpu'):
35 | start, end = torch.sort(
36 | torch.randint(
37 | low=0, high=shape[dim] + 1, size=(*shape[:dim], 2, *((1,) * (len(shape) - dim - 1))), device=device
38 | ).expand(*((-1,) * (dim + 1)), *shape[dim + 1:]), dim=dim
39 | )[0].split(1, dim=dim)
40 | idx = torch.arange(
41 | 0, shape[dim], dtype=torch.long, device=device
42 | ).reshape(*((1,) * dim), shape[dim], *((1,) * (len(shape) - dim - 1)))
43 | masks = (idx >= start) & (idx < end)
44 | return masks
45 |
46 |
47 | def _is_batch_full(batch, num_frames, max_batch_frames, max_batch_size):
48 | if len(batch) == 0:
49 | return 0
50 | if len(batch) == max_batch_size:
51 | return 1
52 | if num_frames > max_batch_frames:
53 | return 1
54 | return 0
55 |
56 |
57 | def batch_by_size(
58 | indices, num_frames_fn, max_batch_frames=80000, max_batch_size=48,
59 | required_batch_size_multiple=1
60 | ):
61 | """
62 | Yield mini-batches of indices bucketed by size. Batches may contain
63 | sequences of different lengths.
64 |
65 | Args:
66 | indices (List[int]): ordered list of dataset indices
67 | num_frames_fn (callable): function that returns the number of frames at
68 | a given index
69 | max_batch_frames (int, optional): max number of frames in each batch
70 | (default: 80000).
71 | max_batch_size (int, optional): max number of sentences in each
72 | batch (default: 48).
73 | required_batch_size_multiple: require the batch size to be multiple
74 | of a given number
75 | """
76 | bsz_mult = required_batch_size_multiple
77 |
78 | if isinstance(indices, types.GeneratorType):
79 | indices = np.fromiter(indices, dtype=np.int64, count=-1)
80 |
81 | sample_len = 0
82 | sample_lens = []
83 | batch = []
84 | batches = []
85 | for i in range(len(indices)):
86 | idx = indices[i]
87 | num_frames = num_frames_fn(idx)
88 | sample_lens.append(num_frames)
89 | sample_len = max(sample_len, num_frames)
90 | assert sample_len <= max_batch_frames, (
91 | "sentence at index {} of size {} exceeds max_batch_samples "
92 | "limit of {}!".format(idx, sample_len, max_batch_frames)
93 | )
94 | num_frames = (len(batch) + 1) * sample_len
95 |
96 | if _is_batch_full(batch, num_frames, max_batch_frames, max_batch_size):
97 | mod_len = max(
98 | bsz_mult * (len(batch) // bsz_mult),
99 | len(batch) % bsz_mult,
100 | )
101 | batches.append(batch[:mod_len])
102 | batch = batch[mod_len:]
103 | sample_lens = sample_lens[mod_len:]
104 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
105 | batch.append(idx)
106 | if len(batch) > 0:
107 | batches.append(batch)
108 | return batches
109 |
110 |
111 | def unpack_dict_to_list(samples):
112 | samples_ = []
113 | bsz = samples.get('outputs').size(0)
114 | for i in range(bsz):
115 | res = {}
116 | for k, v in samples.items():
117 | try:
118 | res[k] = v[i]
119 | except:
120 | pass
121 | samples_.append(res)
122 | return samples_
123 |
124 |
125 | def filter_kwargs(dict_to_filter, kwarg_obj):
126 | import inspect
127 |
128 | sig = inspect.signature(kwarg_obj)
129 | filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD]
130 | filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if
131 | filter_key in dict_to_filter}
132 | return filtered_dict
133 |
134 |
135 | def load_ckpt(
136 | cur_model, ckpt_base_dir, ckpt_steps=None,
137 | prefix_in_ckpt='model', key_in_ckpt='state_dict',
138 | strict=True, device='cpu'
139 | ):
140 | if not isinstance(ckpt_base_dir, pathlib.Path):
141 | ckpt_base_dir = pathlib.Path(ckpt_base_dir)
142 | if ckpt_base_dir.is_file():
143 | checkpoint_path = [ckpt_base_dir]
144 | elif ckpt_steps is not None:
145 | checkpoint_path = [ckpt_base_dir / f'model_ckpt_steps_{int(ckpt_steps)}.ckpt']
146 | else:
147 | base_dir = ckpt_base_dir
148 | checkpoint_path = sorted(
149 | [
150 | ckpt_file
151 | for ckpt_file in base_dir.iterdir()
152 | if ckpt_file.is_file() and re.fullmatch(r'model_ckpt_steps_\d+\.ckpt', ckpt_file.name)
153 | ],
154 | key=lambda x: int(re.search(r'\d+', x.name).group(0))
155 | )
156 | assert len(checkpoint_path) > 0, f'ckpt not found in {ckpt_base_dir}.'
157 | checkpoint_path = checkpoint_path[-1]
158 | ckpt_loaded = torch.load(checkpoint_path, map_location=device)
159 | if key_in_ckpt is None:
160 | state_dict = ckpt_loaded
161 | else:
162 | state_dict = ckpt_loaded[key_in_ckpt]
163 | if prefix_in_ckpt is not None:
164 | state_dict = OrderedDict({
165 | k[len(prefix_in_ckpt) + 1:]: v
166 | for k, v in state_dict.items() if k.startswith(f'{prefix_in_ckpt}.')
167 | })
168 | if not strict:
169 | cur_model_state_dict = cur_model.state_dict()
170 | unmatched_keys = []
171 | for key, param in state_dict.items():
172 | if key in cur_model_state_dict:
173 | new_param = cur_model_state_dict[key]
174 | if new_param.shape != param.shape:
175 | unmatched_keys.append(key)
176 | print('Unmatched keys: ', key, new_param.shape, param.shape)
177 | for key in unmatched_keys:
178 | del state_dict[key]
179 | cur_model.load_state_dict(state_dict, strict=strict)
180 | shown_model_name = 'state dict'
181 | if prefix_in_ckpt is not None:
182 | shown_model_name = f'\'{prefix_in_ckpt}\''
183 | elif key_in_ckpt is not None:
184 | shown_model_name = f'\'{key_in_ckpt}\''
185 | print(f'load {shown_model_name} from \'{checkpoint_path}\'.')
186 |
187 |
188 | def remove_padding(x, padding_idx=0):
189 | if x is None:
190 | return None
191 | assert len(x.shape) in [1, 2]
192 | if len(x.shape) == 2: # [T, H]
193 | return x[np.abs(x).sum(-1) != padding_idx]
194 | elif len(x.shape) == 1: # [T]
195 | return x[x != padding_idx]
196 |
197 |
198 | def print_arch(model, model_name='model'):
199 | print(f"{model_name} Arch: ", model)
200 | # num_params(model, model_name=model_name)
201 |
202 |
203 | def num_params(model, print_out=True, model_name="model"):
204 | parameters = filter(lambda p: p.requires_grad, model.parameters())
205 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
206 | if print_out:
207 | print(f'{model_name} Trainable Parameters: %.3fM' % parameters)
208 | return parameters
209 |
210 |
211 | def build_object_from_class_name(cls_str, parent_cls, *args, **kwargs):
212 | import importlib
213 |
214 | pkg = ".".join(cls_str.split(".")[:-1])
215 | cls_name = cls_str.split(".")[-1]
216 | cls_type = getattr(importlib.import_module(pkg), cls_name)
217 | if parent_cls is not None:
218 | assert issubclass(cls_type, parent_cls), f'{cls_type} is not subclass of {parent_cls}.'
219 |
220 | return cls_type(*args, **filter_kwargs(kwargs, cls_type))
221 |
222 |
223 | def build_lr_scheduler_from_config(optimizer, scheduler_args):
224 | try:
225 | # PyTorch 2.0+
226 | from torch.optim.lr_scheduler import LRScheduler as LRScheduler
227 | except ImportError:
228 | # PyTorch 1.X
229 | from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
230 |
231 | def helper(params):
232 | if isinstance(params, list):
233 | return [helper(s) for s in params]
234 | elif isinstance(params, dict):
235 | resolved = {k: helper(v) for k, v in params.items()}
236 | if 'cls' in resolved:
237 | if (
238 | resolved["cls"] == "torch.optim.lr_scheduler.ChainedScheduler"
239 | and scheduler_args["scheduler_cls"] == "torch.optim.lr_scheduler.SequentialLR"
240 | ):
241 | raise ValueError(f"ChainedScheduler cannot be part of a SequentialLR.")
242 | resolved['optimizer'] = optimizer
243 | obj = build_object_from_class_name(
244 | resolved['cls'],
245 | LRScheduler,
246 | **resolved
247 | )
248 | return obj
249 | return resolved
250 | else:
251 | return params
252 |
253 | resolved = helper(scheduler_args)
254 | resolved['optimizer'] = optimizer
255 | return build_object_from_class_name(
256 | scheduler_args['scheduler_cls'],
257 | LRScheduler,
258 | **resolved
259 | )
260 |
261 |
262 | def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_groups=1):
263 | optimizer = build_object_from_class_name(
264 | optimizer_args['optimizer_cls'],
265 | torch.optim.Optimizer,
266 | [{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)],
267 | **optimizer_args
268 | )
269 | scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args)
270 | scheduler.optimizer._step_count = 1
271 | for _ in range(step_count):
272 | scheduler.step()
273 | return scheduler.state_dict()
274 |
275 |
276 | def remove_suffix(string: str, suffix: str):
277 | # Just for Python 3.8 compatibility, since `str.removesuffix()` API of is available since Python 3.9
278 | if string.endswith(suffix):
279 | string = string[:-len(suffix)]
280 | return string
281 |
--------------------------------------------------------------------------------
/utils/infer_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Dict
2 |
3 | import mido
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def decode_gaussian_blurred_probs(probs, vmin, vmax, deviation, threshold):
10 | num_bins = probs.shape[-1]
11 | interval = (vmax - vmin) / (num_bins - 1)
12 | width = int(3 * deviation / interval) # 3 * sigma
13 | idx = torch.arange(num_bins, device=probs.device)[None, None, :] # [1, 1, N]
14 | idx_values = idx * interval + vmin
15 | center = torch.argmax(probs, dim=-1, keepdim=True) # [B, T, 1]
16 | start = torch.clip(center - width, min=0) # [B, T, 1]
17 | end = torch.clip(center + width + 1, max=num_bins) # [B, T, 1]
18 | idx_masks = (idx >= start) & (idx < end) # [B, T, N]
19 | weights = probs * idx_masks # [B, T, N]
20 | product_sum = torch.sum(weights * idx_values, dim=2) # [B, T]
21 | weight_sum = torch.sum(weights, dim=2) # [B, T]
22 | values = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T]
23 | rest = probs.max(dim=-1)[0] < threshold # [B, T]
24 | return values, rest
25 |
26 |
27 | def decode_bounds_to_alignment(bounds):
28 | bounds_step = bounds.cumsum(dim=1).round().long()
29 | bounds_inc = torch.diff(
30 | bounds_step, dim=1, prepend=torch.full(
31 | (bounds.shape[0], 1), fill_value=-1,
32 | dtype=bounds_step.dtype, device=bounds_step.device
33 | )
34 | ) > 0
35 | frame2item = bounds_inc.long().cumsum(dim=1)
36 | return frame2item
37 |
38 |
39 | def decode_note_sequence(frame2item, values, masks, threshold=0.5):
40 | """
41 |
42 | :param frame2item: [1, 1, 1, 1, 2, 2, 3, 3, 3]
43 | :param values:
44 | :param masks:
45 | :param threshold: minimum ratio of unmasked frames required to be regarded as an unmasked item
46 | :return: item_values, item_dur, item_masks
47 | """
48 | b = frame2item.shape[0]
49 | space = frame2item.max() + 1
50 |
51 | item_dur = frame2item.new_zeros(b, space).scatter_add(
52 | 1, frame2item, torch.ones_like(frame2item)
53 | )[:, 1:]
54 | item_unmasked_dur = frame2item.new_zeros(b, space).scatter_add(
55 | 1, frame2item, masks.long()
56 | )[:, 1:]
57 | item_masks = item_unmasked_dur / item_dur >= threshold
58 |
59 | values_quant = values.round().long()
60 | histogram = frame2item.new_zeros(b, space * 128).scatter_add(
61 | 1, frame2item * 128 + values_quant, torch.ones_like(frame2item) * masks
62 | ).unflatten(1, [space, 128])[:, 1:, :]
63 | item_values_center = histogram.argmax(dim=2).to(dtype=values.dtype)
64 | values_center = torch.gather(F.pad(item_values_center, [1, 0]), 1, frame2item)
65 | values_near_center = masks & (values >= values_center - 0.5) & (values <= values_center + 0.5)
66 | item_valid_dur = frame2item.new_zeros(b, space).scatter_add(
67 | 1, frame2item, values_near_center.long()
68 | )[:, 1:]
69 | item_values = values.new_zeros(b, space).scatter_add(
70 | 1, frame2item, values * values_near_center
71 | )[:, 1:] / (item_valid_dur + (item_valid_dur == 0))
72 |
73 | return item_values, item_dur, item_masks
74 |
75 |
76 | def build_midi_file(offsets: List[float], segments: List[Dict[str, np.ndarray]], tempo=120) -> mido.MidiFile:
77 | midi_file = mido.MidiFile(charset='utf8')
78 | midi_track = mido.MidiTrack()
79 | midi_track.append(mido.MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0))
80 | last_time = 0
81 | offsets = [round(o * tempo * 8) for o in offsets]
82 | for i, (offset, segment) in enumerate(zip(offsets, segments)):
83 | note_midi = np.round(segment['note_midi']).astype(np.int64).tolist()
84 | note_tick = np.diff(np.round(np.cumsum(segment['note_dur']) * tempo * 8).astype(np.int64), prepend=0).tolist()
85 | note_rest = segment['note_rest'].tolist()
86 | start = offset
87 | for j in range(len(note_midi)):
88 | end = start + note_tick[j]
89 | if i < len(offsets) - 1 and end > offsets[i + 1]:
90 | end = offsets[i + 1]
91 | if start < end and not note_rest[j]:
92 | midi_track.append(mido.Message('note_on', note=note_midi[j], time=start - last_time))
93 | midi_track.append(mido.Message('note_off', note=note_midi[j], time=end - start))
94 | last_time = end
95 | start = end
96 | midi_file.tracks.append(midi_track)
97 | return midi_file
98 |
--------------------------------------------------------------------------------
/utils/pitch_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | f0_bin = 256
5 | f0_max = 1100.0
6 | f0_min = 50.0
7 | f0_mel_min = 1127 * np.log(1 + f0_min / 700)
8 | f0_mel_max = 1127 * np.log(1 + f0_max / 700)
9 |
10 |
11 | def f0_to_coarse(f0):
12 | is_torch = isinstance(f0, torch.Tensor)
13 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
14 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
15 |
16 | f0_mel[f0_mel <= 1] = 1
17 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
18 | f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int)
19 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
20 | return f0_coarse
21 |
22 |
23 | def norm_f0(f0, uv=None):
24 | if uv is None:
25 | uv = f0 == 0
26 | f0 = np.log2(f0 + uv) # avoid arithmetic error
27 | f0[uv] = -np.inf
28 | return f0
29 |
30 |
31 | def interp_f0(f0, uv=None):
32 | if uv is None:
33 | uv = f0 == 0
34 | f0 = norm_f0(f0, uv)
35 | if uv.any() and not uv.all():
36 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
37 | return denorm_f0(f0, uv=None), uv
38 |
39 |
40 | def denorm_f0(f0, uv, pitch_padding=None):
41 | f0 = 2 ** f0
42 | if uv is not None:
43 | f0[uv > 0] = 0
44 | if pitch_padding is not None:
45 | f0[pitch_padding] = 0
46 | return f0
47 |
48 |
49 | def resample_align_curve(points: np.ndarray, original_timestep: float, target_timestep: float, align_length: int):
50 | t_max = (len(points) - 1) * original_timestep
51 | curve_interp = np.interp(
52 | np.arange(0, t_max, target_timestep),
53 | original_timestep * np.arange(len(points)),
54 | points
55 | ).astype(points.dtype)
56 | delta_l = align_length - len(curve_interp)
57 | if delta_l < 0:
58 | curve_interp = curve_interp[:align_length]
59 | elif delta_l > 0:
60 | curve_interp = np.concatenate((curve_interp, np.full(delta_l, fill_value=curve_interp[-1])), axis=0)
61 | return curve_interp
62 |
--------------------------------------------------------------------------------
/utils/slicer2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | # This function is obtained from librosa.
5 | def get_rms(
6 | y,
7 | *,
8 | frame_length=2048,
9 | hop_length=512,
10 | pad_mode="constant",
11 | ):
12 | padding = (int(frame_length // 2), int(frame_length // 2))
13 | y = np.pad(y, padding, mode=pad_mode)
14 |
15 | axis = -1
16 | # put our new within-frame axis at the end for now
17 | out_strides = y.strides + tuple([y.strides[axis]])
18 | # Reduce the shape on the framing axis
19 | x_shape_trimmed = list(y.shape)
20 | x_shape_trimmed[axis] -= frame_length - 1
21 | out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
22 | xw = np.lib.stride_tricks.as_strided(
23 | y, shape=out_shape, strides=out_strides
24 | )
25 | if axis < 0:
26 | target_axis = axis - 1
27 | else:
28 | target_axis = axis + 1
29 | xw = np.moveaxis(xw, -1, target_axis)
30 | # Downsample along the target axis
31 | slices = [slice(None)] * xw.ndim
32 | slices[axis] = slice(0, None, hop_length)
33 | x = xw[tuple(slices)]
34 |
35 | # Calculate power
36 | power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
37 |
38 | return np.sqrt(power)
39 |
40 |
41 | class Slicer:
42 | def __init__(self,
43 | sr: int,
44 | threshold: float = -40.,
45 | min_length: int = 5000,
46 | min_interval: int = 300,
47 | hop_size: int = 20,
48 | max_sil_kept: int = 5000):
49 | if not min_length >= min_interval >= hop_size:
50 | raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size')
51 | if not max_sil_kept >= hop_size:
52 | raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size')
53 | min_interval = sr * min_interval / 1000
54 | self.sr = sr
55 | self.threshold = 10 ** (threshold / 20.)
56 | self.hop_size = round(sr * hop_size / 1000)
57 | self.win_size = min(round(min_interval), 4 * self.hop_size)
58 | self.min_length = round(sr * min_length / 1000 / self.hop_size)
59 | self.min_interval = round(min_interval / self.hop_size)
60 | self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
61 |
62 | def _apply_slice(self, waveform, begin, end):
63 | chunk = {
64 | 'offset': begin * self.hop_size / self.sr
65 | }
66 | if len(waveform.shape) > 1:
67 | chunk['waveform'] = waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)]
68 | else:
69 | chunk['waveform'] = waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)]
70 | return chunk
71 |
72 | # @timeit
73 | def slice(self, waveform):
74 | if len(waveform.shape) > 1:
75 | samples = waveform.mean(axis=0)
76 | else:
77 | samples = waveform
78 | if (samples.shape[0] + self.hop_size - 1) // self.hop_size <= self.min_length:
79 | return [{'offset': 0, 'waveform': waveform}]
80 | rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
81 | sil_tags = []
82 | silence_start = None
83 | clip_start = 0
84 | for i, rms in enumerate(rms_list):
85 | # Keep looping while frame is silent.
86 | if rms < self.threshold:
87 | # Record start of silent frames.
88 | if silence_start is None:
89 | silence_start = i
90 | continue
91 | # Keep looping while frame is not silent and silence start has not been recorded.
92 | if silence_start is None:
93 | continue
94 | # Clear recorded silence start if interval is not enough or clip is too short
95 | is_leading_silence = silence_start == 0 and i > self.max_sil_kept
96 | need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
97 | if not is_leading_silence and not need_slice_middle:
98 | silence_start = None
99 | continue
100 | # Need slicing. Record the range of silent frames to be removed.
101 | if i - silence_start <= self.max_sil_kept:
102 | pos = rms_list[silence_start: i + 1].argmin() + silence_start
103 | if silence_start == 0:
104 | sil_tags.append((0, pos))
105 | else:
106 | sil_tags.append((pos, pos))
107 | clip_start = pos
108 | elif i - silence_start <= self.max_sil_kept * 2:
109 | pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin()
110 | pos += i - self.max_sil_kept
111 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
112 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
113 | if silence_start == 0:
114 | sil_tags.append((0, pos_r))
115 | clip_start = pos_r
116 | else:
117 | sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
118 | clip_start = max(pos_r, pos)
119 | else:
120 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start
121 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept
122 | if silence_start == 0:
123 | sil_tags.append((0, pos_r))
124 | else:
125 | sil_tags.append((pos_l, pos_r))
126 | clip_start = pos_r
127 | silence_start = None
128 | # Deal with trailing silence.
129 | total_frames = rms_list.shape[0]
130 | if silence_start is not None and total_frames - silence_start >= self.min_interval:
131 | silence_end = min(total_frames, silence_start + self.max_sil_kept)
132 | pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start
133 | sil_tags.append((pos, total_frames + 1))
134 | # Apply and return slices.
135 | if len(sil_tags) == 0:
136 | return [{'offset': 0, 'waveform': waveform}]
137 | else:
138 | chunks = []
139 | if sil_tags[0][0] > 0:
140 | chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
141 | for i in range(len(sil_tags) - 1):
142 | chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
143 | if sil_tags[-1][1] < total_frames:
144 | chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
145 | return chunks
146 |
--------------------------------------------------------------------------------
/webui.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from infer import wav2svp
3 |
4 | def inference(input, bpm, extract_pitch, extract_tension, extract_breathiness):
5 | model_path = "weights/model_steps_64000_simplified.ckpt"
6 | return wav2svp(input, model_path, bpm, extract_pitch, extract_tension, extract_breathiness)
7 |
8 | def webui():
9 | with gr.Blocks() as webui:
10 | gr.Markdown('''wav2svp - Waveform to Synthesizer V Project
''')
11 | gr.Markdown("Upload an audio file and download the svp file with midi and selected datas.")
12 | with gr.Row():
13 | with gr.Column():
14 | input = gr.File(label="Input Audio File", type="filepath")
15 | bpm = gr.Number(label='BPM Value', minimum=20, maximum=200, value=120, step=0.01, interactive=True)
16 | extract_pitch = gr.Checkbox(label="Extract Pitch Data", value=True)
17 | extract_tension = gr.Checkbox(label="Extract Tension Data (Experimental)", value=False)
18 | extract_breathiness = gr.Checkbox(label="Extract Breathiness Data (Experimental)", value=False)
19 | run = gr.Button(value="Generate svp File", variant="primary")
20 | with gr.Column():
21 | output_svp = gr.File(label="Output svp File", type="filepath", interactive=False)
22 | output_midi = gr.File(label="Output midi File", type="filepath", interactive=False)
23 | run.click(inference, [input, bpm, extract_pitch, extract_tension, extract_breathiness], [output_svp, output_midi])
24 | webui.launch(inbrowser=True)
25 |
26 | if __name__ == '__main__':
27 | webui()
--------------------------------------------------------------------------------
/weights/config.yaml:
--------------------------------------------------------------------------------
1 | accumulate_grad_batches: 1
2 | audio_sample_rate: 44100
3 | binarization_args:
4 | num_workers: 0
5 | shuffle: true
6 | binarizer_cls: preprocessing.MIDIExtractionBinarizer
7 | binary_data_dir: data/some_ds_fixmel_spk3_aug8/binary
8 | clip_grad_norm: 1
9 | dataloader_prefetch_factor: 2
10 | ddp_backend: nccl
11 | ds_workers: 4
12 | finetune_ckpt_path: null
13 | finetune_enabled: false
14 | finetune_ignored_params: []
15 | finetune_strict_shapes: true
16 | fmax: 8000
17 | fmin: 40
18 | freezing_enabled: false
19 | frozen_params: []
20 | hop_size: 512
21 | log_interval: 100
22 | lr_scheduler_args:
23 | min_lr: 1.0e-05
24 | scheduler_cls: lr_scheduler.scheduler.WarmupLR
25 | warmup_steps: 5000
26 | max_batch_frames: 80000
27 | max_batch_size: 8
28 | max_updates: 10000000
29 | max_val_batch_frames: 10000
30 | max_val_batch_size: 1
31 | midi_extractor_args:
32 | attention_drop: 0.1
33 | attention_heads: 8
34 | attention_heads_dim: 64
35 | conv_drop: 0.1
36 | dim: 512
37 | ffn_latent_drop: 0.1
38 | ffn_out_drop: 0.1
39 | kernel_size: 31
40 | lay: 8
41 | use_lay_skip: true
42 | midi_max: 128
43 | midi_min: 0
44 | midi_num_bins: 256
45 | midi_prob_deviation: 0.5
46 | midi_shift_proportion: 0.0
47 | midi_shift_range:
48 | - -6
49 | - 6
50 | model_cls: modules.model.Gmidi_conform.midi_conforms
51 | num_ckpt_keep: 5
52 | num_sanity_val_steps: 1
53 | num_valid_plots: 300
54 | optimizer_args:
55 | beta1: 0.9
56 | beta2: 0.98
57 | lr: 0.0001
58 | optimizer_cls: torch.optim.AdamW
59 | weight_decay: 0
60 | pe: rmvpe
61 | pe_ckpt: pretrained/rmvpe/model.pt
62 | permanent_ckpt_interval: 40000
63 | permanent_ckpt_start: 200000
64 | pl_trainer_accelerator: auto
65 | pl_trainer_devices: auto
66 | pl_trainer_num_nodes: 1
67 | pl_trainer_precision: 32-true
68 | pl_trainer_strategy: auto
69 | raw_data_dir: []
70 | rest_threshold: 0.1
71 | sampler_frame_count_grid: 6
72 | seed: 114514
73 | sort_by_len: true
74 | task_cls: training.MIDIExtractionTask
75 | test_prefixes: null
76 | train_set_name: train
77 | units_dim: 80
78 | units_encoder: mel
79 | units_encoder_ckpt: pretrained/contentvec/checkpoint_best_legacy_500.pt
80 | use_buond_loss: true
81 | use_midi_loss: true
82 | val_check_interval: 4000
83 | valid_set_name: valid
84 | win_size: 2048
85 |
--------------------------------------------------------------------------------