├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── build_svp.py ├── infer.py ├── inference ├── __init__.py ├── base_infer.py ├── me_infer.py └── me_quant_infer.py ├── modules ├── __init__.py ├── attention │ ├── __init__.py │ └── base_attention.py ├── conform │ ├── Gconform.py │ └── __init__.py ├── conv │ └── base_conv.py ├── model │ ├── Gmidi_conform.py │ └── __init__.py └── rmvpe │ ├── __init__.py │ ├── constants.py │ ├── deepunet.py │ ├── inference.py │ ├── model.py │ ├── seq.py │ ├── spec.py │ └── utils.py ├── requirements.txt ├── template.json ├── utils ├── __init__.py ├── infer_utils.py ├── pitch_utils.py └── slicer2.py ├── webui.py └── weights └── config.yaml /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ckpt 2 | *.pt 3 | __pycache__ 4 | input/ 5 | results/ 6 | workenv/ 7 | temp.txt 8 | *.svp 9 | *.bat 10 | *.7z -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Sucial 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # wav2svp: Waveform to Synthesizer V Project 4 | 5 |
6 | 7 | ### Description 8 | 9 | wav2svp is a project that converts a waveform to a Synthesizer V Project (SVP) file. It is based on the [SOME](https://github.com/openvpi/SOME) and [RMVPE](https://github.com/Dream-High/RMVPE). In addition to automatically extracting MIDI, this project can also extract **pitch data**, tension data(Experimental) and breathiness data(Experimental) simultaneously. But unfortunately, at present, it's unable to simultaneously extract lyrics. 10 | 11 | ### Usage 12 | 13 | You can download the **One click startup package** from [releases](https://github.com/SUC-DriverOld/wav2svp/releases), unzip and double click `go-webui.bat` to start the WebUI. 14 | 15 | ### Run from Code 16 | 17 | 1. Clone this repository and install the dependencies. We recommand to use python 3.10. 18 | 19 | ```shell 20 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | 3. Download pre-trained models: 25 | 26 | - [0119_continuous128_5spk](https://github.com/openvpi/SOME/releases/download/v1.0.0-baseline/0119_continuous128_5spk.zip) and unzip it to `weights`. 27 | - [rmvpe](https://github.com/yxlllc/RMVPE/releases/download/230917/rmvpe.zip) and unzip it to `weights` and rename it to `rmvpe.pt`. 28 | - Order the `weights` folder as follows: 29 | 30 | ```shell 31 | weights 32 | ├-config.yaml 33 | ├-model_steps_64000_simplified.ckpt 34 | └-rmvpe.pt 35 | ``` 36 | 37 | 4. Run the following command to start WebUI: 38 | 39 | ```shell 40 | python webui.py 41 | ``` 42 | 43 | 5. You can download the inference results from WebUI interface or from the `results` folder. 44 | 45 | ### Command Line Usage 46 | 47 | Use `infer.py`: 48 | 49 | ```shell 50 | usage: infer.py [-h] [--model_path MODEL_PATH] [--tempo TEMPO] [--extract_pitch] [--extract_tension] [--extract_breathiness] audio_path 51 | 52 | Inference for wav2svp 53 | 54 | positional arguments: 55 | audio_path Path to the input audio file 56 | 57 | options: 58 | -h, --help show this help message and exit 59 | --model_path MODEL_PATH 60 | Path to the model file, default: weights/model_steps_64000_simplified.ckpt 61 | --tempo TEMPO Tempo value for the midi file, default: 120 62 | --extract_pitch Whether to extract pitch from the audio file, default: False 63 | --extract_tension Whether to extract tension from the audio file, default: False 64 | --extract_breathiness 65 | Whether to extract breathiness from the audio file, default: False 66 | ``` 67 | 68 | You can find the results in the `results` folder. 69 | 70 | ### Thanks 71 | 72 | - [openvpi/SOME] [openvpi/SOME](https://github.com/openvpi/SOME) 73 | - [Dream-High/RMVPE] [Dream-High/RMVPE](https://github.com/Dream-High/RMVPE) 74 | - [yxlllc/RMVPE] [yxlllc/RMVPE](https://github.com/yxlllc/RMVPE) -------------------------------------------------------------------------------- /build_svp.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | import os 4 | import math 5 | 6 | 7 | per_dur = 705600000 # 每拍在sv的时长 8 | time_per_frame = 0.02 # 每帧的时间 hop_size / sample_rate 9 | 10 | 11 | def build_svp(template, midis, arguments, tempo, basename, extract_pitch, extract_tension, extract_breathiness) -> str: 12 | notes = [] # 用于保存的音符数据 13 | datas = [] # 用于记录的音符数据 14 | new_uuid = str(uuid.uuid4()).lower() 15 | 16 | per_time = 60 / tempo # 每拍的时间 17 | template["time"]["tempo"] = [{"position": 0, "bpm": tempo}] 18 | 19 | index = 0 20 | for midi in midis: 21 | offset = int(arguments[index]["offset"] / per_time * per_dur) # 音符的起始时间在sv的时长 22 | 23 | dur = midi["note_dur"] # 音符的时长 24 | pitch = midi["note_midi"] # 音符的音高 25 | rest = midi["note_rest"] # 是否为休止符 26 | midi_duration = 0 # 该段音符的总时长 27 | 28 | for i in range(len(pitch)): 29 | current_duration = dur[i] / per_time * per_dur # 当前音符在sv的时长 30 | onset = midi_duration + offset # 音符的起始时间 31 | midi_duration += int(current_duration) 32 | if rest[i]: # 休止符 33 | continue 34 | current_pitch = round(pitch[i]) 35 | 36 | note = { 37 | "musicalType": "singing", 38 | "onset": int(onset), 39 | "duration": int(current_duration), 40 | "lyrics": "la", 41 | "phonemes": "", 42 | "accent": "", 43 | "pitch": int(current_pitch), 44 | "detune": 0, 45 | "instantMode": False, 46 | "attributes": {"evenSyllableDuration": True}, 47 | "systemAttributes": {"evenSyllableDuration": True}, 48 | "pitchTakes": {"activeTakeId": 0,"takes": [{"id": 0,"expr": 0,"liked": False}]}, 49 | "timbreTakes": {"activeTakeId": 0,"takes": [{"id": 0,"expr": 0,"liked": False}]} 50 | } 51 | notes.append(note) 52 | 53 | data = {"start": int(onset),"finish": int(current_duration + onset),"pitch": int(current_pitch)} 54 | datas.append(data) 55 | index += 1 56 | 57 | template["tracks"][0]["mainGroup"]["notes"] = notes 58 | template["tracks"][0]["mainGroup"]["uuid"] = new_uuid 59 | template["tracks"][0]["mainRef"]["groupID"] = new_uuid 60 | 61 | pitch, tension, breathiness = [], [], [] 62 | 63 | if extract_pitch: 64 | pitch = build_pitch(datas, arguments, tempo) 65 | template["tracks"][0]["mainGroup"]["parameters"]["pitchDelta"]["points"] = pitch 66 | 67 | if extract_tension: 68 | tension = build_arguments(arguments, "tension", tempo) 69 | template["tracks"][0]["mainGroup"]["parameters"]["tension"]["points"] = tension 70 | 71 | if extract_breathiness: 72 | breathiness = build_arguments(arguments, "breathiness", tempo) 73 | template["tracks"][0]["mainGroup"]["parameters"]["breathiness"]["points"] = breathiness 74 | 75 | if extract_pitch and not extract_tension and not extract_breathiness: 76 | template["tracks"][0]["mainRef"]["voice"]["dF0Vbr"] = 0 77 | 78 | file_path = os.path.join("results", f"{basename}.svp") 79 | with open(file_path, "w", encoding="utf-8") as f: 80 | json.dump(template, f) 81 | 82 | return file_path 83 | 84 | 85 | def build_pitch(datas: list, arguments: list, tempo: int) -> list: 86 | pitch = [] # 用于保存的音高数据 87 | per_time = 60 / tempo # 每拍的时间 88 | 89 | for f0 in arguments: 90 | offset = f0["offset"] 91 | f0_data = f0["f0"] 92 | for i in range(len(f0_data)): 93 | pitch_onset, pitch_cents = None, None 94 | f0_value = f0_data[i] # 当前帧的f0值 95 | if f0_value == 0.0: 96 | continue 97 | onset_time = offset + i * time_per_frame # 当前帧的起始时间 98 | onset = (onset_time / per_time) * per_dur # 当前帧的起始时间在sv的时长 99 | pitch_onset = int(onset) 100 | for data in datas: 101 | if data["start"] <= onset and onset < data["finish"]: # 当前帧在音符的时间范围内 102 | pitch_cents = calculate_cents_difference(data["pitch"], f0_value) 103 | break 104 | if pitch_onset is None or pitch_cents is None: 105 | continue 106 | pitch.append(pitch_onset) 107 | pitch.append(pitch_cents) 108 | return pitch 109 | 110 | 111 | def build_arguments(arguments: dict, data: str, tempo: int, rate=1.0, argu_offset=0.0): 112 | args = [] 113 | per_time = 60 / tempo # 每拍的时间 114 | 115 | for arg in arguments: 116 | offset = arg["offset"] 117 | datas = arg[data] 118 | for i in range(len(datas)): 119 | onset_time = offset + i * time_per_frame # 当前帧的起始时间 120 | onset = (onset_time / per_time) * per_dur # 当前帧的起始时间在sv的时长 121 | args.append(onset) 122 | args.append(datas[i] * rate + argu_offset) 123 | return args 124 | 125 | 126 | def calculate_cents_difference(midi_note, f0): 127 | def midi_to_freq(midi_note): 128 | A4 = 440.0 129 | return A4 * (2 ** ((midi_note - 69) / 12)) 130 | 131 | def cents_difference(f0, midi_note): 132 | midi_freq = midi_to_freq(midi_note) 133 | return 1200 * math.log2(f0 / midi_freq) 134 | 135 | cents_diff = cents_difference(f0, midi_note) 136 | return round(cents_diff, 5) 137 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import os 3 | import yaml 4 | import json 5 | import inference 6 | import importlib 7 | import numpy as np 8 | from tqdm import tqdm 9 | from scipy.ndimage import gaussian_filter 10 | 11 | from modules.rmvpe.inference import RMVPE 12 | from utils.slicer2 import Slicer 13 | from utils.infer_utils import build_midi_file 14 | from build_svp import build_svp 15 | 16 | 17 | def load_config(config_path: str) -> dict: 18 | if config_path.endswith('.yaml'): 19 | with open(config_path, 'r', encoding='utf8') as f: 20 | config = yaml.safe_load(f) 21 | elif config_path.endswith('.json'): 22 | with open(config_path, 'r', encoding='utf8') as f: 23 | config = json.load(f) 24 | else: 25 | raise ValueError(f'Unsupported config file format: {config_path}') 26 | return config 27 | 28 | 29 | config = load_config('weights/config.yaml') 30 | sr = config['audio_sample_rate'] 31 | 32 | 33 | def audio_slicer(audio_path: str) -> list: 34 | """ 35 | Returns: 36 | list of dict: [{ 37 | "offset": np.float64, 38 | "waveform": array of float, dtype=float32, 39 | }, ...] 40 | """ 41 | waveform, _ = librosa.load(audio_path, sr=sr, mono=True) 42 | slicer = Slicer(sr=sr, max_sil_kept=1000) 43 | chunks = slicer.slice(waveform) 44 | for c in chunks: 45 | c['waveform_16k'] = librosa.resample(y=c['waveform'], orig_sr=sr, target_sr=16000) 46 | return chunks 47 | 48 | 49 | def get_midi(chunks: list, model_path: str) -> list: 50 | """ 51 | Args: 52 | chunks (list): results from audio_slicer 53 | 54 | Returns: 55 | list of dict: [{ 56 | "note_midi": array of float, dtype=float32, 57 | "note_dur": array of float, 58 | "note_rest": array of bool, 59 | }, ...] 60 | """ 61 | infer_cls = inference.task_inference_mapping[config['task_cls']] 62 | pkg = ".".join(infer_cls.split(".")[:-1]) 63 | cls_name = infer_cls.split(".")[-1] 64 | infer_cls = getattr(importlib.import_module(pkg), cls_name) 65 | assert issubclass(infer_cls, inference.BaseInference), \ 66 | f'Inference class {infer_cls} is not a subclass of {inference.BaseInference}.' 67 | infer_ins = infer_cls(config=config, model_path=model_path) 68 | midis = infer_ins.infer([c['waveform'] for c in chunks]) 69 | return midis 70 | 71 | 72 | def save_midi(midis: list, tempo: int, chunks: list, midi_path: str) -> None: 73 | midi_file = build_midi_file([c['offset'] for c in chunks], midis, tempo=tempo) 74 | midi_file.save(midi_path) 75 | 76 | 77 | def get_f0(chunks: list): 78 | rmvpe = RMVPE(model_path='weights/rmvpe.pt') # hop_size=160 79 | for chunk in tqdm(chunks, desc='Extracting F0'): 80 | chunk['f0'] = rmvpe.infer_from_audio(chunk['waveform_16k'], sample_rate=16000)[::2].astype(float) 81 | return chunks 82 | 83 | 84 | def get_energy_librosa(waveform, hop_size, win_size): 85 | energy = librosa.feature.rms(y=waveform, frame_length=win_size, hop_length=hop_size)[0] 86 | return energy 87 | 88 | 89 | def get_breathiness(chunks, hop_size, win_size, sigma=1.0): 90 | for chunk in tqdm(chunks, desc='Extracting Breathiness'): 91 | waveform = chunk['waveform_16k'] 92 | waveform_ap = librosa.effects.percussive(waveform) 93 | breathiness = get_energy_librosa(waveform_ap, hop_size, win_size) 94 | breathiness = (2 / max(abs(np.max(breathiness)), abs(np.min(breathiness)))) * breathiness 95 | breathiness = np.tanh(breathiness - np.mean(breathiness)) 96 | breathiness_smoothed = gaussian_filter(breathiness, sigma=sigma) 97 | chunk['breathiness'] = breathiness_smoothed[::2].astype(float) 98 | return chunks 99 | 100 | 101 | def get_tension(chunks, hop_size, win_size, sigma=1.0): 102 | for chunk in tqdm(chunks, desc='Extracting Tension'): 103 | waveform = chunk['waveform_16k'] 104 | waveform_h = librosa.effects.harmonic(waveform) 105 | waveform_base_h = librosa.effects.harmonic(waveform, power=0.5) 106 | energy_base_h = get_energy_librosa(waveform_base_h, hop_size, win_size) 107 | energy_h = get_energy_librosa(waveform_h, hop_size, win_size) 108 | tension = np.sqrt(np.clip(energy_h ** 2 - energy_base_h ** 2, 0, None)) / (energy_h + 1e-5) 109 | tension = (2 / max(abs(np.max(tension)), abs(np.min(tension)))) * tension 110 | tension = np.tanh(tension - np.mean(tension)) 111 | tension_smoothed = gaussian_filter(tension, sigma=sigma) 112 | chunk['tension'] = tension_smoothed[::2].astype(float) 113 | return chunks 114 | 115 | 116 | def get_arguments(chunks, hop_size, win_size, extract_pitch=False, extract_tension=False, extract_breathiness=False): 117 | if extract_pitch: 118 | chunks = get_f0(chunks) 119 | if extract_tension: 120 | chunks = get_tension(chunks, hop_size, win_size) 121 | if extract_breathiness: 122 | chunks = get_breathiness(chunks, hop_size, win_size) 123 | return chunks 124 | 125 | 126 | def wav2svp(audio_path, model_path, tempo=120, extract_pitch=False, extract_tension=False, extract_breathiness=False): 127 | os.makedirs('results', exist_ok=True) 128 | basename = os.path.basename(audio_path).split('.')[0] 129 | 130 | chunks = audio_slicer(audio_path) 131 | midis = get_midi(chunks, model_path) 132 | arguments = get_arguments( 133 | chunks, hop_size=160, win_size=1024, 134 | extract_pitch=extract_pitch, extract_tension=extract_tension, extract_breathiness=extract_breathiness 135 | ) 136 | 137 | template = load_config('template.json') 138 | 139 | print("building svp file") 140 | svp_path = build_svp(template, midis, arguments, tempo, basename, extract_pitch, extract_tension, extract_breathiness) 141 | 142 | print("building midi file") 143 | midi_path = os.path.join('results', f'{basename}.mid') 144 | save_midi(midis, tempo, chunks, midi_path) 145 | 146 | print("Success") 147 | return svp_path, midi_path 148 | 149 | 150 | if __name__ == '__main__': 151 | import argparse 152 | parser = argparse.ArgumentParser(description='Inference for wav2svp') 153 | parser.add_argument('audio_path', type=str, help='Path to the input audio file') 154 | parser.add_argument('--model_path', type=str, default="weights/model_steps_64000_simplified.ckpt", help='Path to the model file, default: weights/model_steps_64000_simplified.ckpt') 155 | parser.add_argument('--tempo', type=float, default=120.0, help='Tempo value for the midi file, default: 120') 156 | parser.add_argument('--extract_pitch', action='store_true', help='Whether to extract pitch from the audio file, default: False') 157 | parser.add_argument('--extract_tension', action='store_true', help='Whether to extract tension from the audio file, default: False') 158 | parser.add_argument('--extract_breathiness', action='store_true', help='Whether to extract breathiness from the audio file, default: False') 159 | args = parser.parse_args() 160 | 161 | assert os.path.isfile("weights/rmvpe.pt"), "RMVPE model not found" 162 | assert os.path.isfile(args.model_path), "SOME Model not found" 163 | 164 | wav2svp(args.audio_path, args.model_path, args.tempo, args.extract_pitch, args.extract_tension, args.extract_breathiness) -------------------------------------------------------------------------------- /inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_infer import BaseInference 2 | from .me_infer import MIDIExtractionInference 3 | from .me_quant_infer import QuantizedMIDIExtractionInference 4 | 5 | task_inference_mapping = { 6 | 'training.MIDIExtractionTask': 'inference.MIDIExtractionInference', 7 | 'training.QuantizedMIDIExtractionTask': 'inference.QuantizedMIDIExtractionInference', 8 | } 9 | -------------------------------------------------------------------------------- /inference/base_infer.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from collections import OrderedDict 3 | from typing import Dict, List 4 | 5 | import numpy as np 6 | import torch 7 | import tqdm 8 | from torch import nn 9 | 10 | from utils import build_object_from_class_name 11 | 12 | 13 | class BaseInference: 14 | def __init__(self, config: dict, model_path: pathlib.Path, device=None): 15 | if device is None: 16 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | self.config = config 18 | self.model_path = model_path 19 | self.device = device 20 | self.timestep = self.config['hop_size'] / self.config['audio_sample_rate'] 21 | self.model: torch.nn.Module = self.build_model() 22 | 23 | def build_model(self) -> nn.Module: 24 | model: nn.Module = build_object_from_class_name( 25 | self.config['model_cls'], nn.Module, config=self.config 26 | ).eval().to(self.device) 27 | state_dict = torch.load(self.model_path, map_location=self.device, weights_only=True)['state_dict'] 28 | prefix_in_ckpt = 'model' 29 | state_dict = OrderedDict({ 30 | k[len(prefix_in_ckpt) + 1:]: v 31 | for k, v in state_dict.items() if k.startswith(f'{prefix_in_ckpt}.') 32 | }) 33 | model.load_state_dict(state_dict, strict=True) 34 | print(f'load \'{prefix_in_ckpt}\' from \'{self.model_path}\'.') 35 | return model 36 | 37 | def preprocess(self, waveform: np.ndarray) -> Dict[str, torch.Tensor]: 38 | raise NotImplementedError() 39 | 40 | def forward_model(self, sample: Dict[str, torch.Tensor]): 41 | raise NotImplementedError() 42 | 43 | def postprocess(self, results: Dict[str, torch.Tensor]) -> List[Dict[str, np.ndarray]]: 44 | raise NotImplementedError() 45 | 46 | def infer(self, waveforms: List[np.ndarray]) -> List[Dict[str, np.ndarray]]: 47 | results = [] 48 | for w in tqdm.tqdm(waveforms, desc='SOME Inference'): 49 | model_in = self.preprocess(w) 50 | model_out = self.forward_model(model_in) 51 | res = self.postprocess(model_out) 52 | results.append(res) 53 | return results 54 | -------------------------------------------------------------------------------- /inference/me_infer.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import modules.rmvpe 8 | from utils.infer_utils import decode_bounds_to_alignment, decode_gaussian_blurred_probs, decode_note_sequence 9 | from .base_infer import BaseInference 10 | 11 | 12 | class MIDIExtractionInference(BaseInference): 13 | def __init__(self, config: dict, model_path: pathlib.Path, device=None): 14 | super().__init__(config, model_path, device=device) 15 | self.mel_spec = modules.rmvpe.MelSpectrogram( 16 | n_mel_channels=self.config['units_dim'], sampling_rate=self.config['audio_sample_rate'], 17 | win_length=self.config['win_size'], hop_length=self.config['hop_size'], 18 | mel_fmin=self.config['fmin'], mel_fmax=self.config['fmax'] 19 | ).to(self.device) 20 | self.rmvpe = None 21 | self.midi_min = self.config['midi_min'] 22 | self.midi_max = self.config['midi_max'] 23 | self.midi_deviation = self.config['midi_prob_deviation'] 24 | self.rest_threshold = self.config['rest_threshold'] 25 | 26 | def preprocess(self, waveform: np.ndarray) -> Dict[str, torch.Tensor]: 27 | wav_tensor = torch.from_numpy(waveform).unsqueeze(0).to(self.device) 28 | units = self.mel_spec(wav_tensor).transpose(1, 2) 29 | length = units.shape[1] 30 | 31 | pitch = torch.zeros(units.shape[:2], dtype=torch.float32, device=self.device) 32 | return { 33 | 'units': units, 34 | 'pitch': pitch, 35 | 'masks': torch.ones_like(pitch, dtype=torch.bool) 36 | } 37 | 38 | @torch.no_grad() 39 | def forward_model(self, sample: Dict[str, torch.Tensor]): 40 | 41 | probs, bounds = self.model(x=sample['units'], f0=sample['pitch'], mask=sample['masks'],sig=True) 42 | 43 | return { 44 | 'probs': probs, 45 | 'bounds': bounds, 46 | 'masks': sample['masks'], 47 | } 48 | 49 | def postprocess(self, results: Dict[str, torch.Tensor]) -> List[Dict[str, np.ndarray]]: 50 | probs = results['probs'] 51 | bounds = results['bounds'] 52 | masks = results['masks'] 53 | probs *= masks[..., None] 54 | bounds *= masks 55 | unit2note_pred = decode_bounds_to_alignment(bounds) * masks 56 | midi_pred, rest_pred = decode_gaussian_blurred_probs( 57 | probs, vmin=self.midi_min, vmax=self.midi_max, 58 | deviation=self.midi_deviation, threshold=self.rest_threshold 59 | ) 60 | note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence( 61 | unit2note_pred, midi_pred, ~rest_pred & masks 62 | ) 63 | note_rest_pred = ~note_mask_pred 64 | return { 65 | 'note_midi': note_midi_pred.squeeze(0).cpu().numpy(), 66 | 'note_dur': note_dur_pred.squeeze(0).cpu().numpy() * self.timestep, 67 | 'note_rest': note_rest_pred.squeeze(0).cpu().numpy() 68 | } 69 | -------------------------------------------------------------------------------- /inference/me_quant_infer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from utils.infer_utils import decode_bounds_to_alignment, decode_note_sequence 7 | from .me_infer import MIDIExtractionInference 8 | 9 | 10 | class QuantizedMIDIExtractionInference(MIDIExtractionInference): 11 | @torch.no_grad() 12 | def forward_model(self, sample: Dict[str, torch.Tensor]): 13 | probs, bounds = self.model(x=sample['units'], f0=sample['pitch'], mask=sample['masks'], softmax=True) 14 | 15 | return { 16 | 'probs': probs, 17 | 'bounds': bounds, 18 | 'masks': sample['masks'], 19 | } 20 | 21 | def postprocess(self, results: Dict[str, torch.Tensor]) -> List[Dict[str, np.ndarray]]: 22 | probs = results['probs'] 23 | bounds = results['bounds'] 24 | masks = results['masks'] 25 | probs *= masks[..., None] 26 | bounds *= masks 27 | unit2note_pred = decode_bounds_to_alignment(bounds) * masks 28 | midi_pred = probs.argmax(dim=-1) 29 | rest_pred = midi_pred == 128 30 | note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence( 31 | unit2note_pred, midi_pred.clip(min=0, max=127), ~rest_pred & masks 32 | ) 33 | note_rest_pred = ~note_mask_pred 34 | return { 35 | 'note_midi': note_midi_pred.squeeze(0).cpu().numpy(), 36 | 'note_dur': note_dur_pred.squeeze(0).cpu().numpy() * self.timestep, 37 | 'note_rest': note_rest_pred.squeeze(0).cpu().numpy() 38 | } 39 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SUC-DriverOld/wav2svp/60396b79431f672b6f92d8bb43912d1c861595b9/modules/__init__.py -------------------------------------------------------------------------------- /modules/attention/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SUC-DriverOld/wav2svp/60396b79431f672b6f92d8bb43912d1c861595b9/modules/attention/__init__.py -------------------------------------------------------------------------------- /modules/attention/base_attention.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | 8 | class Attention(nn.Module): 9 | def __init__(self, dim, heads=4, dim_head=32, conditiondim=None): 10 | super().__init__() 11 | if conditiondim is None: 12 | conditiondim = dim 13 | 14 | self.scale = dim_head ** -0.5 15 | self.heads = heads 16 | hidden_dim = dim_head * heads 17 | self.to_q = nn.Linear(dim, hidden_dim, bias=False) 18 | self.to_kv = nn.Linear(conditiondim, hidden_dim * 2, bias=False) 19 | 20 | self.to_out = nn.Sequential(nn.Linear(hidden_dim, dim, ), 21 | ) 22 | 23 | def forward(self, q, kv=None, mask=None): 24 | # b, c, h, w = x.shape 25 | if kv is None: 26 | kv = q 27 | # q, kv = map( 28 | # lambda t: rearrange(t, "b c t -> b t c", ), (q, kv) 29 | # ) 30 | 31 | q = self.to_q(q) 32 | k, v = self.to_kv(kv).chunk(2, dim=2) 33 | 34 | q, k, v = map( 35 | lambda t: rearrange(t, "b t (h c) -> b h t c", h=self.heads), (q, k, v) 36 | ) 37 | 38 | if mask is not None: 39 | mask = mask.unsqueeze(1).unsqueeze(1) 40 | 41 | with torch.backends.cuda.sdp_kernel(enable_math=False 42 | ): 43 | out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) 44 | 45 | out = rearrange(out, "b h t c -> b t (h c) ", h=self.heads, ) 46 | return self.to_out(out) 47 | -------------------------------------------------------------------------------- /modules/conform/Gconform.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from einops import rearrange 7 | 8 | from modules.attention.base_attention import Attention 9 | from modules.conv.base_conv import conform_conv 10 | class GLU(nn.Module): 11 | def __init__(self, dim): 12 | super().__init__() 13 | self.dim = dim 14 | 15 | def forward(self, x): 16 | out, gate = x.chunk(2, dim=self.dim) 17 | 18 | return out * gate.sigmoid() 19 | 20 | class conform_ffn(nn.Module): 21 | def __init__(self, dim, DropoutL1: float = 0.1, DropoutL2: float = 0.1): 22 | super().__init__() 23 | self.ln1 = nn.Linear(dim, dim * 4) 24 | self.ln2 = nn.Linear(dim * 4, dim) 25 | self.drop1 = nn.Dropout(DropoutL1) if DropoutL1 > 0. else nn.Identity() 26 | self.drop2 = nn.Dropout(DropoutL2) if DropoutL2 > 0. else nn.Identity() 27 | self.act = nn.SiLU() 28 | 29 | def forward(self, x): 30 | x = self.ln1(x) 31 | x = self.act(x) 32 | x = self.drop1(x) 33 | x = self.ln2(x) 34 | return self.drop2(x) 35 | 36 | 37 | class conform_blocke(nn.Module): 38 | def __init__(self, dim: int, kernel_size: int = 31, conv_drop: float = 0.1, ffn_latent_drop: float = 0.1, 39 | ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4, 40 | attention_heads_dim: int = 64): 41 | super().__init__() 42 | self.ffn1 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop) 43 | self.ffn2 = conform_ffn(dim, ffn_latent_drop, ffn_out_drop) 44 | self.att = Attention(dim, heads=attention_heads, dim_head=attention_heads_dim) 45 | self.attdrop = nn.Dropout(attention_drop) if attention_drop > 0. else nn.Identity() 46 | self.conv = conform_conv(dim, kernel_size=kernel_size, 47 | 48 | DropoutL=conv_drop, ) 49 | self.norm1 = nn.LayerNorm(dim) 50 | self.norm2 = nn.LayerNorm(dim) 51 | self.norm3 = nn.LayerNorm(dim) 52 | self.norm4 = nn.LayerNorm(dim) 53 | self.norm5 = nn.LayerNorm(dim) 54 | 55 | 56 | def forward(self, x, mask=None,): 57 | x = self.ffn1(self.norm1(x)) * 0.5 + x 58 | 59 | 60 | x = self.attdrop(self.att(self.norm2(x), mask=mask)) + x 61 | x = self.conv(self.norm3(x)) + x 62 | x = self.ffn2(self.norm4(x)) * 0.5 + x 63 | return self.norm5(x) 64 | 65 | # return x 66 | 67 | 68 | class Gcf(nn.Module): 69 | def __init__(self,dim: int, kernel_size: int = 31, conv_drop: float = 0.1, ffn_latent_drop: float = 0.1, 70 | ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4, 71 | attention_heads_dim: int = 64): 72 | super().__init__() 73 | self.att1=conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop, 74 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads, 75 | attention_heads_dim=attention_heads_dim) 76 | self.att2 = conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop, 77 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads, 78 | attention_heads_dim=attention_heads_dim) 79 | self.glu1=nn.Sequential(nn.Linear(dim, dim*2),GLU(2) ) 80 | self.glu2 = nn.Sequential(nn.Linear(dim, dim * 2), GLU(2)) 81 | 82 | def forward(self, midi,bound): 83 | midi=self.att1(midi) 84 | bound=self.att2(bound) 85 | midis=self.glu1(midi) 86 | bounds=self.glu2(bound) 87 | return midi+bounds,bound+midis 88 | 89 | 90 | 91 | 92 | class Gmidi_conform(nn.Module): 93 | def __init__(self, lay: int, dim: int, indim: int, outdim: int, use_lay_skip: bool, kernel_size: int = 31, 94 | conv_drop: float = 0.1, 95 | ffn_latent_drop: float = 0.1, 96 | ffn_out_drop: float = 0.1, attention_drop: float = 0.1, attention_heads: int = 4, 97 | attention_heads_dim: int = 64): 98 | super().__init__() 99 | 100 | self.inln = nn.Linear(indim, dim) 101 | self.inln1 = nn.Linear(indim, dim) 102 | self.outln = nn.Linear(dim, outdim) 103 | self.cutheard = nn.Linear(dim, 1) 104 | # self.cutheard = nn.Linear(dim, outdim) 105 | self.lay = lay 106 | self.use_lay_skip = use_lay_skip 107 | self.cf_lay = nn.ModuleList( 108 | [Gcf(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop, 109 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads, 110 | attention_heads_dim=attention_heads_dim) for _ in range(lay)]) 111 | self.att1=conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop, 112 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads, 113 | attention_heads_dim=attention_heads_dim) 114 | self.att2 = conform_blocke(dim=dim, kernel_size=kernel_size, conv_drop=conv_drop, ffn_latent_drop=ffn_latent_drop, 115 | ffn_out_drop=ffn_out_drop, attention_drop=attention_drop, attention_heads=attention_heads, 116 | attention_heads_dim=attention_heads_dim) 117 | 118 | 119 | def forward(self, x, pitch, mask=None): 120 | 121 | # torch.masked_fill() 122 | x1=x.clone() 123 | 124 | x = self.inln(x ) 125 | x1=self.inln1(x1) 126 | if mask is not None: 127 | x = x.masked_fill(~mask.unsqueeze(-1), 0) 128 | for idx, i in enumerate(self.cf_lay): 129 | x,x1 = i(x,x1) 130 | 131 | if mask is not None: 132 | x = x.masked_fill(~mask.unsqueeze(-1), 0) 133 | x,x1=self.att1(x),self.att2(x1) 134 | 135 | cutprp = self.cutheard(x1) 136 | midiout = self.outln(x) 137 | cutprp = torch.sigmoid(cutprp) 138 | cutprp = torch.squeeze(cutprp, -1) 139 | 140 | return midiout, cutprp 141 | -------------------------------------------------------------------------------- /modules/conform/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SUC-DriverOld/wav2svp/60396b79431f672b6f92d8bb43912d1c861595b9/modules/conform/__init__.py -------------------------------------------------------------------------------- /modules/conv/base_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | 6 | 7 | class GLU(nn.Module): 8 | def __init__(self, dim): 9 | super().__init__() 10 | self.dim = dim 11 | 12 | def forward(self, x): 13 | out, gate = x.chunk(2, dim=self.dim) 14 | 15 | return out * gate.sigmoid() 16 | 17 | 18 | class conform_conv(nn.Module): 19 | def __init__(self, channels: int, 20 | kernel_size: int = 31, 21 | 22 | DropoutL=0.1, 23 | 24 | bias: bool = True): 25 | super().__init__() 26 | self.act2 = nn.SiLU() 27 | self.act1 = GLU(1) 28 | 29 | self.pointwise_conv1 = nn.Conv1d( 30 | channels, 31 | 2 * channels, 32 | kernel_size=1, 33 | stride=1, 34 | padding=0, 35 | bias=bias) 36 | 37 | # self.lorder is used to distinguish if it's a causal convolution, 38 | # if self.lorder > 0: 39 | # it's a causal convolution, the input will be padded with 40 | # `self.lorder` frames on the left in forward (causal conv impl). 41 | # else: it's a symmetrical convolution 42 | 43 | assert (kernel_size - 1) % 2 == 0 44 | padding = (kernel_size - 1) // 2 45 | 46 | self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, 47 | stride=1, 48 | padding=padding, 49 | groups=channels, 50 | bias=bias) 51 | 52 | 53 | self.norm = nn.BatchNorm1d(channels) 54 | 55 | 56 | self.pointwise_conv2 = nn.Conv1d(channels, 57 | channels, 58 | kernel_size=1, 59 | stride=1, 60 | padding=0, 61 | bias=bias) 62 | self.drop=nn.Dropout(DropoutL) if DropoutL>0. else nn.Identity() 63 | def forward(self,x): 64 | x=x.transpose(1,2) 65 | x=self.act1(self.pointwise_conv1(x)) 66 | x=self.depthwise_conv (x) 67 | x=self.norm(x) 68 | x=self.act2(x) 69 | x=self.pointwise_conv2(x) 70 | return self.drop(x).transpose(1,2) 71 | 72 | -------------------------------------------------------------------------------- /modules/model/Gmidi_conform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modules.conform.Gconform import Gmidi_conform 5 | 6 | 7 | 8 | class midi_loss(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | self.loss = nn.BCELoss() 12 | 13 | def forward(self, x, target): 14 | midiout, cutp = x 15 | midi_target, cutp_target = target 16 | 17 | cutploss = self.loss(cutp, cutp_target) 18 | midiloss = self.loss(midiout, midi_target) 19 | return midiloss, cutploss 20 | 21 | 22 | class midi_conforms(nn.Module): 23 | def __init__(self, config): 24 | super().__init__() 25 | 26 | cfg = config['midi_extractor_args'] 27 | cfg.update({'indim': config['units_dim'], 'outdim': config['midi_num_bins']}) 28 | self.model = Gmidi_conform(**cfg) 29 | 30 | def forward(self, x, f0, mask=None,softmax=False,sig=False): 31 | 32 | midi,bound=self.model(x, f0, mask) 33 | if sig: 34 | midi = torch.sigmoid(midi) 35 | 36 | if softmax: 37 | midi=F.softmax(midi,dim=2) 38 | 39 | 40 | return midi,bound 41 | 42 | def get_loss(self): 43 | return midi_loss() 44 | -------------------------------------------------------------------------------- /modules/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SUC-DriverOld/wav2svp/60396b79431f672b6f92d8bb43912d1c861595b9/modules/model/__init__.py -------------------------------------------------------------------------------- /modules/rmvpe/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import * 2 | from .model import E2E0 3 | from .utils import to_local_average_f0, to_viterbi_f0 4 | from .spec import MelSpectrogram 5 | -------------------------------------------------------------------------------- /modules/rmvpe/constants.py: -------------------------------------------------------------------------------- 1 | SAMPLE_RATE = 16000 2 | 3 | N_CLASS = 360 4 | 5 | N_MELS = 128 6 | MEL_FMIN = 30 7 | MEL_FMAX = 8000 8 | WINDOW_LENGTH = 1024 9 | CONST = 1997.3794084376191 10 | -------------------------------------------------------------------------------- /modules/rmvpe/deepunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .constants import N_MELS 4 | 5 | 6 | class ConvBlockRes(nn.Module): 7 | def __init__(self, in_channels, out_channels, momentum=0.01): 8 | super(ConvBlockRes, self).__init__() 9 | self.conv = nn.Sequential( 10 | nn.Conv2d(in_channels=in_channels, 11 | out_channels=out_channels, 12 | kernel_size=(3, 3), 13 | stride=(1, 1), 14 | padding=(1, 1), 15 | bias=False), 16 | nn.BatchNorm2d(out_channels, momentum=momentum), 17 | nn.ReLU(), 18 | 19 | nn.Conv2d(in_channels=out_channels, 20 | out_channels=out_channels, 21 | kernel_size=(3, 3), 22 | stride=(1, 1), 23 | padding=(1, 1), 24 | bias=False), 25 | nn.BatchNorm2d(out_channels, momentum=momentum), 26 | nn.ReLU(), 27 | ) 28 | if in_channels != out_channels: 29 | self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) 30 | self.is_shortcut = True 31 | else: 32 | self.is_shortcut = False 33 | 34 | def forward(self, x): 35 | if self.is_shortcut: 36 | return self.conv(x) + self.shortcut(x) 37 | else: 38 | return self.conv(x) + x 39 | 40 | 41 | class ResEncoderBlock(nn.Module): 42 | def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01): 43 | super(ResEncoderBlock, self).__init__() 44 | self.n_blocks = n_blocks 45 | self.conv = nn.ModuleList() 46 | self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) 47 | for i in range(n_blocks - 1): 48 | self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) 49 | self.kernel_size = kernel_size 50 | if self.kernel_size is not None: 51 | self.pool = nn.AvgPool2d(kernel_size=kernel_size) 52 | 53 | def forward(self, x): 54 | for i in range(self.n_blocks): 55 | x = self.conv[i](x) 56 | if self.kernel_size is not None: 57 | return x, self.pool(x) 58 | else: 59 | return x 60 | 61 | 62 | class ResDecoderBlock(nn.Module): 63 | def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01): 64 | super(ResDecoderBlock, self).__init__() 65 | out_padding = (0, 1) if stride == (1, 2) else (1, 1) 66 | self.n_blocks = n_blocks 67 | self.conv1 = nn.Sequential( 68 | nn.ConvTranspose2d(in_channels=in_channels, 69 | out_channels=out_channels, 70 | kernel_size=(3, 3), 71 | stride=stride, 72 | padding=(1, 1), 73 | output_padding=out_padding, 74 | bias=False), 75 | nn.BatchNorm2d(out_channels, momentum=momentum), 76 | nn.ReLU(), 77 | ) 78 | self.conv2 = nn.ModuleList() 79 | self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) 80 | for i in range(n_blocks-1): 81 | self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) 82 | 83 | def forward(self, x, concat_tensor): 84 | x = self.conv1(x) 85 | x = torch.cat((x, concat_tensor), dim=1) 86 | for i in range(self.n_blocks): 87 | x = self.conv2[i](x) 88 | return x 89 | 90 | 91 | class Encoder(nn.Module): 92 | def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01): 93 | super(Encoder, self).__init__() 94 | self.n_encoders = n_encoders 95 | self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) 96 | self.layers = nn.ModuleList() 97 | self.latent_channels = [] 98 | for i in range(self.n_encoders): 99 | self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum)) 100 | self.latent_channels.append([out_channels, in_size]) 101 | in_channels = out_channels 102 | out_channels *= 2 103 | in_size //= 2 104 | self.out_size = in_size 105 | self.out_channel = out_channels 106 | 107 | def forward(self, x): 108 | concat_tensors = [] 109 | x = self.bn(x) 110 | for i in range(self.n_encoders): 111 | _, x = self.layers[i](x) 112 | concat_tensors.append(_) 113 | return x, concat_tensors 114 | 115 | 116 | class Intermediate(nn.Module): 117 | def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01): 118 | super(Intermediate, self).__init__() 119 | self.n_inters = n_inters 120 | self.layers = nn.ModuleList() 121 | self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)) 122 | for i in range(self.n_inters-1): 123 | self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)) 124 | 125 | def forward(self, x): 126 | for i in range(self.n_inters): 127 | x = self.layers[i](x) 128 | return x 129 | 130 | 131 | class Decoder(nn.Module): 132 | def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01): 133 | super(Decoder, self).__init__() 134 | self.layers = nn.ModuleList() 135 | self.n_decoders = n_decoders 136 | for i in range(self.n_decoders): 137 | out_channels = in_channels // 2 138 | self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)) 139 | in_channels = out_channels 140 | 141 | def forward(self, x, concat_tensors): 142 | for i in range(self.n_decoders): 143 | x = self.layers[i](x, concat_tensors[-1-i]) 144 | return x 145 | 146 | 147 | class TimbreFilter(nn.Module): 148 | def __init__(self, latent_rep_channels): 149 | super(TimbreFilter, self).__init__() 150 | self.layers = nn.ModuleList() 151 | for latent_rep in latent_rep_channels: 152 | self.layers.append(ConvBlockRes(latent_rep[0], latent_rep[0])) 153 | 154 | def forward(self, x_tensors): 155 | out_tensors = [] 156 | for i, layer in enumerate(self.layers): 157 | out_tensors.append(layer(x_tensors[i])) 158 | return out_tensors 159 | 160 | 161 | class DeepUnet0(nn.Module): 162 | def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16): 163 | super(DeepUnet0, self).__init__() 164 | self.encoder = Encoder(in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels) 165 | self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks) 166 | self.tf = TimbreFilter(self.encoder.latent_channels) 167 | self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks) 168 | 169 | def forward(self, x): 170 | x, concat_tensors = self.encoder(x) 171 | x = self.intermediate(x) 172 | x = self.decoder(x, concat_tensors) 173 | return x 174 | -------------------------------------------------------------------------------- /modules/rmvpe/inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torchaudio.transforms import Resample 5 | 6 | from utils.pitch_utils import interp_f0, resample_align_curve 7 | from .constants import * 8 | from .model import E2E0 9 | from .spec import MelSpectrogram 10 | from .utils import to_local_average_f0, to_viterbi_f0 11 | 12 | 13 | class RMVPE: 14 | def __init__(self, model_path, hop_length=160, device=None): 15 | self.resample_kernel = {} 16 | if device is None: 17 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 18 | else: 19 | self.device = device 20 | self.model = E2E0(4, 1, (2, 2)).eval().to(self.device) 21 | ckpt = torch.load(model_path, map_location=self.device, weights_only=True) 22 | self.model.load_state_dict(ckpt['model'], strict=False) 23 | self.mel_extractor = MelSpectrogram( 24 | N_MELS, SAMPLE_RATE, WINDOW_LENGTH, hop_length, None, MEL_FMIN, MEL_FMAX 25 | ).to(self.device) 26 | 27 | @torch.no_grad() 28 | def mel2hidden(self, mel): 29 | n_frames = mel.shape[-1] 30 | mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode='constant') 31 | hidden = self.model(mel) 32 | return hidden[:, :n_frames] 33 | 34 | def decode(self, hidden, thred=0.03, use_viterbi=False): 35 | if use_viterbi: 36 | f0 = to_viterbi_f0(hidden, thred=thred) 37 | else: 38 | f0 = to_local_average_f0(hidden, thred=thred) 39 | return f0 40 | 41 | def infer_from_audio(self, audio, sample_rate=16000, thred=0.03, use_viterbi=False): 42 | audio = torch.from_numpy(audio).float().unsqueeze(0).to(self.device) 43 | if sample_rate == 16000: 44 | audio_res = audio 45 | else: 46 | key_str = str(sample_rate) 47 | if key_str not in self.resample_kernel: 48 | self.resample_kernel[key_str] = Resample(sample_rate, 16000, lowpass_filter_width=128) 49 | self.resample_kernel[key_str] = self.resample_kernel[key_str].to(self.device) 50 | audio_res = self.resample_kernel[key_str](audio) 51 | mel = self.mel_extractor(audio_res, center=True) 52 | hidden = self.mel2hidden(mel) 53 | f0 = self.decode(hidden, thred=thred, use_viterbi=use_viterbi) 54 | return f0 55 | 56 | def get_pitch(self, waveform, sample_rate, hop_size, length, interp_uv=False): 57 | f0 = self.infer_from_audio(waveform, sample_rate=sample_rate) 58 | uv = f0 == 0 59 | f0, uv = interp_f0(f0, uv) 60 | 61 | time_step = hop_size / sample_rate 62 | f0_res = resample_align_curve(f0, 0.01, time_step, length) 63 | uv_res = resample_align_curve(uv.astype(np.float32), 0.01, time_step, length) > 0.5 64 | if not interp_uv: 65 | f0_res[uv_res] = 0 66 | return f0_res, uv_res 67 | -------------------------------------------------------------------------------- /modules/rmvpe/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .constants import * 4 | from .deepunet import DeepUnet0 5 | from .seq import BiGRU 6 | 7 | 8 | class E2E0(nn.Module): 9 | def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, 10 | en_out_channels=16): 11 | super(E2E0, self).__init__() 12 | self.unet = DeepUnet0(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels) 13 | self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) 14 | if n_gru: 15 | self.fc = nn.Sequential( 16 | BiGRU(3 * N_MELS, 256, n_gru), 17 | nn.Linear(512, N_CLASS), 18 | nn.Dropout(0.25), 19 | nn.Sigmoid() 20 | ) 21 | else: 22 | self.fc = nn.Sequential( 23 | nn.Linear(3 * N_MELS, N_CLASS), 24 | nn.Dropout(0.25), 25 | nn.Sigmoid() 26 | ) 27 | 28 | def forward(self, mel): 29 | mel = mel.transpose(-1, -2).unsqueeze(1) 30 | x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) 31 | x = self.fc(x) 32 | return x 33 | -------------------------------------------------------------------------------- /modules/rmvpe/seq.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BiGRU(nn.Module): 5 | def __init__(self, input_features, hidden_features, num_layers): 6 | super(BiGRU, self).__init__() 7 | self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True) 8 | 9 | def forward(self, x): 10 | return self.gru(x)[0] 11 | -------------------------------------------------------------------------------- /modules/rmvpe/spec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from librosa.filters import mel 5 | 6 | 7 | class MelSpectrogram(torch.nn.Module): 8 | def __init__( 9 | self, 10 | n_mel_channels, 11 | sampling_rate, 12 | win_length, 13 | hop_length, 14 | n_fft=None, 15 | mel_fmin=0, 16 | mel_fmax=None, 17 | clamp=1e-5 18 | ): 19 | super().__init__() 20 | n_fft = win_length if n_fft is None else n_fft 21 | self.hann_window = {} 22 | mel_basis = mel( 23 | sr=sampling_rate, 24 | n_fft=n_fft, 25 | n_mels=n_mel_channels, 26 | fmin=mel_fmin, 27 | fmax=mel_fmax, 28 | htk=True) 29 | mel_basis = torch.from_numpy(mel_basis).float() 30 | self.register_buffer("mel_basis", mel_basis) 31 | self.n_fft = win_length if n_fft is None else n_fft 32 | self.hop_length = hop_length 33 | self.win_length = win_length 34 | self.sampling_rate = sampling_rate 35 | self.n_mel_channels = n_mel_channels 36 | self.clamp = clamp 37 | 38 | def forward(self, audio, keyshift=0, speed=1, center=True): 39 | factor = 2 ** (keyshift / 12) 40 | n_fft_new = int(np.round(self.n_fft * factor)) 41 | win_length_new = int(np.round(self.win_length * factor)) 42 | hop_length_new = int(np.round(self.hop_length * speed)) 43 | 44 | keyshift_key = str(keyshift) + '_' + str(audio.device) 45 | if keyshift_key not in self.hann_window: 46 | self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device) 47 | if center: 48 | pad_left = win_length_new // 2 49 | pad_right = (win_length_new + 1) // 2 50 | audio = F.pad(audio, (pad_left, pad_right)) 51 | 52 | fft = torch.stft( 53 | audio, 54 | n_fft=n_fft_new, 55 | hop_length=hop_length_new, 56 | win_length=win_length_new, 57 | window=self.hann_window[keyshift_key], 58 | center=False, 59 | return_complex=True 60 | ) 61 | magnitude = fft.abs() 62 | 63 | if keyshift != 0: 64 | size = self.n_fft // 2 + 1 65 | resize = magnitude.size(1) 66 | if resize < size: 67 | magnitude = F.pad(magnitude, (0, 0, 0, size - resize)) 68 | magnitude = magnitude[:, :size, :] * self.win_length / win_length_new 69 | 70 | mel_output = torch.matmul(self.mel_basis, magnitude) 71 | log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) 72 | return log_mel_spec 73 | -------------------------------------------------------------------------------- /modules/rmvpe/utils.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | import torch 4 | 5 | from .constants import * 6 | 7 | 8 | def to_local_average_f0(hidden, center=None, thred=0.03): 9 | idx = torch.arange(N_CLASS, device=hidden.device)[None, None, :] # [B=1, T=1, N] 10 | idx_cents = idx * 20 + CONST # [B=1, N] 11 | if center is None: 12 | center = torch.argmax(hidden, dim=2, keepdim=True) # [B, T, 1] 13 | start = torch.clip(center - 4, min=0) # [B, T, 1] 14 | end = torch.clip(center + 5, max=N_CLASS) # [B, T, 1] 15 | idx_mask = (idx >= start) & (idx < end) # [B, T, N] 16 | weights = hidden * idx_mask # [B, T, N] 17 | product_sum = torch.sum(weights * idx_cents, dim=2) # [B, T] 18 | weight_sum = torch.sum(weights, dim=2) # [B, T] 19 | cents = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T] 20 | f0 = 10 * 2 ** (cents / 1200) 21 | uv = hidden.max(dim=2)[0] < thred # [B, T] 22 | f0 = f0 * ~uv 23 | return f0.squeeze(0).cpu().numpy() 24 | 25 | 26 | def to_viterbi_f0(hidden, thred=0.03): 27 | # Create viterbi transition matrix 28 | if not hasattr(to_viterbi_f0, 'transition'): 29 | xx, yy = np.meshgrid(range(N_CLASS), range(N_CLASS)) 30 | transition = np.maximum(30 - abs(xx - yy), 0) 31 | transition = transition / transition.sum(axis=1, keepdims=True) 32 | to_viterbi_f0.transition = transition 33 | 34 | # Convert to probability 35 | prob = hidden.squeeze(0).cpu().numpy() 36 | prob = prob.T 37 | prob = prob / prob.sum(axis=0) 38 | 39 | # Perform viterbi decoding 40 | path = librosa.sequence.viterbi(prob, to_viterbi_f0.transition).astype(np.int64) 41 | center = torch.from_numpy(path).unsqueeze(0).unsqueeze(-1).to(hidden.device) 42 | 43 | return to_local_average_f0(hidden, center=center, thred=thred) 44 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # It's recommand to download and install torch manually. 2 | # pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 3 | 4 | gradio 5 | librosa 6 | mido 7 | einops -------------------------------------------------------------------------------- /template.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": 153, 3 | "time": { 4 | "meter": [ 5 | { 6 | "index": 0, 7 | "numerator": 4, 8 | "denominator": 4 9 | } 10 | ], 11 | "tempo": [ 12 | { 13 | "position": 0, 14 | "bpm": 120 15 | } 16 | ] 17 | }, 18 | "library": [], 19 | "tracks": [ 20 | { 21 | "name": "未命名音轨", 22 | "dispColor": "ff7db235", 23 | "dispOrder": 0, 24 | "renderEnabled": false, 25 | "mixer": { 26 | "gainDecibel": 0, 27 | "pan": 0, 28 | "mute": false, 29 | "solo": false, 30 | "display": true 31 | }, 32 | "mainGroup": { 33 | "name": "main", 34 | "uuid": "a6cbc361-ef84-4518-8da4-4ac7640aae8c", 35 | "parameters": { 36 | "pitchDelta": { 37 | "mode": "cubic", 38 | "points": [] 39 | }, 40 | "vibratoEnv": { 41 | "mode": "cubic", 42 | "points": [] 43 | }, 44 | "loudness": { 45 | "mode": "cubic", 46 | "points": [] 47 | }, 48 | "tension": { 49 | "mode": "cubic", 50 | "points": [] 51 | }, 52 | "breathiness": { 53 | "mode": "cubic", 54 | "points": [] 55 | }, 56 | "voicing": { 57 | "mode": "cubic", 58 | "points": [] 59 | }, 60 | "gender": { 61 | "mode": "cubic", 62 | "points": [] 63 | }, 64 | "toneShift": { 65 | "mode": "cubic", 66 | "points": [] 67 | } 68 | }, 69 | "vocalModes": {}, 70 | "notes": [ 71 | { 72 | "musicalType": "singing", 73 | "onset": 0, 74 | "duration": 705600000, 75 | "lyrics": "la", 76 | "phonemes": "", 77 | "accent": "", 78 | "pitch": 60, 79 | "detune": 0, 80 | "instantMode": false, 81 | "attributes": { 82 | "evenSyllableDuration": true 83 | }, 84 | "systemAttributes": { 85 | "evenSyllableDuration": true 86 | }, 87 | "pitchTakes": { 88 | "activeTakeId": 0, 89 | "takes": [ 90 | { 91 | "id": 0, 92 | "expr": 0, 93 | "liked": false 94 | } 95 | ] 96 | }, 97 | "timbreTakes": { 98 | "activeTakeId": 0, 99 | "takes": [ 100 | { 101 | "id": 0, 102 | "expr": 0, 103 | "liked": false 104 | } 105 | ] 106 | } 107 | } 108 | ] 109 | }, 110 | "mainRef": { 111 | "groupID": "a6cbc361-ef84-4518-8da4-4ac7640aae8c", 112 | "blickAbsoluteBegin": 0, 113 | "blickAbsoluteEnd": -1, 114 | "blickOffset": 0, 115 | "pitchOffset": 0, 116 | "isInstrumental": false, 117 | "systemPitchDelta": { 118 | "mode": "cubic", 119 | "points": [] 120 | }, 121 | "database": { 122 | "name": "", 123 | "language": "", 124 | "phoneset": "", 125 | "languageOverride": "", 126 | "phonesetOverride": "", 127 | "backendType": "", 128 | "version": "-2" 129 | }, 130 | "dictionary": "", 131 | "voice": { 132 | "vocalModeInherited": true, 133 | "vocalModePreset": "", 134 | "vocalModeParams": {} 135 | }, 136 | "pitchTakes": { 137 | "activeTakeId": 0, 138 | "takes": [ 139 | { 140 | "id": 0, 141 | "expr": 0, 142 | "liked": false 143 | } 144 | ] 145 | }, 146 | "timbreTakes": { 147 | "activeTakeId": 0, 148 | "takes": [ 149 | { 150 | "id": 0, 151 | "expr": 0, 152 | "liked": false 153 | } 154 | ] 155 | } 156 | }, 157 | "groups": [] 158 | } 159 | ], 160 | "renderConfig": { 161 | "destination": "", 162 | "filename": "未命名", 163 | "numChannels": 1, 164 | "aspirationFormat": "noAspiration", 165 | "bitDepth": 16, 166 | "sampleRate": 44100, 167 | "exportMixDown": true, 168 | "exportPitch": false 169 | } 170 | } -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import pathlib 4 | import re 5 | import types 6 | from collections import OrderedDict 7 | 8 | import numpy as np 9 | import torch 10 | 11 | def tensors_to_scalars(metrics): 12 | new_metrics = {} 13 | for k, v in metrics.items(): 14 | if isinstance(v, torch.Tensor): 15 | v = v.item() 16 | if type(v) is dict: 17 | v = tensors_to_scalars(v) 18 | new_metrics[k] = v 19 | return new_metrics 20 | 21 | 22 | def collate_nd(values, pad_value=0, max_len=None): 23 | """ 24 | Pad a list of Nd tensors on their first dimension and stack them into a (N+1)d tensor. 25 | """ 26 | size = ((max(v.size(0) for v in values) if max_len is None else max_len), *values[0].shape[1:]) 27 | res = torch.full((len(values), *size), fill_value=pad_value, dtype=values[0].dtype, device=values[0].device) 28 | 29 | for i, v in enumerate(values): 30 | res[i, :len(v), ...] = v 31 | return res 32 | 33 | 34 | def random_continuous_masks(*shape: int, dim: int, device: str | torch.device = 'cpu'): 35 | start, end = torch.sort( 36 | torch.randint( 37 | low=0, high=shape[dim] + 1, size=(*shape[:dim], 2, *((1,) * (len(shape) - dim - 1))), device=device 38 | ).expand(*((-1,) * (dim + 1)), *shape[dim + 1:]), dim=dim 39 | )[0].split(1, dim=dim) 40 | idx = torch.arange( 41 | 0, shape[dim], dtype=torch.long, device=device 42 | ).reshape(*((1,) * dim), shape[dim], *((1,) * (len(shape) - dim - 1))) 43 | masks = (idx >= start) & (idx < end) 44 | return masks 45 | 46 | 47 | def _is_batch_full(batch, num_frames, max_batch_frames, max_batch_size): 48 | if len(batch) == 0: 49 | return 0 50 | if len(batch) == max_batch_size: 51 | return 1 52 | if num_frames > max_batch_frames: 53 | return 1 54 | return 0 55 | 56 | 57 | def batch_by_size( 58 | indices, num_frames_fn, max_batch_frames=80000, max_batch_size=48, 59 | required_batch_size_multiple=1 60 | ): 61 | """ 62 | Yield mini-batches of indices bucketed by size. Batches may contain 63 | sequences of different lengths. 64 | 65 | Args: 66 | indices (List[int]): ordered list of dataset indices 67 | num_frames_fn (callable): function that returns the number of frames at 68 | a given index 69 | max_batch_frames (int, optional): max number of frames in each batch 70 | (default: 80000). 71 | max_batch_size (int, optional): max number of sentences in each 72 | batch (default: 48). 73 | required_batch_size_multiple: require the batch size to be multiple 74 | of a given number 75 | """ 76 | bsz_mult = required_batch_size_multiple 77 | 78 | if isinstance(indices, types.GeneratorType): 79 | indices = np.fromiter(indices, dtype=np.int64, count=-1) 80 | 81 | sample_len = 0 82 | sample_lens = [] 83 | batch = [] 84 | batches = [] 85 | for i in range(len(indices)): 86 | idx = indices[i] 87 | num_frames = num_frames_fn(idx) 88 | sample_lens.append(num_frames) 89 | sample_len = max(sample_len, num_frames) 90 | assert sample_len <= max_batch_frames, ( 91 | "sentence at index {} of size {} exceeds max_batch_samples " 92 | "limit of {}!".format(idx, sample_len, max_batch_frames) 93 | ) 94 | num_frames = (len(batch) + 1) * sample_len 95 | 96 | if _is_batch_full(batch, num_frames, max_batch_frames, max_batch_size): 97 | mod_len = max( 98 | bsz_mult * (len(batch) // bsz_mult), 99 | len(batch) % bsz_mult, 100 | ) 101 | batches.append(batch[:mod_len]) 102 | batch = batch[mod_len:] 103 | sample_lens = sample_lens[mod_len:] 104 | sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 105 | batch.append(idx) 106 | if len(batch) > 0: 107 | batches.append(batch) 108 | return batches 109 | 110 | 111 | def unpack_dict_to_list(samples): 112 | samples_ = [] 113 | bsz = samples.get('outputs').size(0) 114 | for i in range(bsz): 115 | res = {} 116 | for k, v in samples.items(): 117 | try: 118 | res[k] = v[i] 119 | except: 120 | pass 121 | samples_.append(res) 122 | return samples_ 123 | 124 | 125 | def filter_kwargs(dict_to_filter, kwarg_obj): 126 | import inspect 127 | 128 | sig = inspect.signature(kwarg_obj) 129 | filter_keys = [param.name for param in sig.parameters.values() if param.kind == param.POSITIONAL_OR_KEYWORD] 130 | filtered_dict = {filter_key: dict_to_filter[filter_key] for filter_key in filter_keys if 131 | filter_key in dict_to_filter} 132 | return filtered_dict 133 | 134 | 135 | def load_ckpt( 136 | cur_model, ckpt_base_dir, ckpt_steps=None, 137 | prefix_in_ckpt='model', key_in_ckpt='state_dict', 138 | strict=True, device='cpu' 139 | ): 140 | if not isinstance(ckpt_base_dir, pathlib.Path): 141 | ckpt_base_dir = pathlib.Path(ckpt_base_dir) 142 | if ckpt_base_dir.is_file(): 143 | checkpoint_path = [ckpt_base_dir] 144 | elif ckpt_steps is not None: 145 | checkpoint_path = [ckpt_base_dir / f'model_ckpt_steps_{int(ckpt_steps)}.ckpt'] 146 | else: 147 | base_dir = ckpt_base_dir 148 | checkpoint_path = sorted( 149 | [ 150 | ckpt_file 151 | for ckpt_file in base_dir.iterdir() 152 | if ckpt_file.is_file() and re.fullmatch(r'model_ckpt_steps_\d+\.ckpt', ckpt_file.name) 153 | ], 154 | key=lambda x: int(re.search(r'\d+', x.name).group(0)) 155 | ) 156 | assert len(checkpoint_path) > 0, f'ckpt not found in {ckpt_base_dir}.' 157 | checkpoint_path = checkpoint_path[-1] 158 | ckpt_loaded = torch.load(checkpoint_path, map_location=device) 159 | if key_in_ckpt is None: 160 | state_dict = ckpt_loaded 161 | else: 162 | state_dict = ckpt_loaded[key_in_ckpt] 163 | if prefix_in_ckpt is not None: 164 | state_dict = OrderedDict({ 165 | k[len(prefix_in_ckpt) + 1:]: v 166 | for k, v in state_dict.items() if k.startswith(f'{prefix_in_ckpt}.') 167 | }) 168 | if not strict: 169 | cur_model_state_dict = cur_model.state_dict() 170 | unmatched_keys = [] 171 | for key, param in state_dict.items(): 172 | if key in cur_model_state_dict: 173 | new_param = cur_model_state_dict[key] 174 | if new_param.shape != param.shape: 175 | unmatched_keys.append(key) 176 | print('Unmatched keys: ', key, new_param.shape, param.shape) 177 | for key in unmatched_keys: 178 | del state_dict[key] 179 | cur_model.load_state_dict(state_dict, strict=strict) 180 | shown_model_name = 'state dict' 181 | if prefix_in_ckpt is not None: 182 | shown_model_name = f'\'{prefix_in_ckpt}\'' 183 | elif key_in_ckpt is not None: 184 | shown_model_name = f'\'{key_in_ckpt}\'' 185 | print(f'load {shown_model_name} from \'{checkpoint_path}\'.') 186 | 187 | 188 | def remove_padding(x, padding_idx=0): 189 | if x is None: 190 | return None 191 | assert len(x.shape) in [1, 2] 192 | if len(x.shape) == 2: # [T, H] 193 | return x[np.abs(x).sum(-1) != padding_idx] 194 | elif len(x.shape) == 1: # [T] 195 | return x[x != padding_idx] 196 | 197 | 198 | def print_arch(model, model_name='model'): 199 | print(f"{model_name} Arch: ", model) 200 | # num_params(model, model_name=model_name) 201 | 202 | 203 | def num_params(model, print_out=True, model_name="model"): 204 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 205 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 206 | if print_out: 207 | print(f'{model_name} Trainable Parameters: %.3fM' % parameters) 208 | return parameters 209 | 210 | 211 | def build_object_from_class_name(cls_str, parent_cls, *args, **kwargs): 212 | import importlib 213 | 214 | pkg = ".".join(cls_str.split(".")[:-1]) 215 | cls_name = cls_str.split(".")[-1] 216 | cls_type = getattr(importlib.import_module(pkg), cls_name) 217 | if parent_cls is not None: 218 | assert issubclass(cls_type, parent_cls), f'{cls_type} is not subclass of {parent_cls}.' 219 | 220 | return cls_type(*args, **filter_kwargs(kwargs, cls_type)) 221 | 222 | 223 | def build_lr_scheduler_from_config(optimizer, scheduler_args): 224 | try: 225 | # PyTorch 2.0+ 226 | from torch.optim.lr_scheduler import LRScheduler as LRScheduler 227 | except ImportError: 228 | # PyTorch 1.X 229 | from torch.optim.lr_scheduler import _LRScheduler as LRScheduler 230 | 231 | def helper(params): 232 | if isinstance(params, list): 233 | return [helper(s) for s in params] 234 | elif isinstance(params, dict): 235 | resolved = {k: helper(v) for k, v in params.items()} 236 | if 'cls' in resolved: 237 | if ( 238 | resolved["cls"] == "torch.optim.lr_scheduler.ChainedScheduler" 239 | and scheduler_args["scheduler_cls"] == "torch.optim.lr_scheduler.SequentialLR" 240 | ): 241 | raise ValueError(f"ChainedScheduler cannot be part of a SequentialLR.") 242 | resolved['optimizer'] = optimizer 243 | obj = build_object_from_class_name( 244 | resolved['cls'], 245 | LRScheduler, 246 | **resolved 247 | ) 248 | return obj 249 | return resolved 250 | else: 251 | return params 252 | 253 | resolved = helper(scheduler_args) 254 | resolved['optimizer'] = optimizer 255 | return build_object_from_class_name( 256 | scheduler_args['scheduler_cls'], 257 | LRScheduler, 258 | **resolved 259 | ) 260 | 261 | 262 | def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_groups=1): 263 | optimizer = build_object_from_class_name( 264 | optimizer_args['optimizer_cls'], 265 | torch.optim.Optimizer, 266 | [{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)], 267 | **optimizer_args 268 | ) 269 | scheduler = build_lr_scheduler_from_config(optimizer, scheduler_args) 270 | scheduler.optimizer._step_count = 1 271 | for _ in range(step_count): 272 | scheduler.step() 273 | return scheduler.state_dict() 274 | 275 | 276 | def remove_suffix(string: str, suffix: str): 277 | # Just for Python 3.8 compatibility, since `str.removesuffix()` API of is available since Python 3.9 278 | if string.endswith(suffix): 279 | string = string[:-len(suffix)] 280 | return string 281 | -------------------------------------------------------------------------------- /utils/infer_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import mido 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def decode_gaussian_blurred_probs(probs, vmin, vmax, deviation, threshold): 10 | num_bins = probs.shape[-1] 11 | interval = (vmax - vmin) / (num_bins - 1) 12 | width = int(3 * deviation / interval) # 3 * sigma 13 | idx = torch.arange(num_bins, device=probs.device)[None, None, :] # [1, 1, N] 14 | idx_values = idx * interval + vmin 15 | center = torch.argmax(probs, dim=-1, keepdim=True) # [B, T, 1] 16 | start = torch.clip(center - width, min=0) # [B, T, 1] 17 | end = torch.clip(center + width + 1, max=num_bins) # [B, T, 1] 18 | idx_masks = (idx >= start) & (idx < end) # [B, T, N] 19 | weights = probs * idx_masks # [B, T, N] 20 | product_sum = torch.sum(weights * idx_values, dim=2) # [B, T] 21 | weight_sum = torch.sum(weights, dim=2) # [B, T] 22 | values = product_sum / (weight_sum + (weight_sum == 0)) # avoid dividing by zero, [B, T] 23 | rest = probs.max(dim=-1)[0] < threshold # [B, T] 24 | return values, rest 25 | 26 | 27 | def decode_bounds_to_alignment(bounds): 28 | bounds_step = bounds.cumsum(dim=1).round().long() 29 | bounds_inc = torch.diff( 30 | bounds_step, dim=1, prepend=torch.full( 31 | (bounds.shape[0], 1), fill_value=-1, 32 | dtype=bounds_step.dtype, device=bounds_step.device 33 | ) 34 | ) > 0 35 | frame2item = bounds_inc.long().cumsum(dim=1) 36 | return frame2item 37 | 38 | 39 | def decode_note_sequence(frame2item, values, masks, threshold=0.5): 40 | """ 41 | 42 | :param frame2item: [1, 1, 1, 1, 2, 2, 3, 3, 3] 43 | :param values: 44 | :param masks: 45 | :param threshold: minimum ratio of unmasked frames required to be regarded as an unmasked item 46 | :return: item_values, item_dur, item_masks 47 | """ 48 | b = frame2item.shape[0] 49 | space = frame2item.max() + 1 50 | 51 | item_dur = frame2item.new_zeros(b, space).scatter_add( 52 | 1, frame2item, torch.ones_like(frame2item) 53 | )[:, 1:] 54 | item_unmasked_dur = frame2item.new_zeros(b, space).scatter_add( 55 | 1, frame2item, masks.long() 56 | )[:, 1:] 57 | item_masks = item_unmasked_dur / item_dur >= threshold 58 | 59 | values_quant = values.round().long() 60 | histogram = frame2item.new_zeros(b, space * 128).scatter_add( 61 | 1, frame2item * 128 + values_quant, torch.ones_like(frame2item) * masks 62 | ).unflatten(1, [space, 128])[:, 1:, :] 63 | item_values_center = histogram.argmax(dim=2).to(dtype=values.dtype) 64 | values_center = torch.gather(F.pad(item_values_center, [1, 0]), 1, frame2item) 65 | values_near_center = masks & (values >= values_center - 0.5) & (values <= values_center + 0.5) 66 | item_valid_dur = frame2item.new_zeros(b, space).scatter_add( 67 | 1, frame2item, values_near_center.long() 68 | )[:, 1:] 69 | item_values = values.new_zeros(b, space).scatter_add( 70 | 1, frame2item, values * values_near_center 71 | )[:, 1:] / (item_valid_dur + (item_valid_dur == 0)) 72 | 73 | return item_values, item_dur, item_masks 74 | 75 | 76 | def build_midi_file(offsets: List[float], segments: List[Dict[str, np.ndarray]], tempo=120) -> mido.MidiFile: 77 | midi_file = mido.MidiFile(charset='utf8') 78 | midi_track = mido.MidiTrack() 79 | midi_track.append(mido.MetaMessage('set_tempo', tempo=mido.bpm2tempo(tempo), time=0)) 80 | last_time = 0 81 | offsets = [round(o * tempo * 8) for o in offsets] 82 | for i, (offset, segment) in enumerate(zip(offsets, segments)): 83 | note_midi = np.round(segment['note_midi']).astype(np.int64).tolist() 84 | note_tick = np.diff(np.round(np.cumsum(segment['note_dur']) * tempo * 8).astype(np.int64), prepend=0).tolist() 85 | note_rest = segment['note_rest'].tolist() 86 | start = offset 87 | for j in range(len(note_midi)): 88 | end = start + note_tick[j] 89 | if i < len(offsets) - 1 and end > offsets[i + 1]: 90 | end = offsets[i + 1] 91 | if start < end and not note_rest[j]: 92 | midi_track.append(mido.Message('note_on', note=note_midi[j], time=start - last_time)) 93 | midi_track.append(mido.Message('note_off', note=note_midi[j], time=end - start)) 94 | last_time = end 95 | start = end 96 | midi_file.tracks.append(midi_track) 97 | return midi_file 98 | -------------------------------------------------------------------------------- /utils/pitch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | f0_bin = 256 5 | f0_max = 1100.0 6 | f0_min = 50.0 7 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 8 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 9 | 10 | 11 | def f0_to_coarse(f0): 12 | is_torch = isinstance(f0, torch.Tensor) 13 | f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) 14 | f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 15 | 16 | f0_mel[f0_mel <= 1] = 1 17 | f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 18 | f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) 19 | assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) 20 | return f0_coarse 21 | 22 | 23 | def norm_f0(f0, uv=None): 24 | if uv is None: 25 | uv = f0 == 0 26 | f0 = np.log2(f0 + uv) # avoid arithmetic error 27 | f0[uv] = -np.inf 28 | return f0 29 | 30 | 31 | def interp_f0(f0, uv=None): 32 | if uv is None: 33 | uv = f0 == 0 34 | f0 = norm_f0(f0, uv) 35 | if uv.any() and not uv.all(): 36 | f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) 37 | return denorm_f0(f0, uv=None), uv 38 | 39 | 40 | def denorm_f0(f0, uv, pitch_padding=None): 41 | f0 = 2 ** f0 42 | if uv is not None: 43 | f0[uv > 0] = 0 44 | if pitch_padding is not None: 45 | f0[pitch_padding] = 0 46 | return f0 47 | 48 | 49 | def resample_align_curve(points: np.ndarray, original_timestep: float, target_timestep: float, align_length: int): 50 | t_max = (len(points) - 1) * original_timestep 51 | curve_interp = np.interp( 52 | np.arange(0, t_max, target_timestep), 53 | original_timestep * np.arange(len(points)), 54 | points 55 | ).astype(points.dtype) 56 | delta_l = align_length - len(curve_interp) 57 | if delta_l < 0: 58 | curve_interp = curve_interp[:align_length] 59 | elif delta_l > 0: 60 | curve_interp = np.concatenate((curve_interp, np.full(delta_l, fill_value=curve_interp[-1])), axis=0) 61 | return curve_interp 62 | -------------------------------------------------------------------------------- /utils/slicer2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # This function is obtained from librosa. 5 | def get_rms( 6 | y, 7 | *, 8 | frame_length=2048, 9 | hop_length=512, 10 | pad_mode="constant", 11 | ): 12 | padding = (int(frame_length // 2), int(frame_length // 2)) 13 | y = np.pad(y, padding, mode=pad_mode) 14 | 15 | axis = -1 16 | # put our new within-frame axis at the end for now 17 | out_strides = y.strides + tuple([y.strides[axis]]) 18 | # Reduce the shape on the framing axis 19 | x_shape_trimmed = list(y.shape) 20 | x_shape_trimmed[axis] -= frame_length - 1 21 | out_shape = tuple(x_shape_trimmed) + tuple([frame_length]) 22 | xw = np.lib.stride_tricks.as_strided( 23 | y, shape=out_shape, strides=out_strides 24 | ) 25 | if axis < 0: 26 | target_axis = axis - 1 27 | else: 28 | target_axis = axis + 1 29 | xw = np.moveaxis(xw, -1, target_axis) 30 | # Downsample along the target axis 31 | slices = [slice(None)] * xw.ndim 32 | slices[axis] = slice(0, None, hop_length) 33 | x = xw[tuple(slices)] 34 | 35 | # Calculate power 36 | power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True) 37 | 38 | return np.sqrt(power) 39 | 40 | 41 | class Slicer: 42 | def __init__(self, 43 | sr: int, 44 | threshold: float = -40., 45 | min_length: int = 5000, 46 | min_interval: int = 300, 47 | hop_size: int = 20, 48 | max_sil_kept: int = 5000): 49 | if not min_length >= min_interval >= hop_size: 50 | raise ValueError('The following condition must be satisfied: min_length >= min_interval >= hop_size') 51 | if not max_sil_kept >= hop_size: 52 | raise ValueError('The following condition must be satisfied: max_sil_kept >= hop_size') 53 | min_interval = sr * min_interval / 1000 54 | self.sr = sr 55 | self.threshold = 10 ** (threshold / 20.) 56 | self.hop_size = round(sr * hop_size / 1000) 57 | self.win_size = min(round(min_interval), 4 * self.hop_size) 58 | self.min_length = round(sr * min_length / 1000 / self.hop_size) 59 | self.min_interval = round(min_interval / self.hop_size) 60 | self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) 61 | 62 | def _apply_slice(self, waveform, begin, end): 63 | chunk = { 64 | 'offset': begin * self.hop_size / self.sr 65 | } 66 | if len(waveform.shape) > 1: 67 | chunk['waveform'] = waveform[:, begin * self.hop_size: min(waveform.shape[1], end * self.hop_size)] 68 | else: 69 | chunk['waveform'] = waveform[begin * self.hop_size: min(waveform.shape[0], end * self.hop_size)] 70 | return chunk 71 | 72 | # @timeit 73 | def slice(self, waveform): 74 | if len(waveform.shape) > 1: 75 | samples = waveform.mean(axis=0) 76 | else: 77 | samples = waveform 78 | if (samples.shape[0] + self.hop_size - 1) // self.hop_size <= self.min_length: 79 | return [{'offset': 0, 'waveform': waveform}] 80 | rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0) 81 | sil_tags = [] 82 | silence_start = None 83 | clip_start = 0 84 | for i, rms in enumerate(rms_list): 85 | # Keep looping while frame is silent. 86 | if rms < self.threshold: 87 | # Record start of silent frames. 88 | if silence_start is None: 89 | silence_start = i 90 | continue 91 | # Keep looping while frame is not silent and silence start has not been recorded. 92 | if silence_start is None: 93 | continue 94 | # Clear recorded silence start if interval is not enough or clip is too short 95 | is_leading_silence = silence_start == 0 and i > self.max_sil_kept 96 | need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length 97 | if not is_leading_silence and not need_slice_middle: 98 | silence_start = None 99 | continue 100 | # Need slicing. Record the range of silent frames to be removed. 101 | if i - silence_start <= self.max_sil_kept: 102 | pos = rms_list[silence_start: i + 1].argmin() + silence_start 103 | if silence_start == 0: 104 | sil_tags.append((0, pos)) 105 | else: 106 | sil_tags.append((pos, pos)) 107 | clip_start = pos 108 | elif i - silence_start <= self.max_sil_kept * 2: 109 | pos = rms_list[i - self.max_sil_kept: silence_start + self.max_sil_kept + 1].argmin() 110 | pos += i - self.max_sil_kept 111 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start 112 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept 113 | if silence_start == 0: 114 | sil_tags.append((0, pos_r)) 115 | clip_start = pos_r 116 | else: 117 | sil_tags.append((min(pos_l, pos), max(pos_r, pos))) 118 | clip_start = max(pos_r, pos) 119 | else: 120 | pos_l = rms_list[silence_start: silence_start + self.max_sil_kept + 1].argmin() + silence_start 121 | pos_r = rms_list[i - self.max_sil_kept: i + 1].argmin() + i - self.max_sil_kept 122 | if silence_start == 0: 123 | sil_tags.append((0, pos_r)) 124 | else: 125 | sil_tags.append((pos_l, pos_r)) 126 | clip_start = pos_r 127 | silence_start = None 128 | # Deal with trailing silence. 129 | total_frames = rms_list.shape[0] 130 | if silence_start is not None and total_frames - silence_start >= self.min_interval: 131 | silence_end = min(total_frames, silence_start + self.max_sil_kept) 132 | pos = rms_list[silence_start: silence_end + 1].argmin() + silence_start 133 | sil_tags.append((pos, total_frames + 1)) 134 | # Apply and return slices. 135 | if len(sil_tags) == 0: 136 | return [{'offset': 0, 'waveform': waveform}] 137 | else: 138 | chunks = [] 139 | if sil_tags[0][0] > 0: 140 | chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0])) 141 | for i in range(len(sil_tags) - 1): 142 | chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0])) 143 | if sil_tags[-1][1] < total_frames: 144 | chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames)) 145 | return chunks 146 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from infer import wav2svp 3 | 4 | def inference(input, bpm, extract_pitch, extract_tension, extract_breathiness): 5 | model_path = "weights/model_steps_64000_simplified.ckpt" 6 | return wav2svp(input, model_path, bpm, extract_pitch, extract_tension, extract_breathiness) 7 | 8 | def webui(): 9 | with gr.Blocks() as webui: 10 | gr.Markdown('''
wav2svp - Waveform to Synthesizer V Project
''') 11 | gr.Markdown("Upload an audio file and download the svp file with midi and selected datas.") 12 | with gr.Row(): 13 | with gr.Column(): 14 | input = gr.File(label="Input Audio File", type="filepath") 15 | bpm = gr.Number(label='BPM Value', minimum=20, maximum=200, value=120, step=0.01, interactive=True) 16 | extract_pitch = gr.Checkbox(label="Extract Pitch Data", value=True) 17 | extract_tension = gr.Checkbox(label="Extract Tension Data (Experimental)", value=False) 18 | extract_breathiness = gr.Checkbox(label="Extract Breathiness Data (Experimental)", value=False) 19 | run = gr.Button(value="Generate svp File", variant="primary") 20 | with gr.Column(): 21 | output_svp = gr.File(label="Output svp File", type="filepath", interactive=False) 22 | output_midi = gr.File(label="Output midi File", type="filepath", interactive=False) 23 | run.click(inference, [input, bpm, extract_pitch, extract_tension, extract_breathiness], [output_svp, output_midi]) 24 | webui.launch(inbrowser=True) 25 | 26 | if __name__ == '__main__': 27 | webui() -------------------------------------------------------------------------------- /weights/config.yaml: -------------------------------------------------------------------------------- 1 | accumulate_grad_batches: 1 2 | audio_sample_rate: 44100 3 | binarization_args: 4 | num_workers: 0 5 | shuffle: true 6 | binarizer_cls: preprocessing.MIDIExtractionBinarizer 7 | binary_data_dir: data/some_ds_fixmel_spk3_aug8/binary 8 | clip_grad_norm: 1 9 | dataloader_prefetch_factor: 2 10 | ddp_backend: nccl 11 | ds_workers: 4 12 | finetune_ckpt_path: null 13 | finetune_enabled: false 14 | finetune_ignored_params: [] 15 | finetune_strict_shapes: true 16 | fmax: 8000 17 | fmin: 40 18 | freezing_enabled: false 19 | frozen_params: [] 20 | hop_size: 512 21 | log_interval: 100 22 | lr_scheduler_args: 23 | min_lr: 1.0e-05 24 | scheduler_cls: lr_scheduler.scheduler.WarmupLR 25 | warmup_steps: 5000 26 | max_batch_frames: 80000 27 | max_batch_size: 8 28 | max_updates: 10000000 29 | max_val_batch_frames: 10000 30 | max_val_batch_size: 1 31 | midi_extractor_args: 32 | attention_drop: 0.1 33 | attention_heads: 8 34 | attention_heads_dim: 64 35 | conv_drop: 0.1 36 | dim: 512 37 | ffn_latent_drop: 0.1 38 | ffn_out_drop: 0.1 39 | kernel_size: 31 40 | lay: 8 41 | use_lay_skip: true 42 | midi_max: 128 43 | midi_min: 0 44 | midi_num_bins: 256 45 | midi_prob_deviation: 0.5 46 | midi_shift_proportion: 0.0 47 | midi_shift_range: 48 | - -6 49 | - 6 50 | model_cls: modules.model.Gmidi_conform.midi_conforms 51 | num_ckpt_keep: 5 52 | num_sanity_val_steps: 1 53 | num_valid_plots: 300 54 | optimizer_args: 55 | beta1: 0.9 56 | beta2: 0.98 57 | lr: 0.0001 58 | optimizer_cls: torch.optim.AdamW 59 | weight_decay: 0 60 | pe: rmvpe 61 | pe_ckpt: pretrained/rmvpe/model.pt 62 | permanent_ckpt_interval: 40000 63 | permanent_ckpt_start: 200000 64 | pl_trainer_accelerator: auto 65 | pl_trainer_devices: auto 66 | pl_trainer_num_nodes: 1 67 | pl_trainer_precision: 32-true 68 | pl_trainer_strategy: auto 69 | raw_data_dir: [] 70 | rest_threshold: 0.1 71 | sampler_frame_count_grid: 6 72 | seed: 114514 73 | sort_by_len: true 74 | task_cls: training.MIDIExtractionTask 75 | test_prefixes: null 76 | train_set_name: train 77 | units_dim: 80 78 | units_encoder: mel 79 | units_encoder_ckpt: pretrained/contentvec/checkpoint_best_legacy_500.pt 80 | use_buond_loss: true 81 | use_midi_loss: true 82 | val_check_interval: 4000 83 | valid_set_name: valid 84 | win_size: 2048 85 | --------------------------------------------------------------------------------