├── .gitignore ├── LICENSE ├── README.md ├── setup.py ├── torchfcpe ├── __init__.py ├── assets │ └── fcpe_c_v001.pt ├── f02midi │ ├── MIDI.py │ ├── featureExtraction.py │ ├── quantization.py │ ├── transpose.py │ └── utils.py ├── mel_extractor.py ├── mel_fn_librosa.py ├── model_conformer_naive.py ├── model_convnext.py ├── models.py ├── models_infer.py ├── tools.py └── torch_interp.py └── train ├── configs └── config.yaml ├── data_loaders_wav.py ├── draw.py ├── pre_data.py ├── redis_coder.py ├── savertools ├── __init__.py ├── saver.py └── utils.py ├── solver_wav.py ├── train_wav.py ├── utils_1.py └── utils_all.py /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | __pycache__/ 3 | torchfcpe.egg-info/ 4 | 5 | test** -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 CN_ChiTu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

TorchFCPE

2 | 3 | ## Overview 4 | 5 | TorchFCPE(Fast Context-based Pitch Estimation) is a PyTorch-based library designed for audio pitch extraction and MIDI conversion. This README provides a quick guide on how to use the library for audio pitch inference and MIDI extraction. 6 | 7 | Note: that the MIDI extractor of FCPE is quantized from f0 using non neural network methods 8 | 9 | Note: I won't be updating FCPE (or benchmark) so soon, but I will definitely release a version with cleaned-up code by no later than next year. 10 | 11 | ## Installation 12 | 13 | Before using the library, make sure you have the necessary dependencies installed: 14 | 15 | ```bash 16 | pip install torchfcpe 17 | ``` 18 | 19 | ## Usage 20 | 21 | ### 1. Audio Pitch Inference 22 | 23 | ```python 24 | from torchfcpe import spawn_bundled_infer_model 25 | import torch 26 | import librosa 27 | 28 | # Configure device and target hop size 29 | device = 'cpu' # or 'cuda' if using a GPU 30 | sr = 16000 # Sample rate 31 | hop_size = 160 # Hop size for processing 32 | 33 | # Load and preprocess audio 34 | audio, sr = librosa.load('test.wav', sr=sr) 35 | audio = librosa.to_mono(audio) 36 | audio_length = len(audio) 37 | f0_target_length = (audio_length // hop_size) + 1 38 | audio = torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(-1).to(device) 39 | 40 | # Load the model 41 | model = spawn_bundled_infer_model(device=device) 42 | 43 | # Perform pitch inference 44 | f0 = model.infer( 45 | audio, 46 | sr=sr, 47 | decoder_mode='local_argmax', # Recommended mode 48 | threshold=0.006, # Threshold for V/UV decision 49 | f0_min=80, # Minimum pitch 50 | f0_max=880, # Maximum pitch 51 | interp_uv=False, # Interpolate unvoiced frames 52 | output_interp_target_length=f0_target_length, # Interpolate to target length 53 | ) 54 | 55 | print(f0) 56 | ``` 57 | 58 | ### 2. MIDI Extraction 59 | 60 | ```python 61 | # Extract MIDI from audio 62 | midi = model.extact_midi( 63 | audio, 64 | sr=sr, 65 | decoder_mode='local_argmax', # Recommended mode 66 | threshold=0.006, # Threshold for V/UV decision 67 | f0_min=80, # Minimum pitch 68 | f0_max=880, # Maximum pitch 69 | output_path="test.mid", # Save MIDI to file 70 | ) 71 | 72 | print(midi) 73 | ``` 74 | 75 | ### Notes 76 | 77 | - **Inference Parameters:** 78 | 79 | - `audio`: Input audio as a `torch.Tensor`. 80 | - `sr`: Sample rate of the audio. 81 | - `decoder_mode` (Optional): Mode for decoding, 'local_argmax' is recommended. 82 | - `threshold` (Optional): Threshold for voice/unvoiced decision; default is 0.006. 83 | - `f0_min` (Optional): Minimum pitch value; default is 80 Hz. 84 | - `f0_max` (Optional): Maximum pitch value; default is 880 Hz. 85 | - `interp_uv` (Optional): Whether to interpolate unvoiced frames; default is False. 86 | - `output_interp_target_length` (Optional): Length to which the output pitch should be interpolated. 87 | 88 | - **MIDI Extraction Parameters:** 89 | - `audio`: Input audio as a `torch.Tensor`. 90 | - `sr`: Sample rate of the audio. 91 | - `decoder_mode` (Optional): Mode for decoding; 'local_argmax' is recommended. 92 | - `threshold` (Optional): Threshold for voice/unvoiced decision; default is 0.006. 93 | - `f0_min` (Optional): Minimum pitch value; default is 80 Hz. 94 | - `f0_max` (Optional): Maximum pitch value; default is 880 Hz. 95 | - `output_path` (Optional): File path to save the MIDI file. If not provided, only returns the MIDI structure. 96 | - `tempo` (Optional): BPM for the MIDI file. If None, BPM is automatically predicted. 97 | 98 | ## Additional Features 99 | 100 | - **Model as a PyTorch Module:** 101 | You can use the model as a standard PyTorch module. For example: 102 | 103 | ```python 104 | # Change device 105 | model = model.to(device) 106 | 107 | # Compile model 108 | model = torch.compile(model) 109 | ``` 110 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | 4 | with open('README.md', encoding='utf8') as file: 5 | long_description = file.read() 6 | 7 | 8 | setup( 9 | name='torchfcpe', 10 | description='The official Pytorch implementation of Fast Context-based Pitch Estimation (FCPE)', 11 | version='0.0.4', 12 | author='CNChTu', 13 | author_email='2921046558@qq.com', 14 | url='https://github.com/CNChTu/FCPE', 15 | install_requires=['einops', 'local_attention', 'torch', 'torchaudio', 'numpy', 'scipy', 'librosa', 'pydub', 'pretty_midi'], 16 | packages=['torchfcpe'], 17 | package_data={'torchfcpe': ['assets/*']}, 18 | long_description=long_description, 19 | long_description_content_type='text/markdown', 20 | keywords=['pitch', 'audio', 'speech', 'music', 'pytorch', 'fcpe'], 21 | classifiers=['License :: OSI Approved :: MIT License'], 22 | license='MIT') 23 | -------------------------------------------------------------------------------- /torchfcpe/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import ( 2 | spawn_wav2mel, 3 | ) 4 | from .models_infer import ( 5 | spawn_model, 6 | spawn_infer_model_from_pt, 7 | spawn_infer_model_from_onnx, 8 | spawn_bundled_infer_model, 9 | bundled_infer_model_unit_test 10 | ) 11 | -------------------------------------------------------------------------------- /torchfcpe/assets/fcpe_c_v001.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CNChTu/FCPE/d55c1b636dc3564dc35be19675998688931e870c/torchfcpe/assets/fcpe_c_v001.pt -------------------------------------------------------------------------------- /torchfcpe/f02midi/MIDI.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import pretty_midi 3 | import numpy as np 4 | import librosa.display 5 | 6 | 7 | #%% 8 | def plot_piano_roll(pm, start_pitch, end_pitch, fs=100): 9 | """ Plot piano roll from .mid file 10 | ---------- 11 | Parameters: 12 | pm: RWC, MDB, iKala, DSD100 13 | start/end_pitch: lowest/highest note (float) 14 | fs: sampling freq. (int) 15 | 16 | """ 17 | # Use librosa's specshow function for displaying the piano roll 18 | librosa.display.specshow( 19 | pm.get_piano_roll(fs)[start_pitch:end_pitch], 20 | hop_length=1, 21 | sr=fs, 22 | x_axis="time", 23 | y_axis="cqt_note", 24 | fmin=pretty_midi.note_number_to_hz(start_pitch), 25 | ) 26 | 27 | 28 | def midi_to_note(file_name, pitch_shift, fs=100, start_note=40, end_note=95): 29 | """ Convert .mid to note 30 | ---------- 31 | Parameters: 32 | file_name: '.mid' (str) 33 | pitch_sifht: shift the pitch to adjust notes correctly (int) 34 | fs: sampling freq. (int) 35 | start/end_pitch: lowest/highest note(int) 36 | 37 | ---------- 38 | Returns: 39 | notes: note/10ms (array) 40 | """ 41 | 42 | pm = pretty_midi.PrettyMIDI(file_name) 43 | frame_note = pm.get_piano_roll(fs)[start_note:end_note] 44 | 45 | length_audio = frame_note.shape[1] 46 | notes = np.zeros(length_audio) 47 | 48 | for i in range(length_audio): 49 | note_tmp = np.argmax(frame_note[:, i]) 50 | if note_tmp > 0: 51 | notes[i] = (note_tmp + start_note) + pitch_shift 52 | # note[i] = 2 ** ((note_tmp -69) / 12.) * 440 53 | return notes 54 | 55 | 56 | def midi_to_segment(filename): 57 | """ Convert .mid to segment 58 | ---------- 59 | Parameters: 60 | filename: .mid (str) 61 | 62 | ---------- 63 | Returns: 64 | segments: [start(s),end(s),pitch] (list) 65 | """ 66 | 67 | pm = pretty_midi.PrettyMIDI(filename) 68 | segment = [] 69 | for note in pm.instruments[0].notes: 70 | segment.append([note.start, note.end, note.pitch]) 71 | return segment 72 | 73 | 74 | def segment_to_midi(segments, path_output, tempo=120): 75 | """ Convert segment to .mid 76 | ---------- 77 | Parameters: 78 | segments: [start(s),end(s),pitch] (list) 79 | path_output: path of save file (str) 80 | """ 81 | pm = pretty_midi.PrettyMIDI(initial_tempo=int(tempo)) 82 | inst_program = pretty_midi.instrument_name_to_program("Acoustic Grand Piano") 83 | inst = pretty_midi.Instrument(program=inst_program) 84 | for segment in segments: 85 | note = pretty_midi.Note( 86 | velocity=100, start=segment[0], end=segment[1], pitch=np.int32(segment[2]) 87 | ) 88 | inst.notes.append(note) 89 | pm.instruments.append(inst) 90 | pm.write(f"{path_output}") 91 | 92 | 93 | def note_to_segment(note): 94 | """ Convert note to segment 95 | ---------- 96 | Parameters: 97 | note: note/10ms (array) 98 | ---------- 99 | Returns: 100 | segments: [start(s),end(s),pitch] (list) 101 | """ 102 | startSeg = [] 103 | endSeg = [] 104 | notes = [] 105 | flag = -1 106 | 107 | if note[0] > 0: 108 | startSeg.append(0) 109 | notes.append(np.int32(note[0])) 110 | flag *= -1 111 | for i in range(0, len(note) - 1): 112 | if note[i] != note[i + 1]: 113 | if flag < 0: 114 | startSeg.append(0.01 * (i + 1)) 115 | notes.append(np.int32(note[i + 1])) 116 | flag *= -1 117 | else: 118 | if note[i + 1] == 0: 119 | endSeg.append(0.01 * i) 120 | flag *= -1 121 | else: 122 | endSeg.append(0.01 * i) 123 | startSeg.append(0.01 * (i + 1)) 124 | notes.append(np.int32(note[i + 1])) 125 | 126 | return list(zip(startSeg, endSeg, notes)) 127 | 128 | 129 | def note2Midi(frame_level_pitchscroe, path_output, tempo): 130 | # note = np.loadtxt(path_input_note) 131 | # note = note[:, 1] 132 | segment = note_to_segment(frame_level_pitchscroe) 133 | segment_to_midi(segment, path_output=path_output, tempo=tempo) 134 | 135 | 136 | # def note2Midi(path_input_note, path_output, tempo): 137 | # note = np.loadtxt(path_input_note) 138 | # note = note[:, 1] 139 | # segment = note_to_segment(note) 140 | # segment_to_midi(segment, path_output=path_output, tempo=tempo) 141 | 142 | -------------------------------------------------------------------------------- /torchfcpe/f02midi/featureExtraction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import librosa 3 | from pydub import AudioSegment 4 | import pathlib 5 | 6 | # from pydub.playback import play 7 | import numpy as np 8 | import os 9 | 10 | PATH_PROJECT = os.path.dirname(os.path.realpath(__file__)) 11 | 12 | 13 | def read_audio(filepath, sr=None): 14 | path = pathlib.Path(filepath) 15 | extenstion = path.suffix.replace(".", "") 16 | if extenstion == "mp3": 17 | sound = AudioSegment.from_mp3(filepath) 18 | else: 19 | sound = AudioSegment.from_file(filepath) 20 | # sound = sound[start * 1000 : end * 1000] 21 | sound = sound.set_channels(1) 22 | if sr == None: 23 | sr = sound.frame_rate 24 | sound = sound.set_frame_rate(sr) 25 | samples = sound.get_array_of_samples() 26 | y = np.array(samples).T.astype(np.float32) 27 | 28 | return y, sr 29 | 30 | 31 | def spec_extraction(file_name, win_size): 32 | 33 | y, _ = read_audio(file_name, sr=8000) 34 | 35 | S = librosa.core.stft(y, n_fft=1024, hop_length=80, win_length=1024) 36 | x_spec = np.abs(S) 37 | x_spec = librosa.core.power_to_db(x_spec, ref=np.max) 38 | x_spec = x_spec.astype(np.float32) 39 | num_frames = x_spec.shape[1] 40 | 41 | # for padding 42 | padNum = num_frames % win_size 43 | if padNum != 0: 44 | len_pad = win_size - padNum 45 | padding_feature = np.zeros(shape=(513, len_pad)) 46 | x_spec = np.concatenate((x_spec, padding_feature), axis=1) 47 | num_frames = num_frames + len_pad 48 | 49 | x_test = [] 50 | for j in range(0, num_frames, win_size): 51 | x_test_tmp = x_spec[:, range(j, j + win_size)].T 52 | x_test.append(x_test_tmp) 53 | x_test = np.array(x_test) 54 | 55 | # for standardization 56 | path_project = pathlib.Path(__file__).parent.parent 57 | x_train_mean = np.load(f"{path_project}/data/x_train_mean.npy") 58 | x_train_std = np.load(f"{path_project}/data/x_train_std.npy") 59 | x_test = (x_test - x_train_mean) / (x_train_std + 0.0001) 60 | x_test = x_test[:, :, :, np.newaxis] 61 | return x_test, x_spec 62 | -------------------------------------------------------------------------------- /torchfcpe/f02midi/quantization.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import numpy as np 3 | import librosa 4 | import librosa.display 5 | 6 | from scipy.signal import medfilt 7 | from matplotlib import pyplot as plt 8 | from .featureExtraction import read_audio 9 | from .utils import * 10 | 11 | 12 | # %% 13 | def calc_tempo(path_audio): 14 | """ Calculate audio tempo 15 | ---------- 16 | Parameters: 17 | path_audio: str 18 | 19 | ---------- 20 | Returns: 21 | tempo: float 22 | 23 | """ 24 | target_sr = 22050 25 | y, _ = read_audio(path_audio, sr=target_sr) 26 | onset_strength = librosa.onset.onset_strength(y, sr=target_sr) 27 | tempo = librosa.beat.tempo(onset_envelope=onset_strength, sr=target_sr) 28 | return tempo 29 | 30 | 31 | def one_beat_frame_size(tempo): 32 | """ Calculate frame size of 1 beat 33 | ---------- 34 | Parameters: 35 | tempo: float 36 | 37 | ---------- 38 | Returns: 39 | tempo: int 40 | 41 | """ 42 | return np.int32(np.round(60 / tempo * 100)) 43 | 44 | 45 | def median_filter_pitch(pitch, medfilt_size, weight): 46 | """ Smoothing pitch using median filter 47 | ---------- 48 | Parameters: 49 | pitch: array 50 | medfilt_size: int 51 | weight: float 52 | 53 | ---------- 54 | Returns: 55 | pitch: array 56 | 57 | """ 58 | 59 | medfilt_size = np.int32(medfilt_size * weight) 60 | if medfilt_size % 2 == 0: 61 | medfilt_size += 1 62 | return np.round(medfilt(pitch, medfilt_size)) 63 | 64 | 65 | def clean_note_frames(note, min_note_len=5): 66 | """ Remove short pitch frames 67 | ---------- 68 | Parameters: 69 | note: array 70 | min_note_len: int 71 | 72 | ---------- 73 | Returns: 74 | output: array 75 | 76 | """ 77 | 78 | prev_pitch = 0 79 | prev_pitch_start = 0 80 | output = np.copy(note) 81 | for i in range(len(note)): 82 | pitch = note[i] 83 | if pitch != prev_pitch: 84 | prev_pitch_duration = i - prev_pitch_start 85 | if prev_pitch_duration < min_note_len: 86 | output[prev_pitch_start:i] = [0] * prev_pitch_duration 87 | prev_pitch = pitch 88 | prev_pitch_start = i 89 | return output 90 | 91 | 92 | def makeSegments(note): 93 | """ Make segments of notes 94 | ---------- 95 | Parameters: 96 | note: array 97 | 98 | ---------- 99 | Returns: 100 | startSeg: starting points (array) 101 | endSeg: ending points (array) 102 | 103 | """ 104 | startSeg = [] 105 | endSeg = [] 106 | flag = -1 107 | if note[0] > 0: 108 | startSeg.append(0) 109 | flag *= -1 110 | for i in range(0, len(note) - 1): 111 | if note[i] != note[i + 1]: 112 | if flag < 0: 113 | startSeg.append(i + 1) 114 | flag *= -1 115 | else: 116 | if note[i + 1] == 0: 117 | endSeg.append(i) 118 | flag *= -1 119 | else: 120 | endSeg.append(i) 121 | startSeg.append(i + 1) 122 | return startSeg, endSeg 123 | 124 | 125 | def remove_short_segment(idx, note_cleaned, start, end, minLength): 126 | """ Remove short segments 127 | ---------- 128 | Parameters: 129 | idx: (int) 130 | note_cleaned: (array) 131 | start: starting points (array) 132 | end: ending points (array) 133 | minLength: (int) 134 | 135 | ---------- 136 | Returns: 137 | note_cleaned: (array) 138 | 139 | """ 140 | 141 | len_seg = end[idx] - start[idx] 142 | if len_seg < minLength: 143 | if (start[idx + 1] - end[idx] > minLength) and (start[idx] - end[idx - 1] > minLength): 144 | note_cleaned[start[idx] : end[idx] + 1] = [0] * (len_seg + 1) 145 | return note_cleaned 146 | 147 | 148 | def remove_octave_error(idx, note_cleaned, start, end): 149 | """ Remove octave error 150 | ---------- 151 | Parameters: 152 | idx: (int) 153 | note_cleaned: (array) 154 | start: starting points (array) 155 | end: ending points (array) 156 | 157 | ---------- 158 | Returns: 159 | note_cleaned: (array) 160 | 161 | """ 162 | len_seg = end[idx] - start[idx] 163 | if (note_cleaned[start[idx - 1]] == note_cleaned[start[idx + 1]]) and ( 164 | note_cleaned[start[idx]] != note_cleaned[start[idx + 1]] 165 | ): 166 | if np.abs(note_cleaned[start[idx]] - note_cleaned[start[idx + 1]]) % 12 == 0: 167 | note_cleaned[start[idx] - 1 : end[idx] + 1] = [note_cleaned[start[idx + 1]]] * ( 168 | len_seg + 2 169 | ) 170 | return note_cleaned 171 | 172 | 173 | def clean_segment(note, minLength): 174 | """ clean note segments 175 | ---------- 176 | Parameters: 177 | note: (array) 178 | minLength: (int) 179 | 180 | ---------- 181 | Returns: 182 | note_cleaned: (array) 183 | 184 | """ 185 | 186 | note_cleaned = np.copy(note) 187 | start, end = makeSegments(note_cleaned) 188 | 189 | for i in range(1, len(start) - 1): 190 | note_cleaned = remove_short_segment(i, note_cleaned, start, end, minLength) 191 | note_cleaned = remove_octave_error(i, note_cleaned, start, end) 192 | return note_cleaned 193 | 194 | 195 | def refine_note(est_note, tempo): 196 | """ main: refine note segments 197 | ---------- 198 | Parameters: 199 | est_note: (array) 200 | tempo: (float) 201 | 202 | ---------- 203 | Returns: 204 | est_pitch_mf3_v: (array) 205 | 206 | """ 207 | one_beat_size = one_beat_frame_size(tempo) 208 | est_note_mf1 = median_filter_pitch(est_note, one_beat_size, 1 / 6) 209 | est_note_mf2 = median_filter_pitch(est_note_mf1, one_beat_size, 1 / 3) 210 | est_note_mf3 = median_filter_pitch(est_note_mf2, one_beat_size, 1 / 2) 211 | 212 | vocing = est_note_mf1 > 0 213 | est_pitch_mf3_v = vocing * est_note_mf3 214 | est_pitch_mf3_v = clean_note_frames(est_pitch_mf3_v, int(one_beat_size * 1 / 4)) 215 | est_pitch_mf3_v = clean_segment(est_pitch_mf3_v, int(one_beat_size * 1 / 4)) 216 | return est_pitch_mf3_v 217 | 218 | -------------------------------------------------------------------------------- /torchfcpe/f02midi/transpose.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # %% 3 | import argparse 4 | import numpy as np 5 | from pathlib import Path 6 | from .featureExtraction import * 7 | from .quantization import * 8 | from .utils import * 9 | from .MIDI import * 10 | import librosa 11 | 12 | def f0_to_note(f0): 13 | """ convert frame-level pitch score(hz) to note-level (time-axis) """ 14 | note = 69 + 12 * np.log2(f0 / 440 + 1e-4) 15 | note = np.round(note) 16 | note = note.astype(int) 17 | note[note < 0] = 0 18 | note[note > 127] = 127 19 | return note 20 | 21 | def f02midi(f0, tempo = None, y = None, sr = None, output_path = None): 22 | """ f0 shape: (n_frames,) """ 23 | 24 | if tempo is None: 25 | if y is not None: 26 | target_sr = 22050 27 | y = librosa.resample(y = y, orig_sr = sr, target_sr = target_sr) 28 | onset_strength = librosa.onset.onset_strength(y = y, sr=target_sr) 29 | tempo = librosa.beat.tempo(onset_envelope=onset_strength, sr=target_sr) 30 | else: 31 | tempo = 120 32 | 33 | f0 = f0_to_note(f0) 34 | refined_fl_note = refine_note(f0, tempo) # frame-level pitch score 35 | 36 | """ convert frame-level pitch score to note-level (time-axis) """ 37 | segment = note_to_segment(refined_fl_note) # note-level pitch score 38 | if output_path is None: 39 | return segment 40 | else: 41 | """ save ouput to .mid """ 42 | segment_to_midi(segment, path_output=output_path, tempo=tempo) 43 | return segment 44 | -------------------------------------------------------------------------------- /torchfcpe/f02midi/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pydub import AudioSegment 4 | import pathlib 5 | 6 | 7 | def check_and_make_dir(path_dir): 8 | if not os.path.exists(os.path.dirname(path_dir)): 9 | os.makedirs(os.path.dirname(path_dir)) 10 | 11 | 12 | def get_filename_wo_extension(path_dir): 13 | return pathlib.Path(path_dir).stem 14 | 15 | 16 | def note2pitch(pitch): 17 | """ Convert MIDI number to freq. 18 | ---------- 19 | Parameters: 20 | pitch: MIDI note numbers of pitch (array) 21 | 22 | ---------- 23 | Returns: 24 | pitch: freqeuncy of pitch (array) 25 | """ 26 | 27 | pitch = np.array(pitch) 28 | pitch[pitch > 0] = 2 ** ((pitch[pitch > 0] - 69) / 12.0) * 440 29 | return pitch 30 | 31 | 32 | def pitch2note(pitch): 33 | """ Convert freq to MIDI number 34 | ---------- 35 | Parameters: 36 | pitch: freqeuncy of pitch (array) 37 | 38 | ---------- 39 | Returns: 40 | pitch: MIDI note numbers of pitch (array) 41 | """ 42 | pitch = np.array(pitch) 43 | pitch[pitch > 0] = np.round((69.0 + 12.0 * np.log2(pitch[pitch > 0] / 440.0))) 44 | return pitch 45 | 46 | 47 | a = np.array([0, 0, 0, 1, 2, 3, 5, 0, 0, 0, 1, 2, 4, 5]) 48 | b = a[a > 0] * 2 49 | print(b) 50 | -------------------------------------------------------------------------------- /torchfcpe/mel_extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torchaudio.transforms import Resample 6 | 7 | import os 8 | 9 | os.environ["LRU_CACHE_CAPACITY"] = "3" 10 | 11 | try: 12 | from librosa.filters import mel as librosa_mel_fn 13 | except ImportError: 14 | print(' [INF0] torchfcpe.mel_tools.nv_mel_extractor: Librosa not found,' 15 | ' use torchfcpe.mel_tools.mel_fn_librosa instead.') 16 | from .mel_fn_librosa import mel as librosa_mel_fn 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | return torch.log(torch.clamp(x, min=clip_val) * C) 21 | 22 | 23 | class HannWindow(torch.nn.Module): 24 | def __init__(self, win_size): 25 | super().__init__() 26 | self.register_buffer('window', torch.hann_window(win_size), persistent=False) 27 | 28 | def forward(self): 29 | return self.window 30 | 31 | 32 | class MelModule(torch.nn.Module): 33 | """Mel extractor 34 | 35 | Args: 36 | sr (int): Sampling rate. Defaults to 16000. 37 | n_mels (int): Number of mel bins. Defaults to 128. 38 | n_fft (int): FFT size. Defaults to 1024. 39 | win_size (int): Window size. Defaults to 1024. 40 | hop_length (int): Hop length. Defaults to 160. 41 | fmin (float, optional): Minimum frequency. Defaults to 0. 42 | fmax (float, optional): Maximum frequency. Defaults to sr/2. 43 | clip_val (float, optional): Clipping value. Defaults to 1e-5. 44 | """ 45 | 46 | def __init__(self, 47 | sr: [int, float], 48 | n_mels: int, 49 | n_fft: int, 50 | win_size: int, 51 | hop_length: int, 52 | fmin: float = None, 53 | fmax: float = None, 54 | clip_val: float = 1e-5, 55 | out_stft: bool = False, 56 | ): 57 | super().__init__() 58 | if fmin is None: 59 | fmin = 0 60 | if fmax is None: 61 | fmax = sr / 2 62 | self.target_sr = sr 63 | self.n_mels = n_mels 64 | self.n_fft = n_fft 65 | self.win_size = win_size 66 | self.hop_length = hop_length 67 | self.fmin = fmin 68 | self.fmax = fmax 69 | self.clip_val = clip_val 70 | # self.mel_basis = {} 71 | self.register_buffer( 72 | 'mel_basis', 73 | torch.tensor(librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)).float(), 74 | persistent=False 75 | ) 76 | self.hann_window = torch.nn.ModuleDict() 77 | self.out_stft = out_stft 78 | 79 | @torch.no_grad() 80 | def __call__(self, 81 | y: torch.Tensor, # (B, T, 1) 82 | key_shift: [int, float] = 0, 83 | speed: [int, float] = 1, 84 | center: bool = False, 85 | no_cache_window: bool = False 86 | ) -> torch.Tensor: # (B, T, n_mels) 87 | """Get mel spectrogram 88 | 89 | Args: 90 | y (torch.Tensor): Input waveform, shape=(B, T, 1). 91 | key_shift (int, optional): Key shift. Defaults to 0. 92 | speed (int, optional): Variable speed enhancement factor. Defaults to 1. 93 | center (bool, optional): center for torch.stft. Defaults to False. 94 | no_cache_window (bool, optional): If True will clear cache. Defaults to False. 95 | return: 96 | spec (torch.Tensor): Mel spectrogram, shape=(B, T, n_mels). 97 | """ 98 | 99 | n_fft = self.n_fft 100 | win_size = self.win_size 101 | hop_length = self.hop_length 102 | clip_val = self.clip_val 103 | 104 | factor = 2 ** (key_shift / 12) 105 | n_fft_new = int(np.round(n_fft * factor)) 106 | win_size_new = int(np.round(win_size * factor)) 107 | hop_length_new = int(np.round(hop_length * speed)) 108 | 109 | y = y.squeeze(-1) 110 | 111 | if torch.min(y) < -1.: 112 | print('[error with torchfcpe.mel_extractor.MelModule]min value is ', torch.min(y)) 113 | if torch.max(y) > 1.: 114 | print('[error with torchfcpe.mel_extractor.MelModule]max value is ', torch.max(y)) 115 | 116 | key_shift_key = str(key_shift) 117 | if not no_cache_window: 118 | if key_shift_key in self.hann_window: 119 | hann_window = self.hann_window[key_shift_key] 120 | else: 121 | hann_window = HannWindow(win_size_new).to(self.mel_basis.device) 122 | self.hann_window[key_shift_key] = hann_window 123 | hann_window_tensor = hann_window() 124 | else: 125 | hann_window_tensor = torch.hann_window(win_size_new).to(self.mel_basis.device) 126 | 127 | pad_left = (win_size_new - hop_length_new) // 2 128 | pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left) 129 | if pad_right < y.size(-1): 130 | mode = 'reflect' 131 | else: 132 | mode = 'constant' 133 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode) 134 | y = y.squeeze(1) 135 | 136 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, 137 | window=hann_window_tensor, 138 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 139 | spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9) 140 | if key_shift != 0: 141 | size = n_fft // 2 + 1 142 | resize = spec.size(1) 143 | if resize < size: 144 | spec = F.pad(spec, (0, 0, 0, size - resize)) 145 | spec = spec[:, :size, :] * win_size / win_size_new 146 | if self.out_stft: 147 | spec = spec[:, :512, :] 148 | else: 149 | spec = torch.matmul(self.mel_basis, spec) 150 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val) 151 | spec = spec.transpose(-1, -2) 152 | return spec # (B, T, n_mels) 153 | 154 | 155 | class Wav2MelModule(torch.nn.Module): 156 | """ 157 | Wav to mel converter 158 | NOTE: This class of code is reserved for training only, please use Wav2MelModule for inference 159 | 160 | Args: 161 | sr (int): Sampling rate. Defaults to 16000. 162 | n_mels (int): Number of mel bins. Defaults to 128. 163 | n_fft (int): FFT size. Defaults to 1024. 164 | win_size (int): Window size. Defaults to 1024. 165 | hop_length (int): Hop length. Defaults to 160. 166 | fmin (float, optional): Minimum frequency. Defaults to 0. 167 | fmax (float, optional): Maximum frequency. Defaults to sr/2. 168 | clip_val (float, optional): Clipping value. Defaults to 1e-5. 169 | device (str, optional): Device. Defaults to 'cpu'. 170 | """ 171 | 172 | def __init__(self, 173 | sr: [int, float], 174 | n_mels: int, 175 | n_fft: int, 176 | win_size: int, 177 | hop_length: int, 178 | fmin: float = None, 179 | fmax: float = None, 180 | clip_val: float = 1e-5, 181 | mel_type="default", 182 | ): 183 | super().__init__() 184 | # catch None 185 | if fmin is None: 186 | fmin = 0 187 | if fmax is None: 188 | fmax = sr / 2 189 | # init 190 | self.sampling_rate = sr 191 | self.n_mels = n_mels 192 | self.n_fft = n_fft 193 | self.win_size = win_size 194 | self.hop_size = hop_length 195 | self.fmin = fmin 196 | self.fmax = fmax 197 | self.clip_val = clip_val 198 | # self.device = device 199 | self.register_buffer( 200 | 'tensor_device_marker', 201 | torch.tensor(1.0).float(), 202 | persistent=False 203 | ) 204 | self.resample_kernel = torch.nn.ModuleDict() 205 | if mel_type == "default": 206 | self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, 207 | out_stft=False) 208 | elif mel_type == "stft": 209 | self.mel_extractor = MelModule(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, 210 | out_stft=True) 211 | self.mel_type = mel_type 212 | 213 | def device(self): 214 | """Get device""" 215 | return self.tensor_device_marker.device 216 | 217 | @torch.no_grad() 218 | def __call__(self, 219 | audio: torch.Tensor, # (B, T, 1) 220 | sample_rate: [int, float], 221 | keyshift: [int, float] = 0, 222 | no_cache_window: bool = False 223 | ) -> torch.Tensor: # (B, T, n_mels) 224 | """ 225 | Get mel spectrogram 226 | 227 | Args: 228 | audio (torch.Tensor): Input waveform, shape=(B, T, 1). 229 | sample_rate (int): Sampling rate. 230 | keyshift (int, optional): Key shift. Defaults to 0. 231 | no_cache_window (bool, optional): If True will clear cache. Defaults to False. 232 | return: 233 | spec (torch.Tensor): Mel spectrogram, shape=(B, T, n_mels). 234 | """ 235 | 236 | # resample 237 | if sample_rate == self.sampling_rate: 238 | audio_res = audio 239 | else: 240 | key_str = str(sample_rate) 241 | if key_str not in self.resample_kernel: 242 | if len(self.resample_kernel) > 8: 243 | self.resample_kernel.clear() 244 | self.resample_kernel[key_str] = Resample( 245 | sample_rate, 246 | self.sampling_rate, 247 | lowpass_filter_width=128 248 | ).to(self.tensor_device_marker.device) 249 | audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1) 250 | 251 | # extract 252 | mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window) 253 | n_frames = int(audio.shape[1] // self.hop_size) + 1 254 | if n_frames > int(mel.shape[1]): 255 | mel = torch.cat((mel, mel[:, -1:, :]), 1) 256 | if n_frames < int(mel.shape[1]): 257 | mel = mel[:, :n_frames, :] 258 | 259 | return mel # (B, T, n_mels) 260 | 261 | 262 | class MelExtractor: 263 | """Mel extractor 264 | NOTE: This class of code is reserved for training only, please use MelModule for inference 265 | 266 | Args: 267 | sr (int): Sampling rate. Defaults to 16000. 268 | n_mels (int): Number of mel bins. Defaults to 128. 269 | n_fft (int): FFT size. Defaults to 1024. 270 | win_size (int): Window size. Defaults to 1024. 271 | hop_length (int): Hop length. Defaults to 160. 272 | fmin (float, optional): Minimum frequency. Defaults to 0. 273 | fmax (float, optional): Maximum frequency. Defaults to sr/2. 274 | clip_val (float, optional): Clipping value. Defaults to 1e-5. 275 | """ 276 | 277 | def __init__(self, 278 | sr: [int, float], 279 | n_mels: int, 280 | n_fft: int, 281 | win_size: int, 282 | hop_length: int, 283 | fmin: float = None, 284 | fmax: float = None, 285 | clip_val: float = 1e-5, 286 | out_stft: bool = False, 287 | ): 288 | if fmin is None: 289 | fmin = 0 290 | if fmax is None: 291 | fmax = sr / 2 292 | self.target_sr = sr 293 | self.n_mels = n_mels 294 | self.n_fft = n_fft 295 | self.win_size = win_size 296 | self.hop_length = hop_length 297 | self.fmin = fmin 298 | self.fmax = fmax 299 | self.clip_val = clip_val 300 | self.mel_basis = {} 301 | self.hann_window = {} 302 | self.out_stft = out_stft 303 | 304 | @torch.no_grad() 305 | def __call__(self, 306 | y: torch.Tensor, # (B, T, 1) 307 | key_shift: [int, float] = 0, 308 | speed: [int, float] = 1, 309 | center: bool = False, 310 | no_cache_window: bool = False 311 | ) -> torch.Tensor: # (B, T, n_mels) 312 | """Get mel spectrogram 313 | 314 | Args: 315 | y (torch.Tensor): Input waveform, shape=(B, T, 1). 316 | key_shift (int, optional): Key shift. Defaults to 0. 317 | speed (int, optional): Variable speed enhancement factor. Defaults to 1. 318 | center (bool, optional): center for torch.stft. Defaults to False. 319 | no_cache_window (bool, optional): If True will clear cache. Defaults to False. 320 | return: 321 | spec (torch.Tensor): Mel spectrogram, shape=(B, T, n_mels). 322 | """ 323 | 324 | sampling_rate = self.target_sr 325 | n_mels = self.n_mels 326 | n_fft = self.n_fft 327 | win_size = self.win_size 328 | hop_length = self.hop_length 329 | fmin = self.fmin 330 | fmax = self.fmax 331 | clip_val = self.clip_val 332 | 333 | factor = 2 ** (key_shift / 12) 334 | n_fft_new = int(np.round(n_fft * factor)) 335 | win_size_new = int(np.round(win_size * factor)) 336 | hop_length_new = int(np.round(hop_length * speed)) 337 | if not no_cache_window: 338 | mel_basis = self.mel_basis 339 | hann_window = self.hann_window 340 | else: 341 | mel_basis = {} 342 | hann_window = {} 343 | 344 | y = y.squeeze(-1) 345 | 346 | if torch.min(y) < -1.: 347 | print('min value is ', torch.min(y)) 348 | if torch.max(y) > 1.: 349 | print('max value is ', torch.max(y)) 350 | 351 | mel_basis_key = str(fmax) + '_' + str(y.device) 352 | if (mel_basis_key not in mel_basis) and (not self.out_stft): 353 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) 354 | mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device) 355 | 356 | key_shift_key = str(key_shift) + '_' + str(y.device) 357 | if key_shift_key not in hann_window: 358 | hann_window[key_shift_key] = torch.hann_window(win_size_new).to(y.device) 359 | 360 | pad_left = (win_size_new - hop_length_new) // 2 361 | pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left) 362 | if pad_right < y.size(-1): 363 | mode = 'reflect' 364 | else: 365 | mode = 'constant' 366 | y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode) 367 | y = y.squeeze(1) 368 | 369 | spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, 370 | window=hann_window[key_shift_key], 371 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True) 372 | spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-9) 373 | if key_shift != 0: 374 | size = n_fft // 2 + 1 375 | resize = spec.size(1) 376 | if resize < size: 377 | spec = F.pad(spec, (0, 0, 0, size - resize)) 378 | spec = spec[:, :size, :] * win_size / win_size_new 379 | if self.out_stft: 380 | spec = spec[:, :512, :] 381 | else: 382 | spec = torch.matmul(mel_basis[mel_basis_key], spec) 383 | spec = dynamic_range_compression_torch(spec, clip_val=clip_val) 384 | spec = spec.transpose(-1, -2) 385 | return spec # (B, T, n_mels) 386 | 387 | 388 | # init nv_mel_extractor cache 389 | # will remove this when we have a better solution 390 | # mel_extractor = MelExtractor(16000, 128, 1024, 1024, 160, 0, 8000) 391 | 392 | 393 | class Wav2Mel: 394 | """ 395 | Wav to mel converter 396 | NOTE: This class of code is reserved for training only, please use Wav2MelModule for inference 397 | 398 | Args: 399 | sr (int): Sampling rate. Defaults to 16000. 400 | n_mels (int): Number of mel bins. Defaults to 128. 401 | n_fft (int): FFT size. Defaults to 1024. 402 | win_size (int): Window size. Defaults to 1024. 403 | hop_length (int): Hop length. Defaults to 160. 404 | fmin (float, optional): Minimum frequency. Defaults to 0. 405 | fmax (float, optional): Maximum frequency. Defaults to sr/2. 406 | clip_val (float, optional): Clipping value. Defaults to 1e-5. 407 | device (str, optional): Device. Defaults to 'cpu'. 408 | """ 409 | 410 | def __init__(self, 411 | sr: [int, float], 412 | n_mels: int, 413 | n_fft: int, 414 | win_size: int, 415 | hop_length: int, 416 | fmin: float = None, 417 | fmax: float = None, 418 | clip_val: float = 1e-5, 419 | device='cpu', 420 | mel_type="default", 421 | ): 422 | # catch None 423 | if fmin is None: 424 | fmin = 0 425 | if fmax is None: 426 | fmax = sr / 2 427 | # init 428 | self.sampling_rate = sr 429 | self.n_mels = n_mels 430 | self.n_fft = n_fft 431 | self.win_size = win_size 432 | self.hop_size = hop_length 433 | self.fmin = fmin 434 | self.fmax = fmax 435 | self.clip_val = clip_val 436 | self.device = device 437 | self.resample_kernel = {} 438 | if mel_type == "default": 439 | self.mel_extractor = MelExtractor(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, 440 | out_stft=False) 441 | elif mel_type == "stft": 442 | self.mel_extractor = MelExtractor(sr, n_mels, n_fft, win_size, hop_length, fmin, fmax, clip_val, 443 | out_stft=True) 444 | self.mel_type = mel_type 445 | 446 | def device(self): 447 | """Get device""" 448 | return self.device 449 | 450 | @torch.no_grad() 451 | def __call__(self, 452 | audio: torch.Tensor, # (B, T, 1) 453 | sample_rate: [int, float], 454 | keyshift: [int, float] = 0, 455 | no_cache_window: bool = False 456 | ) -> torch.Tensor: # (B, T, n_mels) 457 | """ 458 | Get mel spectrogram 459 | 460 | Args: 461 | audio (torch.Tensor): Input waveform, shape=(B, T, 1). 462 | sample_rate (int): Sampling rate. 463 | keyshift (int, optional): Key shift. Defaults to 0. 464 | no_cache_window (bool, optional): If True will clear cache. Defaults to False. 465 | return: 466 | spec (torch.Tensor): Mel spectrogram, shape=(B, T, n_mels). 467 | """ 468 | 469 | # resample 470 | if sample_rate == self.sampling_rate: 471 | audio_res = audio 472 | else: 473 | key_str = str(sample_rate) 474 | if key_str not in self.resample_kernel: 475 | self.resample_kernel[key_str] = Resample( 476 | sample_rate, 477 | self.sampling_rate, 478 | lowpass_filter_width=128 479 | ).to(self.device) 480 | audio_res = self.resample_kernel[key_str](audio.squeeze(-1)).unsqueeze(-1) 481 | 482 | # extract 483 | mel = self.mel_extractor(audio_res, keyshift, no_cache_window=no_cache_window) 484 | n_frames = int(audio.shape[1] // self.hop_size) + 1 485 | if n_frames > int(mel.shape[1]): 486 | mel = torch.cat((mel, mel[:, -1:, :]), 1) 487 | if n_frames < int(mel.shape[1]): 488 | mel = mel[:, :n_frames, :] 489 | 490 | return mel # (B, T, n_mels) 491 | 492 | 493 | def unit_text(): 494 | """ 495 | Test unit for nv_mel_extractor.py 496 | Should be set path to your test audio file. 497 | Need matplotlib and librosa to plot. 498 | require: pip install matplotlib librosa 499 | """ 500 | import time 501 | 502 | try: 503 | import matplotlib.pyplot as plt 504 | import librosa 505 | import librosa.display 506 | except ImportError: 507 | print(' [UNIT_TEST] torchfcpe.mel_tools.nv_mel_extractor: Matplotlib or Librosa not found,' 508 | ' skip plotting.') 509 | exit(1) 510 | 511 | # spawn mel extractor and wav2mel 512 | mel_extractor_test = MelExtractor(16000, 128, 1024, 1024, 160, 0, 8000) 513 | wav2mel_test = Wav2Mel(16000, 128, 1024, 1024, 160, 0, 8000) 514 | 515 | # load audio 516 | audio_path = r'E:\AUFSe04BPyProgram\AUFSd04BPyProgram\ddsp-svc\20230308\diffusion-svc\samples\GJ2.wav' 517 | audio, sr = librosa.load(audio_path, sr=16000) 518 | audio = torch.from_numpy(audio).unsqueeze(0).unsqueeze(-1) 519 | audio = audio.to('cuda') 520 | print(' [UNIT_TEST] torchfcpe.mel_tools.mel_extractor: Audio shape: {}'.format(audio.shape)) 521 | 522 | # test mel extractor 523 | start_time = time.time() 524 | mel1 = mel_extractor_test(audio, 0, 1, False) 525 | print(' [UNIT_TEST] torchfcpe.mel_extractor: Mel extractor time cost: {:.3f}s'.format( 526 | time.time() - start_time)) 527 | print(' [UNIT_TEST] torchfcpe.mel_extractor: Mel extractor output shape: {}'.format(mel1.shape)) 528 | 529 | # test wav2mel 530 | start_time = time.time() 531 | mel2 = wav2mel_test(audio, 16000, 0) 532 | print(' [UNIT_TEST] torchfcpe.mel_extractor: Wav2mel time cost: {:.3f}s'.format( 533 | time.time() - start_time)) 534 | print(' [UNIT_TEST] torchfcpe.mel_extractor: Wav2mel output shape: {}'.format(mel2.shape)) 535 | 536 | # test melModule 537 | mel_module = MelModule(16000, 128, 1024, 1024, 160, 0, 8000).to('cuda') 538 | mel3 = mel_module(audio, 0, 1, False).to('cuda') 539 | print(' [UNIT_TEST] torchfcpe.mel_extractor: MelModule output shape: {}'.format(mel3.shape)) 540 | 541 | # test Wav2MelModule 542 | wav2mel_module = Wav2MelModule(16000, 128, 1024, 1024, 160, 0, 8000).to('cuda') 543 | mel4 = wav2mel_module(audio, 16000, 0).to('cuda') 544 | print(' [UNIT_TEST] torchfcpe.mel_extractor: Wav2MelModule output shape: {}'.format(mel4.shape)) 545 | 546 | # plot 547 | plt.figure(figsize=(12, 4)) 548 | plt.subplot(1, 5, 1) 549 | librosa.display.waveshow(audio.squeeze().cpu().numpy(), sr=16000) 550 | plt.title('Audio') 551 | plt.subplot(1, 5, 2) 552 | librosa.display.specshow(mel1.squeeze().cpu().numpy().T, sr=16000, hop_length=160, x_axis='time', y_axis='mel') 553 | plt.title('Mel extractor') 554 | plt.subplot(1, 5, 3) 555 | librosa.display.specshow(mel2.squeeze().cpu().numpy().T, sr=16000, hop_length=160, x_axis='time', y_axis='mel') 556 | plt.title('Wav2mel') 557 | 558 | plt.subplot(1, 5, 4) 559 | librosa.display.specshow(mel3.squeeze().cpu().numpy().T, sr=16000, hop_length=160, x_axis='time', y_axis='mel') 560 | plt.title('MelModule') 561 | plt.subplot(1, 5, 5) 562 | librosa.display.specshow(mel4.squeeze().cpu().numpy().T, sr=16000, hop_length=160, x_axis='time', y_axis='mel') 563 | plt.title('Wav2MelModule') 564 | 565 | plt.tight_layout() 566 | plt.show() 567 | 568 | 569 | if __name__ == '__main__': 570 | unit_text() 571 | -------------------------------------------------------------------------------- /torchfcpe/mel_fn_librosa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """ 4 | from librosa.filters 5 | """ 6 | 7 | 8 | def mel( 9 | *, 10 | sr, 11 | n_fft, 12 | n_mels=128, 13 | fmin=0.0, 14 | fmax=None, 15 | htk=False, 16 | norm="slaney", 17 | dtype=np.float32, 18 | ): 19 | """Create a Mel filter-bank. 20 | 21 | This produces a linear transformation matrix to project 22 | FFT bins onto Mel-frequency bins. 23 | 24 | Parameters 25 | ---------- 26 | sr : number > 0 [scalar] 27 | sampling rate of the incoming signal 28 | 29 | n_fft : int > 0 [scalar] 30 | number of FFT components 31 | 32 | n_mels : int > 0 [scalar] 33 | number of Mel bands to generate 34 | 35 | fmin : float >= 0 [scalar] 36 | lowest frequency (in Hz) 37 | 38 | fmax : float >= 0 [scalar] 39 | highest frequency (in Hz). 40 | If `None`, use ``fmax = sr / 2.0`` 41 | 42 | htk : bool [scalar] 43 | use HTK formula instead of Slaney 44 | 45 | norm : {None, 'slaney', or number} [scalar] 46 | If 'slaney', divide the triangular mel weights by the width of the mel band 47 | (area normalization). 48 | 49 | If numeric, use `librosa.util.normalize` to normalize each filter by to unit l_p norm. 50 | See `librosa.util.normalize` for a full description of supported norm values 51 | (including `+-np.inf`). 52 | 53 | Otherwise, leave all the triangles aiming for a peak value of 1.0 54 | 55 | dtype : np.dtype 56 | The data type of the output basis. 57 | By default, uses 32-bit (single-precision) floating point. 58 | 59 | Returns 60 | ------- 61 | M : np.ndarray [shape=(n_mels, 1 + n_fft/2)] 62 | Mel transform matrix 63 | 64 | See Also 65 | -------- 66 | librosa.util.normalize 67 | 68 | Notes 69 | ----- 70 | This function caches at level 10. 71 | 72 | Examples 73 | -------- 74 | # >>> melfb = librosa.filters.mel(sr=22050, n_fft=2048) 75 | # >>> melfb 76 | array([[ 0. , 0.016, ..., 0. , 0. ], 77 | [ 0. , 0. , ..., 0. , 0. ], 78 | ..., 79 | [ 0. , 0. , ..., 0. , 0. ], 80 | [ 0. , 0. , ..., 0. , 0. ]]) 81 | 82 | Clip the maximum frequency to 8KHz 83 | 84 | # >>> librosa.filters.mel(sr=22050, n_fft=2048, fmax=8000) 85 | array([[ 0. , 0.02, ..., 0. , 0. ], 86 | [ 0. , 0. , ..., 0. , 0. ], 87 | ..., 88 | [ 0. , 0. , ..., 0. , 0. ], 89 | [ 0. , 0. , ..., 0. , 0. ]]) 90 | 91 | # >>> import matplotlib.pyplot as plt 92 | # >>> fig, ax = plt.subplots() 93 | # >>> img = librosa.display.specshow(melfb, x_axis='linear', ax=ax) 94 | # >>> ax.set(ylabel='Mel filter', title='Mel filter bank') 95 | # >>> fig.colorbar(img, ax=ax) 96 | """ 97 | 98 | if fmax is None: 99 | fmax = float(sr) / 2 100 | 101 | # Initialize the weights 102 | n_mels = int(n_mels) 103 | weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) 104 | 105 | # Center freqs of each FFT bin 106 | fftfreqs = fft_frequencies(sr=sr, n_fft=n_fft) 107 | 108 | # 'Center freqs' of mel bands - uniformly spaced between limits 109 | mel_f = mel_frequencies(n_mels + 2, fmin=fmin, fmax=fmax, htk=htk) 110 | 111 | fdiff = np.diff(mel_f) 112 | ramps = np.subtract.outer(mel_f, fftfreqs) 113 | 114 | for i in range(n_mels): 115 | # lower and upper slopes for all bins 116 | lower = -ramps[i] / fdiff[i] 117 | upper = ramps[i + 2] / fdiff[i + 1] 118 | 119 | # .. then intersect them with each other and zero 120 | weights[i] = np.maximum(0, np.minimum(lower, upper)) 121 | 122 | if norm == "slaney": 123 | # Slaney-style mel is scaled to be approx constant energy per channel 124 | enorm = 2.0 / (mel_f[2: n_mels + 2] - mel_f[:n_mels]) 125 | weights *= enorm[:, np.newaxis] 126 | else: 127 | weights = normalize(weights, norm=norm, axis=-1) 128 | 129 | # Only check weights if f_mel[0] is positive 130 | if not np.all((mel_f[:-2] == 0) | (weights.max(axis=1) > 0)): 131 | # This means we have an empty channel somewhere 132 | print( 133 | " [WARN] UserWarning:" 134 | "Empty filters detected in mel frequency basis. " 135 | "Some channels will produce empty responses. " 136 | "Try increasing your sampling rate (and fmax) or " 137 | "reducing n_mels." 138 | ) 139 | 140 | return weights 141 | 142 | 143 | def fft_frequencies(*, sr=22050, n_fft=2048): 144 | """Alternative implementation of `np.fft.fftfreq` 145 | 146 | Parameters 147 | ---------- 148 | sr : number > 0 [scalar] 149 | Audio sampling rate 150 | n_fft : int > 0 [scalar] 151 | FFT window size 152 | 153 | Returns 154 | ------- 155 | freqs : np.ndarray [shape=(1 + n_fft/2,)] 156 | Frequencies ``(0, sr/n_fft, 2*sr/n_fft, ..., sr/2)`` 157 | 158 | Examples 159 | -------- 160 | # >>> librosa.fft_frequencies(sr=22050, n_fft=16) 161 | array([ 0. , 1378.125, 2756.25 , 4134.375, 162 | 5512.5 , 6890.625, 8268.75 , 9646.875, 11025. ]) 163 | 164 | """ 165 | 166 | return np.fft.rfftfreq(n=n_fft, d=1.0 / sr) 167 | 168 | 169 | def mel_frequencies(n_mels=128, *, fmin=0.0, fmax=11025.0, htk=False): 170 | """Compute an array of acoustic frequencies tuned to the mel scale. 171 | 172 | The mel scale is a quasi-logarithmic function of acoustic frequency 173 | designed such that perceptually similar pitch intervals (e.g. octaves) 174 | appear equal in width over the full hearing range. 175 | 176 | Because the definition of the mel scale is conditioned by a finite number 177 | of subjective psychoaoustical experiments, several implementations coexist 178 | in the audio signal processing literature [#]_. By default, librosa replicates 179 | the behavior of the well-established MATLAB Auditory Toolbox of Slaney [#]_. 180 | According to this default implementation, the conversion from Hertz to mel is 181 | linear below 1 kHz and logarithmic above 1 kHz. Another available implementation 182 | replicates the Hidden Markov Toolkit [#]_ (HTK) according to the following formula:: 183 | 184 | mel = 2595.0 * np.log10(1.0 + f / 700.0). 185 | 186 | The choice of implementation is determined by the ``htk`` keyword argument: setting 187 | ``htk=False`` leads to the Auditory toolbox implementation, whereas setting it ``htk=True`` 188 | leads to the HTK implementation. 189 | 190 | .. [#] Umesh, S., Cohen, L., & Nelson, D. Fitting the mel scale. 191 | In Proc. International Conference on Acoustics, Speech, and Signal Processing 192 | (ICASSP), vol. 1, pp. 217-220, 1998. 193 | 194 | .. [#] Slaney, M. Auditory Toolbox: A MATLAB Toolbox for Auditory 195 | Modeling Work. Technical Report, version 2, Interval Research Corporation, 1998. 196 | 197 | .. [#] Young, S., Evermann, G., Gales, M., Hain, T., Kershaw, D., Liu, X., 198 | Moore, G., Odell, J., Ollason, D., Povey, D., Valtchev, V., & Woodland, P. 199 | The HTK book, version 3.4. Cambridge University, March 2009. 200 | 201 | See Also 202 | -------- 203 | hz_to_mel 204 | mel_to_hz 205 | librosa.feature.melspectrogram 206 | librosa.feature.mfcc 207 | 208 | Parameters 209 | ---------- 210 | n_mels : int > 0 [scalar] 211 | Number of mel bins. 212 | fmin : float >= 0 [scalar] 213 | Minimum frequency (Hz). 214 | fmax : float >= 0 [scalar] 215 | Maximum frequency (Hz). 216 | htk : bool 217 | If True, use HTK formula to convert Hz to mel. 218 | Otherwise (False), use Slaney's Auditory Toolbox. 219 | 220 | Returns 221 | ------- 222 | bin_frequencies : ndarray [shape=(n_mels,)] 223 | Vector of ``n_mels`` frequencies in Hz which are uniformly spaced on the Mel 224 | axis. 225 | 226 | Examples 227 | -------- 228 | # >>> librosa.mel_frequencies(n_mels=40) 229 | array([ 0. , 85.317, 170.635, 255.952, 230 | 341.269, 426.586, 511.904, 597.221, 231 | 682.538, 767.855, 853.173, 938.49 , 232 | 1024.856, 1119.114, 1222.042, 1334.436, 233 | 1457.167, 1591.187, 1737.532, 1897.337, 234 | 2071.84 , 2262.393, 2470.47 , 2697.686, 235 | 2945.799, 3216.731, 3512.582, 3835.643, 236 | 4188.417, 4573.636, 4994.285, 5453.621, 237 | 5955.205, 6502.92 , 7101.009, 7754.107, 238 | 8467.272, 9246.028, 10096.408, 11025. ]) 239 | 240 | """ 241 | 242 | # 'Center freqs' of mel bands - uniformly spaced between limits 243 | min_mel = hz_to_mel(fmin, htk=htk) 244 | max_mel = hz_to_mel(fmax, htk=htk) 245 | 246 | mels = np.linspace(min_mel, max_mel, n_mels) 247 | 248 | return mel_to_hz(mels, htk=htk) 249 | 250 | 251 | def hz_to_mel(frequencies, *, htk=False): 252 | """Convert Hz to Mels 253 | 254 | Examples 255 | -------- 256 | # >>> librosa.hz_to_mel(60) 257 | 0.9 258 | # >>> librosa.hz_to_mel([110, 220, 440]) 259 | array([ 1.65, 3.3 , 6.6 ]) 260 | 261 | Parameters 262 | ---------- 263 | frequencies : number or np.ndarray [shape=(n,)] , float 264 | scalar or array of frequencies 265 | htk : bool 266 | use HTK formula instead of Slaney 267 | 268 | Returns 269 | ------- 270 | mels : number or np.ndarray [shape=(n,)] 271 | input frequencies in Mels 272 | 273 | See Also 274 | -------- 275 | mel_to_hz 276 | """ 277 | 278 | frequencies = np.asanyarray(frequencies) 279 | 280 | if htk: 281 | return 2595.0 * np.log10(1.0 + frequencies / 700.0) 282 | 283 | # Fill in the linear part 284 | f_min = 0.0 285 | f_sp = 200.0 / 3 286 | 287 | mels = (frequencies - f_min) / f_sp 288 | 289 | # Fill in the log-scale part 290 | 291 | min_log_hz = 1000.0 # beginning of log region (Hz) 292 | min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) 293 | logstep = np.log(6.4) / 27.0 # step size for log region 294 | 295 | if frequencies.ndim: 296 | # If we have array data, vectorize 297 | log_t = frequencies >= min_log_hz 298 | mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep 299 | elif frequencies >= min_log_hz: 300 | # If we have scalar data, heck directly 301 | mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep 302 | 303 | return mels 304 | 305 | 306 | def mel_to_hz(mels, *, htk=False): 307 | """Convert mel bin numbers to frequencies 308 | 309 | Examples 310 | -------- 311 | # >>> librosa.mel_to_hz(3) 312 | 200. 313 | 314 | # >>> librosa.mel_to_hz([1,2,3,4,5]) 315 | array([ 66.667, 133.333, 200. , 266.667, 333.333]) 316 | 317 | Parameters 318 | ---------- 319 | mels : np.ndarray [shape=(n,)], float 320 | mel bins to convert 321 | htk : bool 322 | use HTK formula instead of Slaney 323 | 324 | Returns 325 | ------- 326 | frequencies : np.ndarray [shape=(n,)] 327 | input mels in Hz 328 | 329 | See Also 330 | -------- 331 | hz_to_mel 332 | """ 333 | 334 | mels = np.asanyarray(mels) 335 | 336 | if htk: 337 | return 700.0 * (10.0 ** (mels / 2595.0) - 1.0) 338 | 339 | # Fill in the linear scale 340 | f_min = 0.0 341 | f_sp = 200.0 / 3 342 | freqs = f_min + f_sp * mels 343 | 344 | # And now the nonlinear scale 345 | min_log_hz = 1000.0 # beginning of log region (Hz) 346 | min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) 347 | logstep = np.log(6.4) / 27.0 # step size for log region 348 | 349 | if mels.ndim: 350 | # If we have vector data, vectorize 351 | log_t = mels >= min_log_mel 352 | freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) 353 | elif mels >= min_log_mel: 354 | # If we have scalar data, check directly 355 | freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel)) 356 | 357 | return freqs 358 | 359 | 360 | def normalize(S, *, norm=np.inf, axis=0, threshold=None, fill=None): 361 | """Normalize an array along a chosen axis. 362 | 363 | Given a norm (described below) and a target axis, the input 364 | array is scaled so that:: 365 | 366 | norm(S, axis=axis) == 1 367 | 368 | For example, ``axis=0`` normalizes each column of a 2-d array 369 | by aggregating over the rows (0-axis). 370 | Similarly, ``axis=1`` normalizes each row of a 2-d array. 371 | 372 | This function also supports thresholding small-norm slices: 373 | any slice (i.e., row or column) with norm below a specified 374 | ``threshold`` can be left un-normalized, set to all-zeros, or 375 | filled with uniform non-zero values that normalize to 1. 376 | 377 | Note: the semantics of this function differ from 378 | `scipy.linalg.norm` in two ways: multi-dimensional arrays 379 | are supported, but matrix-norms are not. 380 | 381 | Parameters 382 | ---------- 383 | S : np.ndarray 384 | The array to normalize 385 | 386 | norm : {np.inf, -np.inf, 0, float > 0, None} 387 | - `np.inf` : maximum absolute value 388 | - `-np.inf` : minimum absolute value 389 | - `0` : number of non-zeros (the support) 390 | - float : corresponding l_p norm 391 | See `scipy.linalg.norm` for details. 392 | - None : no normalization is performed 393 | 394 | axis : int [scalar] 395 | Axis along which to compute the norm. 396 | 397 | threshold : number > 0 [optional] 398 | Only the columns (or rows) with norm at least ``threshold`` are 399 | normalized. 400 | 401 | By default, the threshold is determined from 402 | the numerical precision of ``S.dtype``. 403 | 404 | fill : None or bool 405 | If None, then columns (or rows) with norm below ``threshold`` 406 | are left as is. 407 | 408 | If False, then columns (rows) with norm below ``threshold`` 409 | are set to 0. 410 | 411 | If True, then columns (rows) with norm below ``threshold`` 412 | are filled uniformly such that the corresponding norm is 1. 413 | 414 | .. note:: ``fill=True`` is incompatible with ``norm=0`` because 415 | no uniform vector exists with l0 "norm" equal to 1. 416 | 417 | Returns 418 | ------- 419 | S_norm : np.ndarray [shape=S.shape] 420 | Normalized array 421 | 422 | Raises 423 | ------ 424 | ParameterError 425 | If ``norm`` is not among the valid types defined above 426 | 427 | If ``S`` is not finite 428 | 429 | If ``fill=True`` and ``norm=0`` 430 | 431 | See Also 432 | -------- 433 | scipy.linalg.norm 434 | 435 | Notes 436 | ----- 437 | This function caches at level 40. 438 | 439 | Examples 440 | -------- 441 | # >>> # Construct an example matrix 442 | # >>> S = np.vander(np.arange(-2.0, 2.0)) 443 | # >>> S 444 | array([[-8., 4., -2., 1.], 445 | [-1., 1., -1., 1.], 446 | [ 0., 0., 0., 1.], 447 | [ 1., 1., 1., 1.]]) 448 | # >>> # Max (l-infinity)-normalize the columns 449 | # >>> librosa.util.normalize(S) 450 | array([[-1. , 1. , -1. , 1. ], 451 | [-0.125, 0.25 , -0.5 , 1. ], 452 | [ 0. , 0. , 0. , 1. ], 453 | [ 0.125, 0.25 , 0.5 , 1. ]]) 454 | # >>> # Max (l-infinity)-normalize the rows 455 | # >>> librosa.util.normalize(S, axis=1) 456 | array([[-1. , 0.5 , -0.25 , 0.125], 457 | [-1. , 1. , -1. , 1. ], 458 | [ 0. , 0. , 0. , 1. ], 459 | [ 1. , 1. , 1. , 1. ]]) 460 | # >>> # l1-normalize the columns 461 | # >>> librosa.util.normalize(S, norm=1) 462 | array([[-0.8 , 0.667, -0.5 , 0.25 ], 463 | [-0.1 , 0.167, -0.25 , 0.25 ], 464 | [ 0. , 0. , 0. , 0.25 ], 465 | [ 0.1 , 0.167, 0.25 , 0.25 ]]) 466 | # >>> # l2-normalize the columns 467 | # >>> librosa.util.normalize(S, norm=2) 468 | array([[-0.985, 0.943, -0.816, 0.5 ], 469 | [-0.123, 0.236, -0.408, 0.5 ], 470 | [ 0. , 0. , 0. , 0.5 ], 471 | [ 0.123, 0.236, 0.408, 0.5 ]]) 472 | 473 | # >>> # Thresholding and filling 474 | # >>> S[:, -1] = 1e-308 475 | # >>> S 476 | array([[ -8.000e+000, 4.000e+000, -2.000e+000, 477 | 1.000e-308], 478 | [ -1.000e+000, 1.000e+000, -1.000e+000, 479 | 1.000e-308], 480 | [ 0.000e+000, 0.000e+000, 0.000e+000, 481 | 1.000e-308], 482 | [ 1.000e+000, 1.000e+000, 1.000e+000, 483 | 1.000e-308]]) 484 | 485 | # >>> # By default, small-norm columns are left untouched 486 | # >>> librosa.util.normalize(S) 487 | array([[ -1.000e+000, 1.000e+000, -1.000e+000, 488 | 1.000e-308], 489 | [ -1.250e-001, 2.500e-001, -5.000e-001, 490 | 1.000e-308], 491 | [ 0.000e+000, 0.000e+000, 0.000e+000, 492 | 1.000e-308], 493 | [ 1.250e-001, 2.500e-001, 5.000e-001, 494 | 1.000e-308]]) 495 | # >>> # Small-norm columns can be zeroed out 496 | # >>> librosa.util.normalize(S, fill=False) 497 | array([[-1. , 1. , -1. , 0. ], 498 | [-0.125, 0.25 , -0.5 , 0. ], 499 | [ 0. , 0. , 0. , 0. ], 500 | [ 0.125, 0.25 , 0.5 , 0. ]]) 501 | # >>> # Or set to constant with unit-norm 502 | # >>> librosa.util.normalize(S, fill=True) 503 | array([[-1. , 1. , -1. , 1. ], 504 | [-0.125, 0.25 , -0.5 , 1. ], 505 | [ 0. , 0. , 0. , 1. ], 506 | [ 0.125, 0.25 , 0.5 , 1. ]]) 507 | # >>> # With an l1 norm instead of max-norm 508 | # >>> librosa.util.normalize(S, norm=1, fill=True) 509 | array([[-0.8 , 0.667, -0.5 , 0.25 ], 510 | [-0.1 , 0.167, -0.25 , 0.25 ], 511 | [ 0. , 0. , 0. , 0.25 ], 512 | [ 0.1 , 0.167, 0.25 , 0.25 ]]) 513 | """ 514 | 515 | # Avoid div-by-zero 516 | if threshold is None: 517 | threshold = tiny(S) 518 | 519 | elif threshold <= 0: 520 | raise ValueError( 521 | "threshold={} must be strictly " "positive".format(threshold) 522 | ) 523 | 524 | if fill not in [None, False, True]: 525 | raise ValueError("fill={} must be None or boolean".format(fill)) 526 | 527 | if not np.all(np.isfinite(S)): 528 | raise ValueError("Input must be finite") 529 | 530 | # All norms only depend on magnitude, let's do that first 531 | mag = np.abs(S).astype(float) 532 | 533 | # For max/min norms, filling with 1 works 534 | fill_norm = 1 535 | 536 | if norm == np.inf: 537 | length = np.max(mag, axis=axis, keepdims=True) 538 | 539 | elif norm == -np.inf: 540 | length = np.min(mag, axis=axis, keepdims=True) 541 | 542 | elif norm == 0: 543 | if fill is True: 544 | raise ValueError("Cannot normalize with norm=0 and fill=True") 545 | 546 | length = np.sum(mag > 0, axis=axis, keepdims=True, dtype=mag.dtype) 547 | 548 | elif np.issubdtype(type(norm), np.number) and norm > 0: 549 | length = np.sum(mag ** norm, axis=axis, keepdims=True) ** (1.0 / norm) 550 | 551 | if axis is None: 552 | fill_norm = mag.size ** (-1.0 / norm) 553 | else: 554 | fill_norm = mag.shape[axis] ** (-1.0 / norm) 555 | 556 | elif norm is None: 557 | return S 558 | 559 | else: 560 | raise NotImplementedError("Unsupported norm: {}".format(repr(norm))) 561 | 562 | # indices where norm is below the threshold 563 | small_idx = length < threshold 564 | 565 | S_norm = np.empty_like(S) 566 | if fill is None: 567 | # Leave small indices un-normalized 568 | length[small_idx] = 1.0 569 | S_norm[:] = S / length 570 | 571 | elif fill: 572 | # If we have a non-zero fill value, we locate those entries by 573 | # doing a nan-divide. 574 | # If S was finite, then length is finite (except for small positions) 575 | length[small_idx] = np.nan 576 | S_norm[:] = S / length 577 | S_norm[np.isnan(S_norm)] = fill_norm 578 | else: 579 | # Set small values to zero by doing an inf-divide. 580 | # This is safe (by IEEE-754) as long as S is finite. 581 | length[small_idx] = np.inf 582 | S_norm[:] = S / length 583 | 584 | return S_norm 585 | 586 | 587 | def tiny(x): 588 | """Compute the tiny-value corresponding to an input's data type. 589 | 590 | This is the smallest "usable" number representable in ``x.dtype`` 591 | (e.g., float32). 592 | 593 | This is primarily useful for determining a threshold for 594 | numerical underflow in division or multiplication operations. 595 | 596 | Parameters 597 | ---------- 598 | x : number or np.ndarray 599 | The array to compute the tiny-value for. 600 | All that matters here is ``x.dtype`` 601 | 602 | Returns 603 | ------- 604 | tiny_value : float 605 | The smallest positive usable number for the type of ``x``. 606 | If ``x`` is integer-typed, then the tiny value for ``np.float32`` 607 | is returned instead. 608 | 609 | See Also 610 | -------- 611 | numpy.finfo 612 | 613 | Examples 614 | -------- 615 | For a standard double-precision floating point number: 616 | 617 | # >>> librosa.util.tiny(1.0) 618 | 2.2250738585072014e-308 619 | 620 | Or explicitly as double-precision 621 | 622 | # >>> librosa.util.tiny(np.asarray(1e-5, dtype=np.float64)) 623 | 2.2250738585072014e-308 624 | 625 | Or complex numbers 626 | 627 | # >>> librosa.util.tiny(1j) 628 | 2.2250738585072014e-308 629 | 630 | Single-precision floating point: 631 | 632 | # >>> librosa.util.tiny(np.asarray(1e-5, dtype=np.float32)) 633 | 1.1754944e-38 634 | 635 | Integer 636 | 637 | # >>> librosa.util.tiny(5) 638 | 1.1754944e-38 639 | """ 640 | 641 | # Make sure we have an array view 642 | x = np.asarray(x) 643 | 644 | # Only floating types generate a tiny 645 | if np.issubdtype(x.dtype, np.floating) or np.issubdtype( 646 | x.dtype, np.complexfloating 647 | ): 648 | dtype = x.dtype 649 | else: 650 | dtype = np.float32 651 | 652 | return np.finfo(dtype).tiny 653 | 654 | 655 | if __name__ == '__main__': 656 | """ 657 | Test mel_fn_librosa.py, check mel_fn_librosa.py is same as librosa.filters.mel 658 | "is ok:" should be True 659 | Need librosa 660 | require: pip install librosa 661 | """ 662 | try: 663 | import librosa 664 | except ImportError: 665 | print(' [UNIT_TEST] torchfcpe.mel_tools.mel_fn_librosa: librosa not installed,' 666 | ' if you want check this file with librosa, please install it first.') 667 | exit(1) 668 | from librosa.filters import mel as librosa_mel_fn 669 | 670 | raw_fn = librosa_mel_fn(sr=16000, n_fft=1024, n_mels=128, fmin=0, fmax=8000) 671 | self_fn = mel(sr=16000, n_fft=1024, n_mels=128, fmin=0, fmax=8000) 672 | print(" [UNIT_TEST] torchfcpe.mel_tools.mel_fn_librosa: raw_fn.shape", raw_fn.shape) 673 | print(" [UNIT_TEST] torchfcpe.mel_tools.mel_fn_librosa: self_fn.shape", self_fn.shape) 674 | check = np.allclose(raw_fn, self_fn) 675 | print(" [UNIT_TEST] torchfcpe.mel_tools.mel_fn_librosa: np.allclose(raw_fn, self_fn) is same:", check) 676 | -------------------------------------------------------------------------------- /torchfcpe/model_conformer_naive.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch import nn 4 | import math 5 | from functools import partial 6 | from einops import rearrange, repeat 7 | 8 | from local_attention import LocalAttention 9 | import torch.nn.functional as F 10 | 11 | # From https://github.com/CNChTu/Diffusion-SVC/ by CNChTu 12 | # License: MIT 13 | 14 | 15 | class ConformerNaiveEncoder(nn.Module): 16 | """ 17 | Conformer Naive Encoder 18 | 19 | Args: 20 | dim_model (int): Dimension of model 21 | num_layers (int): Number of layers 22 | num_heads (int): Number of heads 23 | use_norm (bool): Whether to use norm for FastAttention, only True can use bf16/fp16, default False 24 | conv_only (bool): Whether to use only conv module without attention, default False 25 | conv_dropout (float): Dropout rate of conv module, default 0. 26 | atten_dropout (float): Dropout rate of attention module, default 0. 27 | use_pre_norm (bool): Whether to use pre-norm, default False 28 | """ 29 | 30 | def __init__(self, 31 | num_layers: int, 32 | num_heads: int, 33 | dim_model: int, 34 | use_norm: bool = False, 35 | conv_only: bool = False, 36 | conv_dropout: float = 0., 37 | atten_dropout: float = 0., 38 | ): 39 | super().__init__() 40 | self.num_layers = num_layers 41 | self.num_heads = num_heads 42 | self.dim_model = dim_model 43 | self.use_norm = use_norm 44 | self.residual_dropout = 0.1 # 废弃代码,仅做兼容性保留 45 | self.attention_dropout = 0.1 # 废弃代码,仅做兼容性保留 46 | 47 | self.encoder_layers = nn.ModuleList( 48 | [ 49 | CFNEncoderLayer(dim_model, num_heads, use_norm, conv_only, conv_dropout, atten_dropout) 50 | for _ in range(num_layers) 51 | ] 52 | ) 53 | 54 | def forward(self, x, mask=None) -> torch.Tensor: 55 | """ 56 | Args: 57 | x (torch.Tensor): Input tensor (#batch, length, dim_model) 58 | mask (torch.Tensor): Mask tensor, default None 59 | return: 60 | torch.Tensor: Output tensor (#batch, length, dim_model) 61 | """ 62 | 63 | for (i, layer) in enumerate(self.encoder_layers): 64 | x = layer(x, mask) 65 | return x # (#batch, length, dim_model) 66 | 67 | 68 | class CFNEncoderLayer(nn.Module): 69 | """ 70 | Conformer Naive Encoder Layer 71 | 72 | Args: 73 | dim_model (int): Dimension of model 74 | num_heads (int): Number of heads 75 | use_norm (bool): Whether to use norm for FastAttention, only True can use bf16/fp16, default False 76 | conv_only (bool): Whether to use only conv module without attention, default False 77 | conv_dropout (float): Dropout rate of conv module, default 0.1 78 | atten_dropout (float): Dropout rate of attention module, default 0.1 79 | use_pre_norm (bool): Whether to use pre-norm, default False 80 | """ 81 | 82 | def __init__(self, 83 | dim_model: int, 84 | num_heads: int = 8, 85 | use_norm: bool = False, 86 | conv_only: bool = False, 87 | conv_dropout: float = 0., 88 | atten_dropout: float = 0., 89 | ): 90 | super().__init__() 91 | 92 | if conv_dropout > 0.: 93 | self.conformer = nn.Sequential( 94 | ConformerConvModule(dim_model), 95 | nn.Dropout(conv_dropout) 96 | ) 97 | else: 98 | self.conformer = ConformerConvModule(dim_model) 99 | self.norm = nn.LayerNorm(dim_model) 100 | 101 | self.dropout = nn.Dropout(0.1) # 废弃代码,仅做兼容性保留 102 | 103 | # selfatt -> fastatt: performer! 104 | if not conv_only: 105 | self.attn = SelfAttention(dim=dim_model, 106 | heads=num_heads, 107 | causal=False, 108 | use_norm=use_norm, 109 | dropout=atten_dropout, ) 110 | else: 111 | self.attn = None 112 | 113 | def forward(self, x, mask=None) -> torch.Tensor: 114 | """ 115 | Args: 116 | x (torch.Tensor): Input tensor (#batch, length, dim_model) 117 | mask (torch.Tensor): Mask tensor, default None 118 | return: 119 | torch.Tensor: Output tensor (#batch, length, dim_model) 120 | """ 121 | if self.attn is not None: 122 | x = x + (self.attn(self.norm(x), mask=mask)) 123 | 124 | x = x + (self.conformer(x)) 125 | 126 | return x # (#batch, length, dim_model) 127 | 128 | 129 | class ConformerConvModule(nn.Module): 130 | def __init__( 131 | self, 132 | dim, 133 | expansion_factor=2, 134 | kernel_size=31, 135 | dropout=0., 136 | ): 137 | super().__init__() 138 | 139 | inner_dim = dim * expansion_factor 140 | padding = calc_same_padding(kernel_size) 141 | 142 | _norm = nn.LayerNorm(dim) 143 | 144 | self.net = nn.Sequential( 145 | _norm, 146 | Transpose((1, 2)), 147 | nn.Conv1d(dim, inner_dim * 2, 1), 148 | nn.GLU(dim=1), 149 | DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=padding[0], groups=inner_dim), 150 | nn.SiLU(), 151 | nn.Conv1d(inner_dim, dim, 1), 152 | Transpose((1, 2)), 153 | nn.Dropout(dropout) 154 | ) 155 | 156 | def forward(self, x): 157 | return self.net(x) 158 | 159 | 160 | class DepthWiseConv1d(nn.Module): 161 | def __init__(self, chan_in, chan_out, kernel_size, padding, groups): 162 | super().__init__() 163 | self.conv = nn.Conv1d(chan_in, chan_out, kernel_size=kernel_size, padding=padding, groups=groups) 164 | 165 | def forward(self, x): 166 | return self.conv(x) 167 | 168 | 169 | def calc_same_padding(kernel_size): 170 | pad = kernel_size // 2 171 | return (pad, pad - (kernel_size + 1) % 2) 172 | 173 | 174 | class Transpose(nn.Module): 175 | def __init__(self, dims): 176 | super().__init__() 177 | assert len(dims) == 2, 'dims must be a tuple of two dimensions' 178 | self.dims = dims 179 | 180 | def forward(self, x): 181 | return x.transpose(*self.dims) 182 | 183 | 184 | class SelfAttention(nn.Module): 185 | def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None, 186 | feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, 187 | dropout=0., no_projection=False, use_norm=False): 188 | super().__init__() 189 | assert dim % heads == 0, 'dimension must be divisible by number of heads' 190 | dim_head = default(dim_head, dim // heads) 191 | inner_dim = dim_head * heads 192 | self.fast_attention = FastAttention(dim_head, nb_features, causal=causal, 193 | generalized_attention=generalized_attention, kernel_fn=kernel_fn, 194 | qr_uniform_q=qr_uniform_q, no_projection=no_projection, 195 | use_norm=use_norm) 196 | 197 | self.heads = heads 198 | self.global_heads = heads - local_heads 199 | self.local_attn = LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, 200 | look_forward=int(not causal), 201 | rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None 202 | 203 | # print (heads, nb_features, dim_head) 204 | # name_embedding = torch.zeros(110, heads, dim_head, dim_head) 205 | # self.name_embedding = nn.Parameter(name_embedding, requires_grad=True) 206 | 207 | self.to_q = nn.Linear(dim, inner_dim) 208 | self.to_k = nn.Linear(dim, inner_dim) 209 | self.to_v = nn.Linear(dim, inner_dim) 210 | self.to_out = nn.Linear(inner_dim, dim) 211 | self.dropout = nn.Dropout(dropout) 212 | 213 | @torch.no_grad() 214 | def redraw_projection_matrix(self): 215 | self.fast_attention.redraw_projection_matrix() 216 | # torch.nn.init.zeros_(self.name_embedding) 217 | # print (torch.sum(self.name_embedding)) 218 | 219 | def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs): 220 | b, n, _, h, gh = *x.shape, self.heads, self.global_heads 221 | 222 | cross_attend = exists(context) 223 | 224 | context = default(context, x) 225 | context_mask = default(context_mask, mask) if not cross_attend else context_mask 226 | # print (torch.sum(self.name_embedding)) 227 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) 228 | 229 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) 230 | (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v)) 231 | 232 | attn_outs = [] 233 | # print (name) 234 | # print (self.name_embedding[name].size()) 235 | if not empty(q): 236 | if exists(context_mask): 237 | global_mask = context_mask[:, None, :, None] 238 | v.masked_fill_(~global_mask, 0.) 239 | if cross_attend: 240 | pass 241 | # print (torch.sum(self.name_embedding)) 242 | # out = self.fast_attention(q,self.name_embedding[name],None) 243 | # print (torch.sum(self.name_embedding[...,-1:])) 244 | # attn_outs.append(out) 245 | else: 246 | out = self.fast_attention(q, k, v) 247 | attn_outs.append(out) 248 | 249 | if not empty(lq): 250 | assert not cross_attend, 'local attention is not compatible with cross attention' 251 | out = self.local_attn(lq, lk, lv, input_mask=mask) 252 | attn_outs.append(out) 253 | 254 | out = torch.cat(attn_outs, dim=1) 255 | out = rearrange(out, 'b h n d -> b n (h d)') 256 | out = self.to_out(out) 257 | return self.dropout(out) 258 | 259 | 260 | class FastAttention(nn.Module): 261 | def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, 262 | kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False, use_norm=False): 263 | super().__init__() 264 | nb_features = default(nb_features, int(dim_heads * math.log(dim_heads))) 265 | 266 | self.dim_heads = dim_heads 267 | self.nb_features = nb_features 268 | self.ortho_scaling = ortho_scaling 269 | 270 | self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, 271 | nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q) 272 | projection_matrix = self.create_projection() 273 | self.register_buffer('projection_matrix', projection_matrix) 274 | 275 | self.generalized_attention = generalized_attention 276 | self.kernel_fn = kernel_fn 277 | 278 | # if this is turned on, no projection will be used 279 | # queries and keys will be softmax-ed as in the original efficient attention paper 280 | self.no_projection = no_projection 281 | 282 | self.causal = causal 283 | self.use_norm = use_norm 284 | ''' 285 | if causal: 286 | try: 287 | import fast_transformers.causal_product.causal_product_cuda 288 | self.causal_linear_fn = partial(causal_linear_attention) 289 | except ImportError: 290 | print( 291 | 'unable to import cuda code for auto-regressive Performer. will default to the memory inefficient non-cuda version') 292 | self.causal_linear_fn = causal_linear_attention_noncuda 293 | ''' 294 | if self.causal or self.generalized_attention: 295 | raise NotImplementedError('Causal and generalized attention not implemented yet') 296 | 297 | @torch.no_grad() 298 | def redraw_projection_matrix(self): 299 | projections = self.create_projection() 300 | self.projection_matrix.copy_(projections) 301 | del projections 302 | 303 | def forward(self, q, k, v): 304 | device = q.device 305 | 306 | if self.use_norm: 307 | q = q / (q.norm(dim=-1, keepdim=True) + 1e-8) 308 | k = k / (k.norm(dim=-1, keepdim=True) + 1e-8) 309 | 310 | if self.no_projection: 311 | q = q.softmax(dim=-1) 312 | k = torch.exp(k) if self.causal else k.softmax(dim=-2) 313 | 314 | elif self.generalized_attention: 315 | ''' 316 | create_kernel = partial(generalized_kernel, kernel_fn=self.kernel_fn, 317 | projection_matrix=self.projection_matrix, device=device) 318 | q, k = map(create_kernel, (q, k)) 319 | ''' 320 | raise NotImplementedError('generalized attention not implemented yet') 321 | 322 | else: 323 | create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=device) 324 | 325 | q = create_kernel(q, is_query=True) 326 | k = create_kernel(k, is_query=False) 327 | 328 | attn_fn = linear_attention if not self.causal else self.causal_linear_fn 329 | if v is None: 330 | out = attn_fn(q, k, None) 331 | return out 332 | else: 333 | out = attn_fn(q, k, v) 334 | return out 335 | 336 | 337 | def linear_attention(q, k, v): 338 | if v is None: 339 | # print (k.size(), q.size()) 340 | out = torch.einsum('...ed,...nd->...ne', k, q) 341 | return out 342 | 343 | else: 344 | k_cumsum = k.sum(dim=-2) 345 | # k_cumsum = k.sum(dim = -2) 346 | D_inv = 1. / (torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q)) + 1e-8) 347 | 348 | context = torch.einsum('...nd,...ne->...de', k, v) 349 | # print ("TRUEEE: ", context.size(), q.size(), D_inv.size()) 350 | out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv) 351 | return out 352 | 353 | 354 | def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None): 355 | b, h, *_ = data.shape 356 | # (batch size, head, length, model_dim) 357 | 358 | # normalize model dim 359 | data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1. 360 | 361 | # what is ration?, projection_matrix.shape[0] --> 266 362 | 363 | ratio = (projection_matrix.shape[0] ** -0.5) 364 | 365 | projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h) 366 | projection = projection.type_as(data) 367 | 368 | # data_dash = w^T x 369 | data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection) 370 | 371 | # diag_data = D**2 372 | diag_data = data ** 2 373 | diag_data = torch.sum(diag_data, dim=-1) 374 | diag_data = (diag_data / 2.0) * (data_normalizer ** 2) 375 | diag_data = diag_data.unsqueeze(dim=-1) 376 | 377 | # print () 378 | if is_query: 379 | data_dash = ratio * ( 380 | torch.exp(data_dash - diag_data - 381 | torch.max(data_dash, dim=-1, keepdim=True).values) + eps) 382 | else: 383 | data_dash = ratio * ( 384 | torch.exp(data_dash - diag_data + eps)) # - torch.max(data_dash)) + eps) 385 | 386 | return data_dash.type_as(data) 387 | 388 | 389 | def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None): 390 | nb_full_blocks = int(nb_rows / nb_columns) 391 | # print (nb_full_blocks) 392 | block_list = [] 393 | 394 | for _ in range(nb_full_blocks): 395 | q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device) 396 | block_list.append(q) 397 | # block_list[n] is a orthogonal matrix ... (model_dim * model_dim) 398 | # print (block_list[0].size(), torch.einsum('...nd,...nd->...n', block_list[0], torch.roll(block_list[0],1,1))) 399 | # print (nb_rows, nb_full_blocks, nb_columns) 400 | remaining_rows = nb_rows - nb_full_blocks * nb_columns 401 | # print (remaining_rows) 402 | if remaining_rows > 0: 403 | q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device) 404 | # print (q[:remaining_rows].size()) 405 | block_list.append(q[:remaining_rows]) 406 | 407 | final_matrix = torch.cat(block_list) 408 | 409 | if scaling == 0: 410 | multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1) 411 | elif scaling == 1: 412 | multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device) 413 | else: 414 | raise ValueError(f'Invalid scaling {scaling}') 415 | 416 | return torch.diag(multiplier) @ final_matrix 417 | 418 | 419 | def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None): 420 | unstructured_block = torch.randn((cols, cols), device=device) 421 | q, r = torch.linalg.qr(unstructured_block.cpu(), mode='reduced') 422 | q, r = map(lambda t: t.to(device), (q, r)) 423 | 424 | # proposed by @Parskatt 425 | # to make sure Q is uniform https://arxiv.org/pdf/math-ph/0609050.pdf 426 | if qr_uniform_q: 427 | d = torch.diag(r, 0) 428 | q *= d.sign() 429 | return q.t() 430 | 431 | 432 | def default(val, d): 433 | return val if exists(val) else d 434 | 435 | 436 | def exists(val): 437 | return val is not None 438 | 439 | 440 | def empty(tensor): 441 | return tensor.numel() == 0 442 | -------------------------------------------------------------------------------- /torchfcpe/model_convnext.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class ConvNeXtBlock(nn.Module): 8 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. 9 | 10 | Args: 11 | dim (int): Number of input channels. 12 | intermediate_dim (int): Dimensionality of the intermediate layer. 13 | dilation (int, optional): Dilation factor for the depthwise convolution. Defaults to 1. 14 | kernel_size (int, optional): Kernel size for the depthwise convolution. Defaults to 7. 15 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 16 | Defaults to 1e-6. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | dim: int, 22 | intermediate_dim: int, 23 | dilation: int = 1, 24 | kernel_size: int = 7, 25 | layer_scale_init_value: Optional[float] = 1e-6, 26 | ): 27 | super().__init__() 28 | self.dwconv = nn.Conv1d( 29 | dim, 30 | dim, 31 | kernel_size=kernel_size, 32 | groups=dim, 33 | dilation=dilation, 34 | padding=int(dilation * (kernel_size - 1) / 2), 35 | ) # depthwise conv 36 | self.norm = nn.LayerNorm(dim, eps=1e-6) 37 | self.pwconv1 = nn.Linear( 38 | dim, intermediate_dim 39 | ) # pointwise/1x1 convs, implemented with linear layers 40 | self.act = nn.GELU() 41 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 42 | self.gamma = ( 43 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 44 | if layer_scale_init_value is not None and layer_scale_init_value > 0 45 | else None 46 | ) 47 | 48 | def forward(self, x: torch.Tensor) -> torch.Tensor: 49 | residual = x 50 | 51 | x = self.dwconv(x) 52 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 53 | x = self.norm(x) 54 | x = self.pwconv1(x) 55 | x = self.act(x) 56 | x = self.pwconv2(x) 57 | if self.gamma is not None: 58 | x = self.gamma * x 59 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 60 | 61 | x = residual + x 62 | return x 63 | 64 | 65 | class ConvNeXt(nn.Module): 66 | """ConvNeXt layers 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | num_layers (int): Number of ConvNeXt layers. 71 | mlp_factor (int, optional): Factor for the intermediate layer dimensionality. Defaults to 4. 72 | dilation_cycle (int, optional): Cycle for the dilation factor. Defaults to 4. 73 | kernel_size (int, optional): Kernel size for the depthwise convolution. Defaults to 7. 74 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 75 | Defaults to 1e-6. 76 | """ 77 | 78 | def __init__( 79 | self, 80 | dim: int, 81 | num_layers: int = 20, 82 | mlp_factor: int = 4, 83 | dilation_cycle: int = 4, 84 | kernel_size: int = 7, 85 | layer_scale_init_value: Optional[float] = 1e-6, 86 | ): 87 | super().__init__() 88 | self.dim = dim 89 | self.num_layers = num_layers 90 | self.mlp_factor = mlp_factor 91 | self.dilation_cycle = dilation_cycle 92 | self.kernel_size = kernel_size 93 | self.layer_scale_init_value = layer_scale_init_value 94 | 95 | self.layers = nn.ModuleList( 96 | [ 97 | ConvNeXtBlock( 98 | dim, 99 | dim * mlp_factor, 100 | dilation=(2 ** (i % dilation_cycle)), 101 | kernel_size=kernel_size, 102 | layer_scale_init_value=1e-6, 103 | ) 104 | for i in range(num_layers) 105 | ] 106 | ) 107 | 108 | def forward(self, x: torch.Tensor) -> torch.Tensor: 109 | for layer in self.layers: 110 | x = layer(x) 111 | return x 112 | -------------------------------------------------------------------------------- /torchfcpe/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # import weight_norm from different version of pytorch 6 | try: 7 | from torch.nn.utils.parametrizations import weight_norm 8 | except ImportError: 9 | from torch.nn.utils import weight_norm 10 | 11 | from .model_conformer_naive import ConformerNaiveEncoder 12 | 13 | 14 | class CFNaiveMelPE(nn.Module): 15 | """ 16 | Conformer-based Mel-spectrogram Prediction Encoderc in Fast Context-based Pitch Estimation 17 | 18 | Args: 19 | input_channels (int): Number of input channels, should be same as the number of bins of mel-spectrogram. 20 | out_dims (int): Number of output dimensions, also class numbers. 21 | hidden_dims (int): Number of hidden dimensions. 22 | n_layers (int): Number of conformer layers. 23 | f0_max (float): Maximum frequency of f0. 24 | f0_min (float): Minimum frequency of f0. 25 | use_fa_norm (bool): Whether to use fast attention norm, default False 26 | conv_only (bool): Whether to use only conv module without attention, default False 27 | conv_dropout (float): Dropout rate of conv module, default 0. 28 | atten_dropout (float): Dropout rate of attention module, default 0. 29 | use_harmonic_emb (bool): Whether to use harmonic embedding, default False 30 | use_pre_norm (bool): Whether to use pre norm, default False 31 | """ 32 | 33 | def __init__(self, 34 | input_channels: int, 35 | out_dims: int, 36 | hidden_dims: int = 512, 37 | n_layers: int = 6, 38 | n_heads: int = 8, 39 | f0_max: float = 1975.5, 40 | f0_min: float = 32.70, 41 | use_fa_norm: bool = False, 42 | conv_only: bool = False, 43 | conv_dropout: float = 0., 44 | atten_dropout: float = 0., 45 | use_harmonic_emb: bool = False, 46 | ): 47 | super().__init__() 48 | self.input_channels = input_channels 49 | self.out_dims = out_dims 50 | self.hidden_dims = hidden_dims 51 | self.n_layers = n_layers 52 | self.n_heads = n_heads 53 | self.f0_max = f0_max 54 | self.f0_min = f0_min 55 | self.use_fa_norm = use_fa_norm 56 | self.residual_dropout = 0.1 # 废弃代码,仅做兼容性保留 57 | self.attention_dropout = 0.1 # 废弃代码,仅做兼容性保留 58 | 59 | # Harmonic embedding 60 | if use_harmonic_emb: 61 | self.harmonic_emb = nn.Embedding(9, hidden_dims) 62 | else: 63 | self.harmonic_emb = None 64 | 65 | # Input stack, convert mel-spectrogram to hidden_dims 66 | self.input_stack = nn.Sequential( 67 | nn.Conv1d(input_channels, hidden_dims, 3, 1, 1), 68 | nn.GroupNorm(4, hidden_dims), 69 | nn.LeakyReLU(), 70 | nn.Conv1d(hidden_dims, hidden_dims, 3, 1, 1) 71 | ) 72 | # Conformer Encoder 73 | self.net = ConformerNaiveEncoder( 74 | num_layers=n_layers, 75 | num_heads=n_heads, 76 | dim_model=hidden_dims, 77 | use_norm=use_fa_norm, 78 | conv_only=conv_only, 79 | conv_dropout=conv_dropout, 80 | atten_dropout=atten_dropout, 81 | ) 82 | # LayerNorm 83 | self.norm = nn.LayerNorm(hidden_dims) 84 | # Output stack, convert hidden_dims to out_dims 85 | self.output_proj = weight_norm( 86 | nn.Linear(hidden_dims, out_dims) 87 | ) 88 | # Cent table buffer 89 | """ 90 | self.cent_table_b = torch.Tensor( 91 | np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], 92 | out_dims)) 93 | """ 94 | # use torch have very small difference like 1e-4, up to 1e-3, but it may be better to use numpy? 95 | self.cent_table_b = torch.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], 96 | self.f0_to_cent(torch.Tensor([f0_max]))[0], 97 | out_dims).detach() 98 | self.register_buffer("cent_table", self.cent_table_b) 99 | # gaussian_blurred_cent_mask_b buffer 100 | self.gaussian_blurred_cent_mask_b = (1200. * torch.log2(torch.Tensor([self.f0_max / 10.])))[0].detach() 101 | self.register_buffer("gaussian_blurred_cent_mask", self.gaussian_blurred_cent_mask_b) 102 | 103 | def forward(self, x: torch.Tensor, _h_emb=None) -> torch.Tensor: 104 | """ 105 | Args: 106 | x (torch.Tensor): Input mel-spectrogram, shape (B, T, input_channels) or (B, T, mel_bins). 107 | _h_emb (int): Harmonic embedding index, like 0, 1, 2, only use in train. Default: None. 108 | return: 109 | torch.Tensor: Predicted f0 latent, shape (B, T, out_dims). 110 | """ 111 | x = self.input_stack(x.transpose(-1, -2)).transpose(-1, -2) 112 | if self.harmonic_emb is not None: 113 | if _h_emb is None: 114 | x = x + self.harmonic_emb(torch.LongTensor([0]).to(x.device)) 115 | else: 116 | x = x + self.harmonic_emb(torch.LongTensor([int(_h_emb)]).to(x.device)) 117 | x = self.net(x) 118 | x = self.norm(x) 119 | x = self.output_proj(x) 120 | x = torch.sigmoid(x) 121 | return x # latent (B, T, out_dims) 122 | 123 | @torch.no_grad() 124 | def latent2cents_decoder(self, 125 | y: torch.Tensor, 126 | threshold: float = 0.05, 127 | mask: bool = True 128 | ) -> torch.Tensor: 129 | """ 130 | Convert latent to cents. 131 | Args: 132 | y (torch.Tensor): Latent, shape (B, T, out_dims). 133 | threshold (float): Threshold to mask. Default: 0.05. 134 | mask (bool): Whether to mask. Default: True. 135 | return: 136 | torch.Tensor: Cents, shape (B, T, 1). 137 | """ 138 | B, N, _ = y.size() 139 | ci = self.cent_table[None, None, :].expand(B, N, -1) 140 | rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True) # cents: [B,N,1] 141 | if mask: 142 | confident = torch.max(y, dim=-1, keepdim=True)[0] 143 | confident_mask = torch.ones_like(confident) 144 | confident_mask[confident <= threshold] = float("-INF") 145 | rtn = rtn * confident_mask 146 | return rtn # (B, T, 1) 147 | 148 | @torch.no_grad() 149 | def latent2cents_local_decoder(self, 150 | y: torch.Tensor, 151 | threshold: float = 0.05, 152 | mask: bool = True 153 | ) -> torch.Tensor: 154 | """ 155 | Convert latent to cents. Use local argmax. 156 | Args: 157 | y (torch.Tensor): Latent, shape (B, T, out_dims). 158 | threshold (float): Threshold to mask. Default: 0.05. 159 | mask (bool): Whether to mask. Default: True. 160 | return: 161 | torch.Tensor: Cents, shape (B, T, 1). 162 | """ 163 | B, N, _ = y.size() 164 | ci = self.cent_table[None, None, :].expand(B, N, -1) 165 | confident, max_index = torch.max(y, dim=-1, keepdim=True) 166 | local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4) 167 | local_argmax_index[local_argmax_index < 0] = 0 168 | local_argmax_index[local_argmax_index >= self.out_dims] = self.out_dims - 1 169 | ci_l = torch.gather(ci, -1, local_argmax_index) 170 | y_l = torch.gather(y, -1, local_argmax_index) 171 | rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True) # cents: [B,N,1] 172 | if mask: 173 | confident_mask = torch.ones_like(confident) 174 | confident_mask[confident <= threshold] = float("-INF") 175 | rtn = rtn * confident_mask 176 | return rtn # (B, T, 1) 177 | 178 | @torch.no_grad() 179 | def gaussian_blurred_cent2latent(self, cents): # cents: [B,N,1] 180 | """ 181 | Convert cents to latent. 182 | Args: 183 | cents (torch.Tensor): Cents, shape (B, T, 1). 184 | return: 185 | torch.Tensor: Latent, shape (B, T, out_dims). 186 | """ 187 | mask = (cents > 0.1) & (cents < self.gaussian_blurred_cent_mask) 188 | # mask = (cents>0.1) & (cents<(1200.*np.log2(self.f0_max/10.))) 189 | B, N, _ = cents.size() 190 | ci = self.cent_table[None, None, :].expand(B, N, -1) 191 | return torch.exp(-torch.square(ci - cents) / 1250) * mask.float() 192 | 193 | @torch.no_grad() 194 | def infer(self, 195 | mel: torch.Tensor, 196 | decoder: str = "local_argmax", # "argmax" or "local_argmax" 197 | threshold: float = 0.05, 198 | ) -> torch.Tensor: 199 | """ 200 | Args: 201 | mel (torch.Tensor): Input mel-spectrogram, shape (B, T, input_channels) or (B, T, mel_bins). 202 | decoder (str): Decoder type. Default: "local_argmax". 203 | threshold (float): Threshold to mask. Default: 0.05. 204 | """ 205 | latent = self.forward(mel) 206 | if decoder == "argmax": 207 | cents = self.latent2cents_decoder(latent, threshold=threshold) 208 | elif decoder == "local_argmax": 209 | cents = self.latent2cents_local_decoder(latent, threshold=threshold) 210 | else: 211 | raise ValueError(f" [x] Unknown decoder type {decoder}.") 212 | f0 = self.cent_to_f0(cents) 213 | return f0 # (B, T, 1) 214 | 215 | def train_and_loss(self, mel, gt_f0, loss_scale=10): 216 | """ 217 | Args: 218 | mel (torch.Tensor): Input mel-spectrogram, shape (B, T, input_channels) or (B, T, mel_bins). 219 | gt_f0 (torch.Tensor): Ground truth f0, shape (B, T, 1). 220 | loss_scale (float): Loss scale. Default: 10. 221 | return: loss 222 | """ 223 | if mel.shape[-2] != gt_f0.shape[-2]: 224 | _len = min(mel.shape[-2], gt_f0.shape[-2]) 225 | mel = mel[:, :_len, :] 226 | gt_f0 = gt_f0[:, :_len, :] 227 | gt_cent_f0 = self.f0_to_cent(gt_f0) # mel f0, [B,N,1] 228 | x_gt = self.gaussian_blurred_cent2latent(gt_cent_f0) # [B,N,out_dim] 229 | if self.harmonic_emb is not None: 230 | x = self.forward(mel, _h_emb=0) 231 | x_half = self.forward(mel, _h_emb=1) 232 | x_gt_half = self.gaussian_blurred_cent2latent(gt_cent_f0 / 2) 233 | x_gt_double = self.gaussian_blurred_cent2latent(gt_cent_f0 * 2) 234 | x_double = self.forward(mel, _h_emb=2) 235 | loss = F.binary_cross_entropy(x, x_gt) 236 | loss_half = F.binary_cross_entropy(x_half, x_gt_half) 237 | loss_double = F.binary_cross_entropy(x_double, x_gt_double) 238 | loss = loss + (loss_half + loss_double) / 2 239 | loss = loss * loss_scale 240 | else: 241 | x = self.forward(mel) # [B,N,out_dim] 242 | loss = F.binary_cross_entropy(x, x_gt) * loss_scale 243 | return loss 244 | 245 | @torch.no_grad() 246 | def cent_to_f0(self, cent: torch.Tensor) -> torch.Tensor: 247 | """ 248 | Convert cent to f0. Args: cent (torch.Tensor): Cent, shape = (B, T, 1). return: torch.Tensor: f0, shape = (B, T, 1). 249 | """ 250 | f0 = 10. * 2 ** (cent / 1200.) 251 | return f0 # (B, T, 1) 252 | 253 | @torch.no_grad() 254 | def f0_to_cent(self, f0: torch.Tensor) -> torch.Tensor: 255 | """ 256 | Convert f0 to cent. Args: f0 (torch.Tensor): f0, shape = (B, T, 1). return: torch.Tensor: Cent, shape = (B, T, 1). 257 | """ 258 | cent = 1200. * torch.log2(f0 / 10.) 259 | return cent # (B, T, 1) 260 | -------------------------------------------------------------------------------- /torchfcpe/models_infer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | 4 | import torch 5 | from einops import rearrange 6 | 7 | from torchfcpe.f02midi.transpose import f02midi 8 | 9 | from .models import CFNaiveMelPE 10 | from .tools import ( 11 | DotDict, 12 | catch_none_args_must, 13 | catch_none_args_opti, 14 | get_config_json_in_same_path, 15 | get_device, 16 | spawn_wav2mel, 17 | ) 18 | from .torch_interp import batch_interp_with_replacement_detach 19 | 20 | 21 | def ensemble_f0(f0s, key_shift_list, tta_uv_penalty): 22 | """_summary_ 23 | 24 | Args: 25 | f0s (torch.Tensor): (B, T, len(key_shift_list)) 26 | key_shift_list (list): list of key shifts 27 | tta_uv_penalty (float,int): uv penalty 28 | 29 | Returns: 30 | f0: (B, T, 1) 31 | """ 32 | device = f0s.device 33 | # convert f0 to note 34 | f0s = f0s / ( 35 | torch.pow( 36 | 2, 37 | torch.tensor(key_shift_list, device=device) 38 | .to(device) 39 | .unsqueeze(0) 40 | .unsqueeze(0) 41 | / 12, 42 | ) 43 | ) 44 | notes = torch.log2(f0s / 440) * 12 + 69 45 | notes[notes < 0] = 0 46 | 47 | # select best note 48 | # 使用动态规划选择最优的音高 49 | # 惩罚1:uv的惩罚固定为超参数uv_penalty ** 2,v转为uv时额外惩罚两次 50 | # 惩罚2:相邻帧音高的L2距离(uv和v互转的过程除外),距离小于0.5时忽略不计 51 | uv_penalty = tta_uv_penalty**2 52 | dp = torch.zeros_like(notes, device=device) 53 | # dp[b,t,c]表示,对于样本b,0到第t帧的所有选择中,选择第c个f0作为第t帧的结尾的最小惩罚 54 | backtrack = torch.zeros_like(notes, device=device).long() 55 | # backtrack[b,t,c]表示,对于样本b,0到第t帧的所有选择中,选择第c个f0作为第t帧的结尾时,t-1帧结尾的选择,值域为0到len(f0_list)-1 56 | # init 57 | dp[:, 0, :] = (notes[:, 0, :] <= 0) * uv_penalty 58 | # forward 59 | for t in range(1, notes.size(1)): 60 | penalty = torch.zeros( 61 | [notes.size(0), notes.size(2), notes.size(2)], device=device 62 | ) 63 | # [b,c1,c2]表示第b个样本中,t-1帧选择c1,t帧选择c2的惩罚 64 | 65 | # t帧是uv的情况 66 | t_uv = notes[:, t, :] <= 0 67 | penalty += uv_penalty * t_uv.unsqueeze(1) 68 | 69 | # t帧是v的情况 70 | # t-1帧也是v的情况 71 | t1_uv = notes[:, t - 1, :] <= 0 72 | l2 = torch.pow( 73 | (notes[:, t - 1, :].unsqueeze(-1) - notes[:, t, :].unsqueeze(1)) 74 | * (~t1_uv).unsqueeze(-1) 75 | * (~t_uv).unsqueeze(1), 76 | 2, 77 | ) 78 | l2 = l2 - 0.5 79 | l2 = l2 * (l2 > 0) 80 | penalty += l2 81 | 82 | # t-1帧是uv的情况,uv转v的惩罚 83 | penalty += t1_uv.unsqueeze(-1) * (~t_uv).unsqueeze(1) * uv_penalty * 2 84 | 85 | # 选择最小惩罚 86 | min_value, min_indices = torch.min( 87 | dp[:, t - 1, :].unsqueeze(-1) + penalty, dim=1 88 | ) 89 | dp[:, t, :] = min_value 90 | backtrack[:, t, :] = min_indices 91 | 92 | # backtrack 93 | t = f0s.size(1) - 1 94 | f0_result = torch.zeros_like(f0s[:, :, 0], device=device) 95 | min_indices = torch.argmin(dp[:, t, :], dim=-1) 96 | for i in range(0, t + 1): 97 | f0_result[:, t - i] = f0s[:, t - i, min_indices] 98 | min_indices = backtrack[:, t - i, min_indices] 99 | 100 | return f0_result.unsqueeze(-1) 101 | 102 | 103 | class InferCFNaiveMelPE(torch.nn.Module): 104 | """Infer CFNaiveMelPE 105 | Args: 106 | args (DotDict): Config. 107 | state_dict (dict): Model state dict. 108 | """ 109 | 110 | def __init__(self, args, state_dict): 111 | super().__init__() 112 | self.wav2mel = spawn_wav2mel(args, device="cpu") 113 | self.model = spawn_model(args) 114 | self.model.load_state_dict(state_dict) 115 | self.model.eval() 116 | self.args_dict = dict(args) 117 | self.register_buffer( 118 | "tensor_device_marker", torch.tensor(1.0).float(), persistent=False 119 | ) 120 | 121 | def forward( 122 | self, 123 | wav: torch.Tensor, 124 | sr: [int, float], 125 | decoder_mode: str = "local_argmax", 126 | threshold: float = 0.006, 127 | key_shifts: list = [0], 128 | ) -> torch.Tensor: 129 | """Infer 130 | Args: 131 | wav (torch.Tensor): Input wav, (B, n_sample, 1). 132 | sr (int, float): Input wav sample rate. 133 | decoder_mode (str): Decoder type. Default: "local_argmax", support "argmax" or "local_argmax". 134 | threshold (float): Threshold to mask. Default: 0.006. 135 | key_shifts (list): Key shifts. Default: [0]. 136 | return: f0 (torch.Tensor): f0 Hz, shape (B, (n_sample//hop_size + 1), 1). 137 | """ 138 | with torch.no_grad(): 139 | wav = wav.to(self.tensor_device_marker.device) 140 | mels = torch.stack( 141 | [self.wav2mel(wav, sr, keyshift=keyshift) for keyshift in key_shifts], 142 | -1, 143 | ) 144 | mels = rearrange(mels, "B T C K -> (B K) T C") 145 | f0s = self.model.infer(mels, decoder=decoder_mode, threshold=threshold) 146 | f0s = rearrange(f0s, "(B K) T 1 -> B T (K 1)", K=len(key_shifts)) 147 | return f0s # (B, T, len(key_shifts)) 148 | 149 | def infer( 150 | self, 151 | wav: torch.Tensor, 152 | sr: [int, float], 153 | decoder_mode: str = "local_argmax", 154 | threshold: float = 0.006, 155 | f0_min: float = None, 156 | f0_max: float = None, 157 | interp_uv: bool = False, 158 | output_interp_target_length: int = None, 159 | return_uv: bool = False, 160 | test_time_augmentation: bool = False, 161 | tta_uv_penalty: float = 12.0, 162 | tta_key_shifts: list = [0, -12, 12], 163 | tta_use_origin_uv=False, 164 | ) -> torch.Tensor or (torch.Tensor, torch.Tensor): 165 | """Infer 166 | Args: 167 | wav (torch.Tensor): Input wav, (B, n_sample, 1). 168 | sr (int, float): Input wav sample rate. 169 | decoder_mode (str): Decoder type. Default: "local_argmax", support "argmax" or "local_argmax". 170 | threshold (float): Threshold to mask. Default: 0.006. 171 | f0_min (float): Minimum f0. Default: None. Use in post-processing. 172 | f0_max (float): Maximum f0. Default: None. Use in post-processing. 173 | interp_uv (bool): Interpolate unvoiced frames. Default: False. 174 | output_interp_target_length (int): Output interpolation target length. Default: None. 175 | return_uv (bool): Return unvoiced frames. Default: False. 176 | test_time_augmentation (bool): Test time augmentation. If enabled, the output may be better but slower. Default: False. 177 | tta_uv_penalty (float): Test time augmentation unvoiced penalty. Default: 12.0. 178 | tta_key_shifts (list): Test time augmentation key shifts. Default: [0, -12, 12]. 179 | tta_use_origin_uv (bool): Use origin uv. Default: False 180 | return: f0 (torch.Tensor): f0 Hz, shape (B, (n_sample//hop_size + 1) or output_interp_target_length, 1). 181 | if return_uv is True, return f0, uv. the shape of uv(torch.Tensor) is like f0. 182 | """ 183 | # infer 184 | if test_time_augmentation: 185 | assert len(tta_key_shifts) > 0 186 | flag = 0 187 | if tta_use_origin_uv: 188 | if 0 not in tta_key_shifts: 189 | flag = 1 190 | tta_key_shifts.append(0) 191 | tta_key_shifts.sort(key=lambda x: (x if x >= 0 else -x / 2)) 192 | f0s = self.__call__(wav, sr, decoder_mode, threshold, tta_key_shifts) 193 | f0 = ensemble_f0( 194 | f0s[:, :, flag:], 195 | tta_key_shifts[flag:], 196 | tta_uv_penalty, 197 | ) 198 | if tta_use_origin_uv: 199 | f0_for_uv = f0s[:, :, [0]] 200 | else: 201 | f0_for_uv = f0 202 | else: 203 | f0 = self.__call__(wav, sr, decoder_mode, threshold) 204 | f0_for_uv = f0 205 | if f0_min is None: 206 | f0_min = self.args_dict["model"]["f0_min"] 207 | uv = (f0_for_uv < f0_min).type(f0_for_uv.dtype) 208 | f0 = f0 * (1 - uv) 209 | # interp 210 | if interp_uv: 211 | f0 = batch_interp_with_replacement_detach( 212 | uv.squeeze(-1).bool(), f0.squeeze(-1) 213 | ).unsqueeze(-1) 214 | if f0_max is not None: 215 | f0[f0 > f0_max] = f0_max 216 | if output_interp_target_length is not None: 217 | f0 = torch.where(f0 == 0, float("nan"), f0) 218 | f0 = torch.nn.functional.interpolate( 219 | f0.transpose(1, 2), 220 | size=int(output_interp_target_length), 221 | mode="linear", 222 | ).transpose(1, 2) 223 | f0 = torch.where(f0.isnan(), float(0.0), f0) 224 | # if return_uv is True, interp and return uv 225 | if return_uv: 226 | uv = torch.nn.functional.interpolate( 227 | uv.transpose(1, 2), 228 | size=int(output_interp_target_length), 229 | mode="nearest", 230 | ).transpose(1, 2) 231 | return f0, uv 232 | else: 233 | return f0 234 | 235 | def extact_midi( 236 | self, 237 | wav: torch.Tensor, 238 | sr: [int, float], 239 | output_path: str, 240 | decoder_mode: str = "local_argmax", 241 | threshold: float = 0.006, 242 | f0_min: float = None, 243 | f0_max: float = None, 244 | tempo: float = None, 245 | ): 246 | f0 = self.infer( 247 | wav, 248 | sr, 249 | decoder_mode, 250 | threshold, 251 | f0_min, 252 | f0_max, 253 | ) 254 | f0 = f0.squeeze(-1).squeeze(0).cpu().numpy() 255 | wav = wav.squeeze(0).squeeze(-1).cpu().numpy() 256 | return f02midi(f0, tempo=tempo, output_path=output_path, sr=sr, y=wav) 257 | 258 | def get_hop_size(self) -> int: 259 | """Get hop size""" 260 | return DotDict(self.args_dict).mel.hop_size 261 | 262 | def get_hop_size_ms(self) -> float: 263 | """Get hop size in ms""" 264 | return ( 265 | DotDict(self.args_dict).mel.hop_size / DotDict(self.args_dict).mel.sr * 1000 266 | ) 267 | 268 | def get_model_sr(self) -> int: 269 | """Get model sample rate""" 270 | return DotDict(self.args_dict).mel.sr 271 | 272 | def get_mel_config(self) -> dict: 273 | """Get mel config""" 274 | return dict(DotDict(self.args_dict).mel) 275 | 276 | def get_device(self) -> str: 277 | """Get device""" 278 | return self.tensor_device_marker.device 279 | 280 | def get_model_f0_range(self) -> dict: 281 | """Get model f0 range like {'f0_min': 32.70, 'f0_max': 1975.5}""" 282 | return { 283 | "f0_min": DotDict(self.args_dict).model.f0_min, 284 | "f0_max": DotDict(self.args_dict).model.f0_max, 285 | } 286 | 287 | 288 | class InferCFNaiveMelPEONNX: 289 | """Infer CFNaiveMelPE ONNX 290 | Args: 291 | args (DotDict): Config. 292 | onnx_path (str): Path to onnx file. 293 | device (str): Device. must be not None. 294 | """ 295 | 296 | def __init__(self, args, onnx_path, device): 297 | raise NotImplementedError 298 | 299 | 300 | def spawn_bundled_infer_model(device: str = None) -> InferCFNaiveMelPE: 301 | """ 302 | Spawn bundled infer model 303 | This model has been trained on our dataset and comes with the package. 304 | You can use it directly without anything else. 305 | Args: 306 | device (str): Device. Default: None. 307 | """ 308 | file_path = pathlib.Path(__file__) 309 | model_path = file_path.parent / "assets" / "fcpe_c_v001.pt" 310 | model = spawn_infer_model_from_pt(str(model_path), device, bundled_model=True) 311 | return model 312 | 313 | 314 | def spawn_infer_model_from_onnx( 315 | onnx_path: str, device: str = None 316 | ) -> InferCFNaiveMelPEONNX: 317 | """ 318 | Spawn infer model from onnx file 319 | Args: 320 | onnx_path (str): Path to onnx file. 321 | device (str): Device. Default: None. 322 | """ 323 | device = get_device(device, "torchfcpe.tools.spawn_infer_cf_naive_mel_pe_from_onnx") 324 | config_path = get_config_json_in_same_path(onnx_path) 325 | with open(config_path, "r", encoding="utf-8") as f: 326 | config_dict = json.load(f) 327 | args = DotDict(config_dict) 328 | if (args.is_onnx is None) or (args.is_onnx is False): 329 | raise ValueError( 330 | " [ERROR] spawn_infer_model_from_onnx: this model is not onnx model." 331 | ) 332 | 333 | if args.model.type == "CFNaiveMelPEONNX": 334 | infer_model = InferCFNaiveMelPEONNX(args, onnx_path, device) 335 | else: 336 | raise ValueError( 337 | f" [ERROR] args.model.type is {args.model.type}, but only support CFNaiveMelPEONNX" 338 | ) 339 | 340 | return infer_model 341 | 342 | 343 | def spawn_infer_model_from_pt( 344 | pt_path: str, device: str = None, bundled_model: bool = False 345 | ) -> InferCFNaiveMelPE: 346 | """ 347 | Spawn infer model from pt file 348 | Args: 349 | pt_path (str): Path to pt file. 350 | device (str): Device. Default: None. 351 | bundled_model (bool): Whether this model is bundled model, only used in spawn_bundled_infer_model. 352 | """ 353 | device = get_device(device, "torchfcpe.tools.spawn_infer_cf_naive_mel_pe_from_pt") 354 | ckpt = torch.load(pt_path, map_location=torch.device(device)) 355 | if bundled_model: 356 | ckpt["config_dict"]["model"]["conv_dropout"] = 0.0 357 | ckpt["config_dict"]["model"]["atten_dropout"] = 0.0 358 | args = DotDict(ckpt["config_dict"]) 359 | if (args.is_onnx is not None) and (args.is_onnx is True): 360 | raise ValueError( 361 | " [ERROR] spawn_infer_model_from_pt: this model is an onnx model." 362 | ) 363 | 364 | if args.model.type == "CFNaiveMelPE": 365 | infer_model = InferCFNaiveMelPE(args, ckpt["model"]) 366 | infer_model = infer_model.to(device) 367 | infer_model.eval() 368 | else: 369 | raise ValueError( 370 | f" [ERROR] args.model.type is {args.model.type}, but only support CFNaiveMelPE" 371 | ) 372 | 373 | return infer_model 374 | 375 | 376 | def spawn_model(args: DotDict) -> CFNaiveMelPE: 377 | """Spawn conformer naive model""" 378 | if args.model.type == "CFNaiveMelPE": 379 | pe_model = CFNaiveMelPE( 380 | input_channels=catch_none_args_must( 381 | args.mel.num_mels, 382 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 383 | warning_str="args.mel.num_mels is None", 384 | ), 385 | out_dims=catch_none_args_must( 386 | args.model.out_dims, 387 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 388 | warning_str="args.model.out_dims is None", 389 | ), 390 | hidden_dims=catch_none_args_must( 391 | args.model.hidden_dims, 392 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 393 | warning_str="args.model.hidden_dims is None", 394 | ), 395 | n_layers=catch_none_args_must( 396 | args.model.n_layers, 397 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 398 | warning_str="args.model.n_layers is None", 399 | ), 400 | n_heads=catch_none_args_must( 401 | args.model.n_heads, 402 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 403 | warning_str="args.model.n_heads is None", 404 | ), 405 | f0_max=catch_none_args_must( 406 | args.model.f0_max, 407 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 408 | warning_str="args.model.f0_max is None", 409 | ), 410 | f0_min=catch_none_args_must( 411 | args.model.f0_min, 412 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 413 | warning_str="args.model.f0_min is None", 414 | ), 415 | use_fa_norm=catch_none_args_must( 416 | args.model.use_fa_norm, 417 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 418 | warning_str="args.model.use_fa_norm is None", 419 | ), 420 | conv_only=catch_none_args_opti( 421 | args.model.conv_only, 422 | default=False, 423 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 424 | warning_str="args.model.conv_only is None", 425 | ), 426 | conv_dropout=catch_none_args_opti( 427 | args.model.conv_dropout, 428 | default=0.0, 429 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 430 | warning_str="args.model.conv_dropout is None", 431 | ), 432 | atten_dropout=catch_none_args_opti( 433 | args.model.atten_dropout, 434 | default=0.0, 435 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 436 | warning_str="args.model.atten_dropout is None", 437 | ), 438 | use_harmonic_emb=catch_none_args_opti( 439 | args.model.use_harmonic_emb, 440 | default=False, 441 | func_name="torchfcpe.tools.spawn_cf_naive_mel_pe", 442 | warning_str="args.model.use_harmonic_emb is None", 443 | ), 444 | ) 445 | else: 446 | raise ValueError( 447 | f" [ERROR] args.model.type is {args.model.type}, but only support CFNaiveMelPE" 448 | ) 449 | return pe_model 450 | 451 | 452 | def bundled_infer_model_unit_test(wav_path): 453 | """Unit test for bundled infer model""" 454 | # wav_path is your wav file path 455 | try: 456 | import librosa 457 | import matplotlib.pyplot as plt 458 | except ImportError: 459 | print( 460 | " [UNIT_TEST] torchfcpe.tools.spawn_infer_model_from_pt: matplotlib or librosa not found, skip test" 461 | ) 462 | exit(1) 463 | 464 | infer_model = spawn_bundled_infer_model(device="cpu") 465 | wav, sr = librosa.load(wav_path, sr=16000) 466 | f0 = infer_model.infer(torch.tensor(wav).unsqueeze(0), sr, interp_uv=False) 467 | f0_interp = infer_model.infer(torch.tensor(wav).unsqueeze(0), sr, interp_uv=True) 468 | plt.plot(f0.squeeze(-1).squeeze(0).numpy(), color="r", linestyle="-") 469 | plt.plot(f0_interp.squeeze(-1).squeeze(0).numpy(), color="g", linestyle="-") 470 | # 添加图例 471 | plt.legend(["f0", "f0_interp"]) 472 | plt.xlabel("frame") 473 | plt.ylabel("f0") 474 | plt.title("f0") 475 | plt.show() 476 | -------------------------------------------------------------------------------- /torchfcpe/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .mel_extractor import Wav2Mel, Wav2MelModule 3 | import pathlib 4 | 5 | 6 | class DotDict(dict): 7 | """ 8 | DotDict, used for config 9 | 10 | Example: 11 | # >>> config = DotDict({'a': 1, 'b': {'c': 2}}}) 12 | # >>> config.a 13 | # 1 14 | # >>> config.b.c 15 | # 2 16 | """ 17 | 18 | def __getattr__(*args): 19 | val = dict.get(*args) 20 | return DotDict(val) if type(val) is dict else val 21 | 22 | __setattr__ = dict.__setitem__ 23 | __delattr__ = dict.__delitem__ 24 | 25 | 26 | def spawn_wav2mel(args: DotDict, device: str = None) -> Wav2MelModule: 27 | """Spawn wav2mel""" 28 | _type = args.mel.type 29 | if (str(_type).lower() == 'none') or (str(_type).lower() == 'default'): 30 | _type = 'default' 31 | elif str(_type).lower() == 'stft': 32 | _type = 'stft' 33 | else: 34 | raise ValueError(f' [ERROR] torchfcpe.tools.args.spawn_wav2mel: {_type} is not a supported args.mel.type') 35 | wav2mel = Wav2MelModule( 36 | sr=catch_none_args_opti( 37 | args.mel.sr, 38 | default=16000, 39 | func_name='torchfcpe.tools.spawn_wav2mel', 40 | warning_str='args.mel.sr is None', 41 | ), 42 | n_mels=catch_none_args_opti( 43 | args.mel.num_mels, 44 | default=128, 45 | func_name='torchfcpe.tools.spawn_wav2mel', 46 | warning_str='args.mel.num_mels is None', 47 | ), 48 | n_fft=catch_none_args_opti( 49 | args.mel.n_fft, 50 | default=1024, 51 | func_name='torchfcpe.tools.spawn_wav2mel', 52 | warning_str='args.mel.n_fft is None', 53 | ), 54 | win_size=catch_none_args_opti( 55 | args.mel.win_size, 56 | default=1024, 57 | func_name='torchfcpe.tools.spawn_wav2mel', 58 | warning_str='args.mel.win_size is None', 59 | ), 60 | hop_length=catch_none_args_opti( 61 | args.mel.hop_size, 62 | default=160, 63 | func_name='torchfcpe.tools.spawn_wav2mel', 64 | warning_str='args.mel.hop_size is None', 65 | ), 66 | fmin=catch_none_args_opti( 67 | args.mel.fmin, 68 | default=0, 69 | func_name='torchfcpe.tools.spawn_wav2mel', 70 | warning_str='args.mel.fmin is None', 71 | ), 72 | fmax=catch_none_args_opti( 73 | args.mel.fmax, 74 | default=8000, 75 | func_name='torchfcpe.tools.spawn_wav2mel', 76 | warning_str='args.mel.fmax is None', 77 | ), 78 | clip_val=1e-05, 79 | mel_type=_type, 80 | ) 81 | device = catch_none_args_opti( 82 | device, 83 | default='cpu', 84 | func_name='torchfcpe.tools.spawn_wav2mel', 85 | warning_str='.device is None', 86 | ) 87 | return wav2mel.to(torch.device(device)) 88 | 89 | 90 | def catch_none_args_opti(x, default, func_name, warning_str=None, level='WARN'): 91 | """Catch None, optional""" 92 | if x is None: 93 | if warning_str is not None: 94 | print(f' [{level}] {warning_str}; use default {default}') 95 | print(f' [{level}] > call by:{func_name}') 96 | return default 97 | else: 98 | return x 99 | 100 | 101 | def catch_none_args_must(x, func_name, warning_str): 102 | """Catch None, must""" 103 | level = "ERROR" 104 | if x is None: 105 | print(f' [{level}] {warning_str}') 106 | print(f' [{level}] > call by:{func_name}') 107 | raise ValueError(f' [{level}] {warning_str}') 108 | else: 109 | return x 110 | 111 | 112 | def get_device(device: str, func_name: str) -> str: 113 | """Get device""" 114 | 115 | if device is None: 116 | if torch.cuda.is_available(): 117 | device = 'cuda' 118 | elif torch.backends.mps.is_available(): 119 | device = 'mps' 120 | else: 121 | device = 'cpu' 122 | 123 | print(f' [INFO]: Using {device} automatically.') 124 | print(f' [INFO] > call by: {func_name}') 125 | else: 126 | print(f' [INFO]: device is not None, use {device}') 127 | print(f' [INFO] > call by:{func_name}') 128 | device = device 129 | 130 | # Check if the specified device is available, if not, switch to cpu 131 | if ((device == 'cuda' and not torch.cuda.is_available()) or 132 | (device == 'mps' and not torch.backends.mps.is_available())): 133 | print(f' [WARN]: Specified device ({device}) is not available, switching to cpu.') 134 | device = 'cpu' 135 | 136 | return device 137 | 138 | 139 | def get_config_json_in_same_path(path: str) -> str: 140 | """Get config json in same path""" 141 | path = pathlib.Path(path) 142 | config_json = path.parent / 'config.json' 143 | if config_json.exists(): 144 | return str(config_json) 145 | else: 146 | raise FileNotFoundError(f' [ERROR] {config_json} not found.') 147 | -------------------------------------------------------------------------------- /torchfcpe/torch_interp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | ''' 3 | use from https://github.com/autumn-DL/TorchInterp 4 | it use MIT license 5 | ''' 6 | 7 | 8 | def torch_interp(x, xp, fp): 9 | # if not isinstance(x, torch.Tensor): 10 | # x = torch.tensor(x) 11 | # if not isinstance(xp, torch.Tensor): 12 | # xp = torch.tensor(xp) 13 | # if not isinstance(fp, torch.Tensor): 14 | # fp = torch.tensor(fp) 15 | 16 | sort_idx = torch.argsort(xp) 17 | xp = xp[sort_idx] 18 | fp = fp[sort_idx] 19 | 20 | right_idxs = torch.searchsorted(xp, x) 21 | 22 | right_idxs = right_idxs.clamp(max=len(xp) - 1) 23 | 24 | left_idxs = (right_idxs - 1).clamp(min=0) 25 | 26 | x_left = xp[left_idxs] 27 | x_right = xp[right_idxs] 28 | y_left = fp[left_idxs] 29 | y_right = fp[right_idxs] 30 | 31 | interp_vals = y_left + ((x - x_left) * (y_right - y_left) / (x_right - x_left)) 32 | 33 | interp_vals[x < xp[0]] = fp[0] 34 | interp_vals[x > xp[-1]] = fp[-1] 35 | 36 | return interp_vals 37 | 38 | 39 | def batch_interp_with_replacement_detach(uv, f0): 40 | ''' 41 | :param uv: B T 42 | :param f0: B T 43 | :return: f0 B T 44 | ''' 45 | 46 | result = f0.clone() 47 | 48 | for i in range(uv.shape[0]): 49 | x = torch.where(uv[i])[-1] 50 | xp = torch.where(~uv[i])[-1] 51 | fp = f0[i][~uv[i]] 52 | 53 | interp_vals = torch_interp(x, xp, fp).detach() 54 | 55 | result[i][uv[i]] = interp_vals 56 | return result 57 | 58 | 59 | def unit_text(): 60 | try: 61 | import matplotlib.pyplot as plt 62 | except ImportError: 63 | print(' [UNIT_TEST] torchfcpe.torch_interp: matplotlib not found, skip plotting.') 64 | exit(1) 65 | 66 | # f0 67 | f0 = torch.tensor([1, 0, 3, 0, 0, 3, 4, 5, 0, 0]).float() 68 | uv = torch.tensor([0, 1, 0, 1, 1, 0, 0, 0, 1, 1]).bool() 69 | 70 | interp_f0 = batch_interp_with_replacement_detach(uv.unsqueeze(0), f0.unsqueeze(0)).squeeze(0) 71 | 72 | print(interp_f0) 73 | 74 | 75 | if __name__ == '__main__': 76 | unit_text() 77 | -------------------------------------------------------------------------------- /train/configs/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | duration: 0.3483 # Audio duration during training, must be less than the duration of the shortest audio clip 3 | train_path: 'yx2/train' # Create a folder named "audio" under this path and put the audio clip in it 4 | valid_path: 'yx2/val' # Create a folder named "audio" under this path and put the audio clip in it 5 | extensions: # List of extension included in the data collection 6 | - wav 7 | mel: 8 | type: 'stft' # use "none" or "default" to use default 9 | sr: 16000 10 | num_mels: 512 # if stft need to match others 11 | n_fft: 1024 12 | win_size: 1024 13 | hop_size: 160 14 | fmin: 0 15 | fmax: 8000 16 | model: 17 | type: 'CFNaiveMelPE' 18 | out_dims: 360 19 | hidden_dims: 512 20 | n_layers: 6 21 | n_heads: 8 22 | f0_min: 32.70 23 | f0_max: 1975.5 24 | use_fa_norm: true 25 | conv_only: true 26 | conv_dropout: 0.0 27 | atten_dropout: 0.0 28 | use_harmonic_emb: false 29 | loss: 30 | loss_scale: 10 31 | device: cuda 32 | env: 33 | expdir: exp/yx2_001ac_stft 34 | gpu_id: 0 35 | train: 36 | aug_add_music: true 37 | aug_keyshift: true 38 | f0_shift_mode: 'keyshift' 39 | keyshift_min: -6 40 | keyshift_max: 6 41 | aug_noise: true 42 | noise_ratio: 0.7 43 | brown_noise_ratio: 1 44 | aug_mask: true 45 | aug_mask_v_o: true 46 | aug_mask_vertical_factor: 0.05 47 | aug_mask_vertical_factor_v_o: 0.3 48 | aug_mask_iszeropad_mode: 'noise' # randon zero or noise 49 | aug_mask_block_num: 1 50 | aug_mask_block_num_v_o: 1 51 | num_workers: 14 # If your cpu and gpu are both very strong, set to 0 may be faster! 52 | amp_dtype: fp32 # only can ues fp32, else nan 53 | batch_size: 128 54 | use_redis: true 55 | cache_all_data: true # Save Internal-Memory or Graphics-Memory if it is false, but may be slow 56 | cache_device: 'cpu' # Set to 'cuda' to cache the data into the Graphics-Memory, fastest speed for strong gpu 57 | epochs: 100000 58 | interval_log: 100 59 | interval_val: 5000 60 | interval_force_save: 5000 61 | lr: 0.0005 62 | decay_step: 100000 63 | gamma: 0.7071 64 | weight_decay: 0.0001 65 | save_opt: false 66 | -------------------------------------------------------------------------------- /train/data_loaders_wav.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | import librosa 5 | import torch 6 | import random 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset 9 | from concurrent.futures import ProcessPoolExecutor 10 | import torch.multiprocessing as mp 11 | import utils_all as ut 12 | import pandas as pd 13 | import torchfcpe 14 | from redis_coder import RedisService, encode_wb, decode_wb 15 | DATABUFFER = RedisService( 16 | host='localhost', 17 | password='', 18 | port=6379, 19 | max_connections=16 20 | ) 21 | def traverse_dir( 22 | root_dir, 23 | extensions, 24 | amount=None, 25 | str_include=None, 26 | str_exclude=None, 27 | is_pure=False, 28 | is_sort=False, 29 | is_ext=True): 30 | file_list = [] 31 | cnt = 0 32 | for root, _, files in os.walk(root_dir): 33 | for file in files: 34 | if any([file.endswith(f".{ext}") for ext in extensions]): 35 | # path 36 | mix_path = os.path.join(root, file) 37 | pure_path = mix_path[len(root_dir) + 1:] if is_pure else mix_path 38 | 39 | # amount 40 | if (amount is not None) and (cnt == amount): 41 | if is_sort: 42 | file_list.sort() 43 | return file_list 44 | 45 | # check string 46 | if (str_include is not None) and (str_include not in pure_path): 47 | continue 48 | if (str_exclude is not None) and (str_exclude in pure_path): 49 | continue 50 | 51 | if not is_ext: 52 | ext = pure_path.split('.')[-1] 53 | pure_path = pure_path[:-(len(ext) + 1)] 54 | file_list.append(pure_path) 55 | cnt += 1 56 | if is_sort: 57 | file_list.sort() 58 | return file_list 59 | 60 | 61 | def get_data_loaders(args, jump=False): 62 | wav2mel = torchfcpe.spawn_wav2mel(args, device='cpu') 63 | data_train = F0Dataset( 64 | path_root=args.data.train_path, 65 | waveform_sec=args.data.duration, 66 | hop_size=args.mel.hop_size, 67 | sample_rate=args.mel.sr, 68 | duration=args.data.duration, 69 | load_all_data=args.train.cache_all_data, 70 | whole_audio=False, 71 | extensions=args.data.extensions, 72 | device=args.train.cache_device, 73 | wav2mel=wav2mel, 74 | aug_noise=args.train.aug_noise, 75 | noise_ratio=args.train.noise_ratio, 76 | brown_noise_ratio=args.train.brown_noise_ratio, 77 | aug_mask=args.train.aug_mask, 78 | aug_mask_v_o=args.train.aug_mask_v_o, 79 | aug_mask_vertical_factor=args.train.aug_mask_vertical_factor, 80 | aug_mask_vertical_factor_v_o=args.train.aug_mask_vertical_factor_v_o, 81 | aug_mask_iszeropad_mode=args.train.aug_mask_iszeropad_mode, 82 | aug_mask_block_num=args.train.aug_mask_block_num, 83 | aug_mask_block_num_v_o=args.train.aug_mask_block_num_v_o, 84 | aug_keyshift=args.train.aug_keyshift, 85 | keyshift_min=args.train.keyshift_min, 86 | keyshift_max=args.train.keyshift_max, 87 | f0_min=args.model.f0_min, 88 | f0_max=args.model.f0_max, 89 | f0_shift_mode='keyshift', 90 | load_data_num_processes=8, 91 | use_redis=args.train.use_redis, 92 | jump=jump 93 | ) 94 | loader_train = torch.utils.data.DataLoader( 95 | data_train, 96 | batch_size=args.train.batch_size, 97 | shuffle=True, 98 | num_workers=args.train.num_workers if args.train.cache_device == 'cpu' else 0, 99 | persistent_workers=(args.train.num_workers > 0) if args.train.cache_device == 'cpu' else False, 100 | pin_memory=True if args.train.cache_device == 'cpu' else False 101 | ) 102 | data_valid = F0Dataset( 103 | path_root=args.data.valid_path, 104 | waveform_sec=args.data.duration, 105 | hop_size=args.mel.hop_size, 106 | sample_rate=args.mel.sr, 107 | duration=args.data.duration, 108 | load_all_data=args.train.cache_all_data, 109 | whole_audio=True, 110 | extensions=args.data.extensions, 111 | wav2mel=wav2mel, 112 | aug_noise=args.train.aug_noise, 113 | noise_ratio=args.train.noise_ratio, 114 | brown_noise_ratio=args.train.brown_noise_ratio, 115 | aug_mask=args.train.aug_mask, 116 | aug_mask_v_o=args.train.aug_mask_v_o, 117 | aug_mask_vertical_factor=args.train.aug_mask_vertical_factor, 118 | aug_mask_vertical_factor_v_o=args.train.aug_mask_vertical_factor_v_o, 119 | aug_mask_iszeropad_mode=args.train.aug_mask_iszeropad_mode, 120 | aug_mask_block_num=args.train.aug_mask_block_num, 121 | aug_mask_block_num_v_o=args.train.aug_mask_block_num_v_o, 122 | aug_keyshift=False, 123 | ) 124 | loader_valid = torch.utils.data.DataLoader( 125 | data_valid, 126 | batch_size=1, 127 | shuffle=False, 128 | num_workers=0, 129 | pin_memory=True 130 | ) 131 | return loader_train, loader_valid 132 | 133 | 134 | class F0Dataset(Dataset): 135 | def __init__( 136 | self, 137 | path_root, 138 | waveform_sec, 139 | hop_size, 140 | sample_rate, 141 | duration, 142 | load_all_data=True, 143 | whole_audio=False, 144 | extensions=['wav'], 145 | device='cpu', 146 | wav2mel=None, 147 | aug_noise=False, 148 | noise_ratio=0.7, 149 | brown_noise_ratio=1., 150 | aug_mask=False, 151 | aug_mask_v_o=False, 152 | aug_mask_vertical_factor=0.05, 153 | aug_mask_vertical_factor_v_o=0.3, 154 | aug_mask_iszeropad_mode='randon', # randon zero or noise 155 | aug_mask_block_num=1, 156 | aug_mask_block_num_v_o=4, 157 | aug_keyshift=True, 158 | keyshift_min=-5, 159 | keyshift_max=12, 160 | f0_min=32.70, 161 | f0_max=1975.5, 162 | f0_shift_mode='keyshift', 163 | load_data_num_processes=1, 164 | snb_noise=None, 165 | noise_beta=0, 166 | use_redis=False, 167 | jump=False 168 | ): 169 | super().__init__() 170 | self.music_spk_id = 1 171 | self.wav2mel = wav2mel 172 | self.waveform_sec = waveform_sec 173 | self.sample_rate = sample_rate 174 | self.hop_size = hop_size 175 | self.path_root = path_root 176 | self.duration = duration 177 | self.aug_noise = aug_noise 178 | self.noise_ratio = noise_ratio 179 | self.brown_noise_ratio = brown_noise_ratio 180 | self.aug_mask = aug_mask 181 | self.aug_mask_v_o = aug_mask_v_o 182 | self.aug_mask_vertical_factor = aug_mask_vertical_factor 183 | self.aug_mask_vertical_factor_v_o = aug_mask_vertical_factor_v_o 184 | self.aug_mask_iszeropad_mode = aug_mask_iszeropad_mode 185 | self.aug_mask_block_num = aug_mask_block_num 186 | self.aug_mask_block_num_v_o = aug_mask_block_num_v_o 187 | self.aug_keyshift = aug_keyshift 188 | self.keyshift_min = keyshift_min 189 | self.keyshift_max = keyshift_max 190 | self.f0_min = f0_min 191 | self.f0_max = f0_max 192 | self.f0_shift_mode = f0_shift_mode 193 | self.n_spk = 4 194 | self.device = device 195 | self.load_all_data = load_all_data 196 | self.snb_noise = snb_noise 197 | self.noise_beta = noise_beta 198 | self.use_redis = use_redis 199 | self.jump = jump 200 | 201 | self.paths = traverse_dir( 202 | os.path.join(path_root, 'audio'), 203 | extensions=extensions, 204 | is_pure=True, 205 | is_sort=True, 206 | is_ext=True 207 | ) 208 | 209 | self.whole_audio = whole_audio 210 | if self.use_redis: 211 | self.data_buffer = None 212 | else: 213 | self.data_buffer = {} 214 | self.device = device 215 | if load_all_data: 216 | print('Load all the data from :', path_root) 217 | else: 218 | print('Load the f0, volume data from :', path_root) 219 | 220 | if self.use_redis: 221 | _ = self.load_data(self.paths) 222 | else: 223 | with torch.no_grad(): 224 | with ProcessPoolExecutor(max_workers=load_data_num_processes) as executor: 225 | tasks = [] 226 | for i in range(load_data_num_processes): 227 | start = int(i * len(self.paths) / load_data_num_processes) 228 | end = int((i + 1) * len(self.paths) / load_data_num_processes) 229 | file_chunk = self.paths[start:end] 230 | tasks.append(file_chunk) 231 | for data_buffer in executor.map(self.load_data, tasks): 232 | self.data_buffer.update(data_buffer) 233 | 234 | self.paths = np.array(self.paths, dtype=object) 235 | self.data_buffer = pd.DataFrame(self.data_buffer) 236 | 237 | def load_data(self, paths): 238 | with torch.no_grad(): 239 | data_buffer = {} 240 | rank = mp.current_process()._identity 241 | rank = rank[0] if len(rank) > 0 else 0 242 | for name_ext in tqdm(paths): 243 | path_audio = os.path.join(self.path_root, 'audio', name_ext) 244 | duration = librosa.get_duration(filename=path_audio, sr=self.sample_rate) 245 | 246 | path_f0 = os.path.join(self.path_root, 'f0', name_ext) + '.npy' 247 | f0 = np.load(path_f0)[:, None] 248 | # f0 = torch.from_numpy(f0).float().unsqueeze(-1).to(self.device) 249 | 250 | if self.n_spk is not None and self.n_spk > 1: 251 | dirname_split = re.split(r"_|\-", os.path.dirname(name_ext), 2)[0] 252 | t_spk_id = spk_id = int(dirname_split) if str.isdigit(dirname_split) else 0 253 | if spk_id < 1 or spk_id > self.n_spk: 254 | raise ValueError( 255 | ' [x] Muiti-speaker traing error : spk_id must be a positive integer from 1 to n_spk ') 256 | else: 257 | pass 258 | # spk_id = 1 259 | # t_spk_id = spk_id 260 | # spk_id = torch.LongTensor(np.array([spk_id])).to(self.device) 261 | # spk_id = np.array([spk_id]) 262 | 263 | if self.load_all_data: 264 | audio, sr = librosa.load(path_audio, sr=self.sample_rate) 265 | if len(audio.shape) > 1: 266 | audio = librosa.to_mono(audio) 267 | # audio = torch.from_numpy(audio).to(device) 268 | 269 | # path_audio = os.path.join(self.path_root, 'npaudiodir', name_ext) + '.npy' 270 | # audio = np.load(path_audio) 271 | 272 | if spk_id == self.music_spk_id: 273 | path_music = os.path.join(self.path_root, 'music', name_ext)# + '.npy' 274 | # audio_music = np.load(path_music) 275 | audio_music, _ = librosa.load(path_music, sr=self.sample_rate) 276 | if len(audio_music.shape) > 1: 277 | audio_music = librosa.to_mono(audio_music) 278 | else: 279 | audio_music = None 280 | 281 | """ 282 | data_buffer[name_ext] = { 283 | 'duration': duration, 284 | 'audio': audio, 285 | 'f0': f0, 286 | 'spk_id': spk_id, 287 | 't_spk_id': t_spk_id, 288 | } 289 | """ 290 | if self.use_redis: 291 | f0 = encode_wb(f0, f0.dtype, f0.shape) 292 | audio = encode_wb(audio, audio.dtype, audio.shape) 293 | if audio_music is not None: 294 | audio_music = encode_wb(audio_music, audio_music.dtype, audio_music.shape) 295 | else: 296 | audio_music = int(0) 297 | if self.use_redis: 298 | if not self.jump: 299 | DATABUFFER[name_ext] = list((duration, f0, audio, audio_music)) 300 | data_buffer = None 301 | else: 302 | data_buffer[name_ext] = (duration, f0, audio, audio_music) 303 | else: 304 | if spk_id == self.music_spk_id: 305 | use_music = True 306 | else: 307 | use_music = None 308 | """ 309 | data_buffer[name_ext] = { 310 | 'duration': duration, 311 | 'f0': f0, 312 | 'spk_id': spk_id, 313 | 't_spk_id': t_spk_id 314 | } 315 | """ 316 | data_buffer[name_ext] = (duration, f0, use_music) 317 | return data_buffer 318 | 319 | def __getitem__(self, file_idx): 320 | with torch.no_grad(): 321 | name_ext = self.paths[file_idx] 322 | if self.use_redis: 323 | data_buffer = DATABUFFER[name_ext] 324 | else: 325 | data_buffer = self.data_buffer[name_ext] 326 | # check duration. if too short, then skip 327 | if float(data_buffer[0]) < (self.waveform_sec + 0.1): 328 | return self.__getitem__((file_idx + 1) % len(self.paths)) 329 | 330 | # get item 331 | return self.get_data(name_ext, tuple(data_buffer)) 332 | 333 | def get_data(self, name_ext, data_buffer): 334 | with torch.no_grad(): 335 | name = os.path.splitext(name_ext)[0] 336 | frame_resolution = self.hop_size / self.sample_rate 337 | duration = float(data_buffer[0]) 338 | waveform_sec = duration if self.whole_audio else self.waveform_sec 339 | 340 | # load audio 341 | idx_from = 0 if self.whole_audio else random.uniform(0, duration - waveform_sec - 0.1) 342 | start_frame = int(idx_from / frame_resolution) 343 | units_frame_len = int(waveform_sec / frame_resolution) 344 | 345 | # load f0 346 | if self.use_redis and self.load_all_data: 347 | f0 = decode_wb(data_buffer[1]) 348 | f0 = f0.copy() 349 | else: 350 | f0 = data_buffer[1].copy() 351 | f0 = torch.from_numpy(f0).float().cpu() 352 | 353 | # load mel 354 | # audio = data_buffer.get('audio') 355 | if len(data_buffer) == 3: 356 | #path_audio = os.path.join(self.path_root, 'npaudiodir', name_ext) + '.npy' 357 | #audio = np.load(path_audio) 358 | path_audio = os.path.join(self.path_root, 'audio', name_ext) 359 | audio, _ = librosa.load(path_audio, sr=self.sample_rate) 360 | if len(audio.shape) > 1: 361 | audio = librosa.to_mono(audio) 362 | if random.choice((False, True)) and (data_buffer[2] is not None): 363 | path_music = os.path.join(self.path_root, 'music', name_ext) 364 | audio_music, _ = librosa.load(path_music, sr=self.sample_rate) 365 | if len(audio_music.shape) > 1: 366 | audio_music = librosa.to_mono(audio_music) 367 | audio = audio + audio_music 368 | del audio_music 369 | audio = 0.98 * audio / (np.abs(audio).max()) 370 | 371 | else: 372 | if self.use_redis: 373 | audio = decode_wb(data_buffer[2]) 374 | audio = audio.copy() 375 | else: 376 | audio = data_buffer[2].copy() 377 | if len(data_buffer) == 4: 378 | if self.use_redis: 379 | if len(data_buffer[3]) == 1: 380 | pass 381 | else: 382 | audio_music = decode_wb(data_buffer[3]) 383 | audio_music = audio_music.copy() 384 | audio = audio + audio_music 385 | del audio_music 386 | audio = 0.98 * audio / (np.abs(audio).max()) 387 | else: 388 | if data_buffer[3] is not None: 389 | if random.choice((False, True)): 390 | audio_music = data_buffer[3].copy() 391 | audio = audio + audio_music 392 | del audio_music 393 | audio = 0.98 * audio / (np.abs(audio).max()) 394 | 395 | if random.choice((False, True)) and self.aug_keyshift: 396 | if self.f0_shift_mode == 'keyshift': 397 | _f0_shift_mode = 'keyshift' 398 | elif self.f0_shift_mode == 'automax': 399 | _f0_shift_mode = 'automax' 400 | elif self.f0_shift_mode == 'random': 401 | _f0_shift_mode = random.choice(('keyshift', 'automax')) 402 | else: 403 | raise ValueError('f0_shift_mode must be keyshift, automax or random') 404 | 405 | if _f0_shift_mode == 'keyshift': 406 | keyshift = random.uniform(self.keyshift_min, self.keyshift_max) 407 | elif _f0_shift_mode == 'automax': 408 | keyshift_max = 12 * np.log2(self.f0_max / f0.max) 409 | keyshift_min = 12 * np.log2(self.f0_min / f0.min) 410 | keyshift = random.uniform(keyshift_min, keyshift_max) 411 | with torch.no_grad(): 412 | f0 = 2 ** (keyshift / 12) * f0 413 | else: 414 | keyshift = 0 415 | 416 | is_aug_noise = bool(random.randint(0, 1)) 417 | 418 | if self.snb_noise is not None: 419 | audio = ut.add_noise_snb(audio, self.snb_noise, self.noise_beta) 420 | 421 | if self.aug_noise and is_aug_noise: 422 | if bool(random.randint(0, 1)): 423 | audio = ut.add_noise(audio, noise_ratio=self.noise_ratio) 424 | else: 425 | audio = ut.add_noise_slice(audio, self.sample_rate, self.duration, noise_ratio=self.noise_ratio, 426 | brown_noise_ratio=self.brown_noise_ratio) 427 | 428 | peak = np.abs(audio).max() 429 | audio = 0.98 * audio / peak 430 | audio = torch.from_numpy(audio).float().unsqueeze(0).cpu() 431 | with torch.no_grad(): 432 | mel = self.wav2mel(audio, sample_rate=self.sample_rate, keyshift=keyshift, no_cache_window=True).squeeze(0).cpu() 433 | 434 | if self.aug_mask and bool(random.randint(0, 1)) and not is_aug_noise: 435 | v_o = bool(random.randint(0, 1)) and self.aug_mask_v_o 436 | mel = mel.transpose(-1, -2) 437 | if self.aug_mask_iszeropad_mode == 'zero': 438 | iszeropad = True 439 | elif self.aug_mask_iszeropad_mode == 'noise': 440 | iszeropad = False 441 | else: 442 | iszeropad = bool(random.randint(0, 1)) 443 | mel = ut.add_mel_mask_slice(mel, self.sample_rate, self.duration, hop_size=self.hop_size, 444 | vertical_factor=self.aug_mask_vertical_factor_v_o if v_o else self.aug_mask_vertical_factor, 445 | vertical_offset=v_o, iszeropad=iszeropad, 446 | block_num=self.aug_mask_block_num_v_o if v_o else self.aug_mask_block_num) 447 | mel = mel.transpose(-1, -2) 448 | 449 | mel = mel[start_frame: start_frame + units_frame_len].detach() 450 | 451 | f0_frames = f0[start_frame: start_frame + units_frame_len].detach() 452 | 453 | # load spk_id 454 | # spk_id = data_buffer.get('spk_id') 455 | # spk_id = torch.LongTensor(spk_id).to(self.device) 456 | 457 | 458 | del audio 459 | # return dict(mel=mel, f0=f0_frames, spk_id=spk_id, name=name, name_ext=name_ext) 460 | output = (mel, f0_frames, name, name_ext) 461 | return output 462 | 463 | def __len__(self): 464 | return len(self.paths) 465 | -------------------------------------------------------------------------------- /train/draw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | from tqdm import tqdm 5 | import shutil 6 | 7 | train_dir = r'yx1\train' 8 | val_dir = r'yx1\val' 9 | """ 10 | for spker in os.listdir(train_dir): 11 | train_list = os.listdir(os.path.join(train_dir, spker, "audio")) 12 | val_list = [] 13 | for i in range(10): 14 | # 生成随机数,作为索引,随机选择一个文件 15 | _random = random.randint(0, len(train_list) - 1) 16 | val_list.append(train_list[_random]) 17 | train_list.pop(_random) 18 | print(_random, len(train_list)) 19 | for i in val_list: 20 | print(i) 21 | os.makedirs(os.path.join(val_dir, spker, "audio"), exist_ok=True) 22 | shutil.move(os.path.join(train_dir, spker, "audio", i), os.path.join(val_dir, spker, "audio", i)) 23 | os.makedirs(os.path.join(val_dir, spker, "f0"), exist_ok=True) 24 | shutil.move(os.path.join(train_dir, spker, "f0", i + '.npy'), os.path.join(val_dir, spker, "f0", i + '.npy')) 25 | if int(spker) == 1: 26 | os.makedirs(os.path.join(val_dir, spker, "music"), exist_ok=True) 27 | shutil.move(os.path.join(train_dir, spker, "music", i), os.path.join(val_dir, spker, "music", i)) 28 | """ 29 | -------------------------------------------------------------------------------- /train/pre_data.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import librosa 4 | import soundfile as sf 5 | root_dir = r'XXX' 6 | raw_dir = os.path.join(root_dir, 'raw') 7 | music_dir = os.path.join(root_dir, 'music') 8 | f0_mir1k_dir = os.path.join(root_dir, 'f0_mir1k') 9 | f0_ptdb = os.path.join(root_dir, 'f0') 10 | audio_mir1k = os.path.join(root_dir, 'audio_mir1k') 11 | audio_ptdb = os.path.join(root_dir, 'audio') 12 | 13 | 14 | file_list = os.listdir(raw_dir) 15 | 16 | for file_name in file_list: 17 | print(file_name) 18 | if file_name[-3:] == '.pv': 19 | pipe_mode = 'PV' 20 | else: 21 | pipe_mode = 'WAV' 22 | if file_name[:4] == 'mic_': 23 | dataset_mode = 'ptdb' 24 | else: 25 | dataset_mode = 'mir1k' 26 | 27 | dataset_mode = 'ptdb' 28 | 29 | if pipe_mode == 'PV': 30 | pv_np = numpy.loadtxt(os.path.join(raw_dir, file_name)) 31 | pv_np = numpy.insert(pv_np, 0, 0) 32 | mask = pv_np == 0 33 | f0_np = 440 * (2.0 ** ((pv_np - 69.0) / 12.0)) 34 | f0_np[mask] = 0 35 | if dataset_mode == 'ptdb': 36 | numpy.save(os.path.join(f0_ptdb, file_name[:-3] + '.wav.npy'), f0_np) 37 | else: 38 | numpy.save(os.path.join(f0_mir1k_dir, file_name[:-3] + '.wav.npy'), f0_np) 39 | else: 40 | if dataset_mode == 'ptdb': 41 | audio, sr = librosa.load(os.path.join(raw_dir, file_name), mono=False, sr=None) 42 | assert sr == 16000 43 | assert len(audio.shape) == 1 44 | sf.write(os.path.join(audio_ptdb, file_name), audio, sr) 45 | else: 46 | audio, sr = librosa.load(os.path.join(raw_dir, file_name), mono=False, sr=None) 47 | assert sr == 16000 48 | assert len(audio.shape) == 2 49 | sf.write(os.path.join(audio_mir1k, file_name), audio[1], sr) 50 | sf.write(os.path.join(music_dir, file_name), audio[0], sr) 51 | 52 | 53 | """ 54 | in_wav, in_sr = librosa.load(os.path.join(raw_dir, "abjones_1_02.wav"), mono=False, sr=None) 55 | sf.write("test_0.wav", in_wav[0], in_sr) # music 56 | sf.write("test_1.wav", in_wav[1], in_sr) # voice 57 | print(in_wav.shape) 58 | print(in_sr) 59 | print(os.path.join(raw_dir, "abjones_1_02.wav")) 60 | """ 61 | """ 62 | """ 63 | """ 64 | a = numpy.array([1, 2, 0, 0, 0, 90, 58.03715, 66, 66, 66, 66]) 65 | # 给a的前面加一个0 66 | b = numpy.insert(a, 0, 0) 67 | # 为0的地方为False,其他为True 68 | mask = b == 0 69 | # f0 (Hz)= 440 * (2.0 ** ((b - 69.0) / 12.0)) 70 | c = 440 * (2.0 ** ((b - 69.0) / 12.0)) 71 | # 将c中为0的地方置为0 72 | c[mask] = 0 73 | print(c) 74 | print(mask) 75 | print(b) 76 | """ 77 | -------------------------------------------------------------------------------- /train/redis_coder.py: -------------------------------------------------------------------------------- 1 | import redis 2 | import numpy as np 3 | import struct 4 | 5 | 6 | class RedisPool: 7 | def __init__(self, host, password, port, max_connections=10): 8 | # 创建 Redis 连接池 9 | self.pool = redis.ConnectionPool(host=host, port=port, password=password, max_connections=max_connections) 10 | 11 | def get_redis_conn(self): 12 | # 获取 Redis 连接 13 | redis_conn = redis.StrictRedis(connection_pool=self.pool) 14 | try: 15 | # 检查连接是否可用 16 | redis_conn.ping() 17 | except Exception: 18 | # 如果连接不可用,则重建连接 19 | redis_conn.connection_pool.disconnect() 20 | redis_conn = redis.StrictRedis(connection_pool=self.pool) 21 | return redis_conn 22 | 23 | 24 | class RedisService: 25 | def __init__(self, host, password, port, max_connections=10): 26 | # 创建 Redis 连接池对象 27 | self.pool = RedisPool(host=host, port=port, password=password, max_connections=max_connections) 28 | # 获取 Redis 连接 29 | self.redis_conn = self.pool.get_redis_conn() 30 | 31 | def set(self, **kwargs): 32 | for key, value in kwargs.items(): 33 | self.redis_conn.__setitem__(key, value) 34 | 35 | def push(self, key, value): 36 | self.redis_conn.lpush(key, value) 37 | 38 | def pop(self, key): 39 | value = self.redis_conn.rpop(key) 40 | self.redis_conn.lrem(key, 0, value) 41 | return value 42 | 43 | def list_get_index(self, key, index): 44 | return self.redis_conn.lindex(key, index) 45 | 46 | def llen(self, key): 47 | return self.redis_conn.llen(key) 48 | 49 | def set_add(self, key, *value): 50 | self.redis_conn.sadd(key, *value) 51 | 52 | def set_member_exists(self, key, value): 53 | return self.redis_conn.sismember(key, value) 54 | 55 | def exitst(self, key): 56 | return self.redis_conn.exists(key) 57 | 58 | def __setitem__(self, key, value): 59 | if self.redis_conn.exists(key): 60 | self.redis_conn.delete(key) 61 | if isinstance(value, str): 62 | self.redis_conn.set(key, value) 63 | elif isinstance(value, dict): 64 | self.redis_conn.hmset(key, value) 65 | elif isinstance(value, list): 66 | self.redis_conn.rpush(key, *value) 67 | else: 68 | self.redis_conn.set(key, value) 69 | 70 | def __getitem__(self, key): 71 | key_type = self.redis_conn.type(key) 72 | if key_type == b'none': 73 | return None 74 | elif key_type == b'list': 75 | return self.redis_conn.lrange(key, 0, -1) 76 | elif key_type == b'string': 77 | return self.redis_conn.get(key) 78 | elif key_type == b'hash': 79 | return self.redis_conn.hgetall(key) 80 | else: 81 | print(f"Key {key} type {key_type} not support") 82 | return None 83 | 84 | 85 | def encode_wb(wb_np: np.ndarray, 86 | dtype: np.dtype, 87 | shape: tuple, 88 | ) -> bytes: 89 | # 将numpy数组转换为bytes,头部为shape和dtype,变长编码 90 | wb_bytes = struct.pack('i', len(shape)) 91 | for i in shape: 92 | wb_bytes += struct.pack('i', i) 93 | 94 | wb_bytes += struct.pack('i', len(dtype.name)) 95 | wb_bytes += dtype.name.encode('utf-8') 96 | wb_bytes += wb_np.tobytes() 97 | return wb_bytes 98 | 99 | 100 | def decode_wb(wb_bytes: bytes) -> np.ndarray: 101 | # 将上文编码的bytes转换回numpy数组 102 | shape_len = struct.unpack('i', wb_bytes[:4])[0] 103 | shape = [] 104 | for i in range(shape_len): 105 | shape.append(struct.unpack('i', wb_bytes[4 + i * 4: 8 + i * 4])[0]) 106 | dtype_len = struct.unpack('i', wb_bytes[4 + shape_len * 4: 8 + shape_len * 4])[0] 107 | dtype = np.dtype(struct.unpack(f'{dtype_len}s', wb_bytes[8 + shape_len * 4: 8 + shape_len * 4 + dtype_len])[0]) 108 | wb_np = np.frombuffer(wb_bytes[8 + shape_len * 4 + dtype_len:], dtype=dtype) 109 | wb_np = wb_np.reshape(shape) 110 | return wb_np 111 | -------------------------------------------------------------------------------- /train/savertools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CNChTu/FCPE/d55c1b636dc3564dc35be19675998688931e870c/train/savertools/__init__.py -------------------------------------------------------------------------------- /train/savertools/saver.py: -------------------------------------------------------------------------------- 1 | ''' 2 | author: wayn391@mastertones 3 | ''' 4 | 5 | import os 6 | import time 7 | import yaml 8 | import datetime 9 | import torch 10 | import matplotlib.pyplot as plt 11 | plt.switch_backend('agg') 12 | from . import utils 13 | import numpy as np 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | 17 | class Saver(object): 18 | def __init__( 19 | self, 20 | args, 21 | initial_global_step=-1): 22 | 23 | self.expdir = args.env.expdir 24 | self.sample_rate = args.mel.sampling_rate 25 | 26 | # cold start 27 | self.global_step = initial_global_step 28 | self.init_time = time.time() 29 | self.last_time = time.time() 30 | 31 | # makedirs 32 | os.makedirs(self.expdir, exist_ok=True) 33 | 34 | # path 35 | self.path_log_info = os.path.join(self.expdir, 'log_info.txt') 36 | 37 | # ckpt 38 | os.makedirs(self.expdir, exist_ok=True) 39 | 40 | # writer 41 | self.writer = SummaryWriter(os.path.join(self.expdir, 'logs')) 42 | 43 | # save config 44 | path_config = os.path.join(self.expdir, 'config.yaml') 45 | with open(path_config, "w") as out_config: 46 | yaml.dump(dict(args), out_config) 47 | 48 | # save spk_emb_dict 49 | if args.model.use_speaker_encoder: 50 | import numpy as np 51 | path_from_spk_emb_dict = os.path.join(args.data.train_path, 'spk_emb_dict.npy') 52 | path_save_spk_emb_dict = os.path.join(self.expdir, 'spk_emb_dict.npy') 53 | temp_spk_emb_dict = np.load(path_from_spk_emb_dict, allow_pickle=True).item() 54 | np.save(path_save_spk_emb_dict, temp_spk_emb_dict) 55 | 56 | def log_info(self, msg): 57 | '''log method''' 58 | if isinstance(msg, dict): 59 | msg_list = [] 60 | for k, v in msg.items(): 61 | tmp_str = '' 62 | if isinstance(v, int): 63 | tmp_str = '{}: {:,}'.format(k, v) 64 | else: 65 | tmp_str = '{}: {}'.format(k, v) 66 | 67 | msg_list.append(tmp_str) 68 | msg_str = '\n'.join(msg_list) 69 | else: 70 | msg_str = msg 71 | 72 | # dsplay 73 | print(msg_str) 74 | 75 | # save 76 | with open(self.path_log_info, 'a') as fp: 77 | fp.write(msg_str + '\n') 78 | 79 | def log_value(self, dict): 80 | for k, v in dict.items(): 81 | self.writer.add_scalar(k, v, self.global_step) 82 | 83 | def log_spec(self, name, spec, spec_out, vmin=-14, vmax=3.5): 84 | spec_cat = torch.cat([(spec_out - spec).abs() + vmin, spec, spec_out], -1) 85 | spec = spec_cat[0] 86 | if isinstance(spec, torch.Tensor): 87 | spec = spec.cpu().numpy() 88 | fig = plt.figure(figsize=(12, 9)) 89 | plt.pcolor(spec.T, vmin=vmin, vmax=vmax) 90 | plt.tight_layout() 91 | self.writer.add_figure(name, fig, self.global_step) 92 | 93 | def log_audio(self, dict): 94 | for k, v in dict.items(): 95 | self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate) 96 | 97 | def log_f0(self, name, f0_pr, f0_gt, inuv=False): 98 | #f0_gt = (1 + f0_gt / 700).log() 99 | name = (name + '_f0_inuv') if inuv else (name + '_f0') 100 | f0_pr = f0_pr.squeeze().cpu().numpy() 101 | f0_gt = f0_gt.squeeze().cpu().numpy() 102 | if inuv: 103 | uv = f0_pr == 0 104 | if len(f0_pr[~uv]) > 0: 105 | f0_pr[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0_pr[~uv]) 106 | uv = f0_gt == 0 107 | if len(f0_gt[~uv]) > 0: 108 | f0_gt[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0_gt[~uv]) 109 | fig = plt.figure() 110 | plt.plot(f0_gt, color='b', linestyle='-') 111 | plt.plot(f0_pr, color='r', linestyle='-') 112 | self.writer.add_figure(name, fig, self.global_step) 113 | 114 | def get_interval_time(self, update=True): 115 | cur_time = time.time() 116 | time_interval = cur_time - self.last_time 117 | if update: 118 | self.last_time = cur_time 119 | return time_interval 120 | 121 | def get_total_time(self, to_str=True): 122 | total_time = time.time() - self.init_time 123 | if to_str: 124 | total_time = str(datetime.timedelta( 125 | seconds=total_time))[:-5] 126 | return total_time 127 | 128 | def save_model( 129 | self, 130 | model, 131 | optimizer, 132 | name='model', 133 | postfix='', 134 | to_json=False, 135 | config_dict=None): 136 | # path 137 | if postfix: 138 | postfix = '_' + postfix 139 | path_pt = os.path.join( 140 | self.expdir, name + postfix + '.pt') 141 | 142 | # check 143 | print(' [*] model checkpoint saved: {}'.format(path_pt)) 144 | 145 | # save 146 | if optimizer is not None: 147 | torch.save({ 148 | 'global_step': self.global_step, 149 | 'model': model.state_dict(), 150 | 'optimizer': optimizer.state_dict(), 151 | 'config_dict': config_dict 152 | }, path_pt) 153 | else: 154 | torch.save({ 155 | 'global_step': self.global_step, 156 | 'model': model.state_dict(), 157 | 'config_dict': config_dict 158 | }, path_pt) 159 | 160 | # to json 161 | if to_json: 162 | path_json = os.path.join( 163 | self.expdir, name + '.json') 164 | utils.to_json(path_pt, path_json) 165 | 166 | def delete_model(self, name='model', postfix=''): 167 | # path 168 | if postfix: 169 | postfix = '_' + postfix 170 | path_pt = os.path.join( 171 | self.expdir, name + postfix + '.pt') 172 | 173 | # delete 174 | if os.path.exists(path_pt): 175 | os.remove(path_pt) 176 | print(' [*] model checkpoint deleted: {}'.format(path_pt)) 177 | 178 | def global_step_increment(self): 179 | self.global_step += 1 180 | -------------------------------------------------------------------------------- /train/savertools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import json 4 | import torch 5 | 6 | 7 | def traverse_dir( 8 | root_dir, 9 | extensions, 10 | amount=None, 11 | str_include=None, 12 | str_exclude=None, 13 | is_pure=False, 14 | is_sort=False, 15 | is_ext=True): 16 | file_list = [] 17 | cnt = 0 18 | for root, _, files in os.walk(root_dir): 19 | for file in files: 20 | if any([file.endswith(f".{ext}") for ext in extensions]): 21 | # path 22 | mix_path = os.path.join(root, file) 23 | pure_path = mix_path[len(root_dir) + 1:] if is_pure else mix_path 24 | 25 | # amount 26 | if (amount is not None) and (cnt == amount): 27 | if is_sort: 28 | file_list.sort() 29 | return file_list 30 | 31 | # check string 32 | if (str_include is not None) and (str_include not in pure_path): 33 | continue 34 | if (str_exclude is not None) and (str_exclude in pure_path): 35 | continue 36 | 37 | if not is_ext: 38 | ext = pure_path.split('.')[-1] 39 | pure_path = pure_path[:-(len(ext) + 1)] 40 | file_list.append(pure_path) 41 | cnt += 1 42 | if is_sort: 43 | file_list.sort() 44 | return file_list 45 | 46 | 47 | class DotDict(dict): 48 | def __getattr__(*args): 49 | val = dict.get(*args) 50 | return DotDict(val) if type(val) is dict else val 51 | 52 | __setattr__ = dict.__setitem__ 53 | __delattr__ = dict.__delitem__ 54 | 55 | 56 | def get_network_paras_amount(model_dict): 57 | info = dict() 58 | for model_name, model in model_dict.items(): 59 | # all_params = sum(p.numel() for p in model.parameters()) 60 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 61 | 62 | info[model_name] = trainable_params 63 | return info 64 | 65 | 66 | def load_config(path_config): 67 | with open(path_config, "r") as config: 68 | args = yaml.safe_load(config) 69 | args = DotDict(args) 70 | # print(args) 71 | return args 72 | 73 | 74 | def to_json(path_params, path_json): 75 | params = torch.load(path_params, map_location=torch.device('cpu')) 76 | raw_state_dict = {} 77 | for k, v in params.items(): 78 | val = v.flatten().numpy().tolist() 79 | raw_state_dict[k] = val 80 | 81 | with open(path_json, 'w') as outfile: 82 | json.dump(raw_state_dict, outfile, indent="\t") 83 | 84 | 85 | def convert_tensor_to_numpy(tensor, is_squeeze=True): 86 | if is_squeeze: 87 | tensor = tensor.squeeze() 88 | if tensor.requires_grad: 89 | tensor = tensor.detach() 90 | if tensor.is_cuda: 91 | tensor = tensor.cpu() 92 | return tensor.numpy() 93 | 94 | 95 | def load_model( 96 | expdir, 97 | model, 98 | optimizer, 99 | name='model', 100 | postfix='', 101 | device='cpu'): 102 | if postfix == '': 103 | postfix = '_' + postfix 104 | path = os.path.join(expdir, name + postfix) 105 | path_pt = traverse_dir(expdir, ['pt'], is_ext=False) 106 | global_step = 0 107 | if len(path_pt) > 0: 108 | steps = [s[len(path):] for s in path_pt] 109 | maxstep = max([int(s) if s.isdigit() else 0 for s in steps]) 110 | if maxstep >= 0: 111 | path_pt = path + str(maxstep) + '.pt' 112 | else: 113 | path_pt = path + 'best.pt' 114 | print(' [*] restoring model from', path_pt) 115 | ckpt = torch.load(path_pt, map_location=torch.device(device)) 116 | global_step = ckpt['global_step'] 117 | model.load_state_dict(ckpt['model'], strict=False) 118 | if ckpt.get('optimizer') != None and optimizer != None: 119 | optimizer.load_state_dict(ckpt['optimizer']) 120 | return global_step, model, optimizer 121 | -------------------------------------------------------------------------------- /train/solver_wav.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time 4 | from savertools.saver import Saver 5 | from savertools.saver import utils 6 | from torch import autocast 7 | from torch.cuda.amp import GradScaler 8 | 9 | from mir_eval.melody import raw_pitch_accuracy, to_cent_voicing, raw_chroma_accuracy, overall_accuracy 10 | from mir_eval.melody import voicing_recall, voicing_false_alarm 11 | import gc 12 | 13 | USE_MIR = True 14 | 15 | 16 | def test(args, model, loader_test, saver): 17 | print(' [*] testing...') 18 | model.eval() 19 | 20 | # losses 21 | _rpa = _rca = _oa = _vfa = _vr = test_loss = 0. 22 | _num_a = 0 23 | 24 | # intialization 25 | num_batches = len(loader_test) 26 | rtf_all = [] 27 | 28 | # run 29 | with torch.no_grad(): 30 | for bidx, data in enumerate(loader_test): 31 | fn = data[2][0] 32 | print('--------') 33 | print('{}/{} - {}'.format(bidx, num_batches, fn)) 34 | 35 | # unpack data 36 | # for k in data.keys(): 37 | # if not k.startswith('name'): 38 | # data[k] = data[k].to(args.device) 39 | for k in range(len(data)): 40 | if k < 2: 41 | data[k] = data[k].to(args.device) 42 | # print('>>', data[2][0]) 43 | 44 | # forward 45 | st_time = time.time() 46 | f0 = model.infer(mel=data[0]) 47 | ed_time = time.time() 48 | 49 | if USE_MIR: 50 | _f0 = f0.squeeze().cpu().numpy() 51 | _df0 = data[1].squeeze().cpu().numpy() 52 | 53 | time_slice = np.array([i * args.mel.hop_size * 1000 / args.mel.sr for i in range(len(_df0))]) 54 | ref_v, ref_c, est_v, est_c = to_cent_voicing(time_slice, _df0, time_slice, _f0) 55 | 56 | rpa = raw_pitch_accuracy(ref_v, ref_c, est_v, est_c) 57 | rca = raw_chroma_accuracy(ref_v, ref_c, est_v, est_c) 58 | oa = overall_accuracy(ref_v, ref_c, est_v, est_c) 59 | vfa = voicing_false_alarm(ref_v, est_v) 60 | vr = voicing_recall(ref_v, est_v) 61 | 62 | # RTF 63 | run_time = ed_time - st_time 64 | song_time = f0.shape[1] * args.mel.hop_size / args.mel.sr 65 | rtf = run_time / song_time 66 | print('RTF: {} | {} / {}'.format(rtf, run_time, song_time)) 67 | if USE_MIR: 68 | print('RPA: {} | RCA: {} | OA: {} | VFA: {} | VR: {} |'.format(rpa, rca, oa, vfa, vr)) 69 | rtf_all.append(rtf) 70 | 71 | # loss 72 | for i in range(args.train.batch_size): 73 | loss = model.train_and_loss(mel=data[0], gt_f0=data[1], loss_scale=args.loss.loss_scale) 74 | test_loss += loss.item() 75 | 76 | if USE_MIR: 77 | _rpa = _rpa + rpa 78 | _rca = _rca + rca 79 | _oa = _oa + oa 80 | _vfa = _vfa + vfa 81 | _vr = _vr + vr 82 | _num_a = _num_a + 1 83 | 84 | # log mel 85 | saver.log_spec(data[3][0], data[0], data[0]) 86 | 87 | saver.log_f0(data[3][0], f0, data[1]) 88 | saver.log_f0(data[3][0], f0, data[1], inuv=True) 89 | 90 | # report 91 | test_loss /= args.train.batch_size 92 | test_loss /= num_batches 93 | 94 | if USE_MIR: 95 | _rpa /= _num_a 96 | 97 | _rca /= _num_a 98 | 99 | _oa /= _num_a 100 | 101 | _vfa /= _num_a 102 | 103 | _vr /= _num_a 104 | 105 | # check 106 | print(' [test_loss] test_loss:', test_loss) 107 | print(' Real Time Factor', np.mean(rtf_all)) 108 | return test_loss, _rpa, _rca, _oa, _vfa, _vr 109 | 110 | 111 | def train(args, initial_global_step, model, optimizer, scheduler, loader_train, loader_test): 112 | # saver 113 | saver = Saver(args, initial_global_step=initial_global_step) 114 | 115 | # model size 116 | params_count = utils.get_network_paras_amount({'model': model}) 117 | saver.log_info('--- model size ---') 118 | saver.log_info(params_count) 119 | 120 | # run 121 | num_batches = len(loader_train) 122 | model.train() 123 | saver.log_info('======= start training =======') 124 | scaler = GradScaler() 125 | if args.train.amp_dtype == 'fp32': 126 | dtype = torch.float32 127 | elif args.train.amp_dtype == 'fp16': 128 | dtype = torch.float16 129 | elif args.train.amp_dtype == 'bf16': 130 | dtype = torch.bfloat16 131 | else: 132 | raise ValueError(' [x] Unknown amp_dtype: ' + args.train.amp_dtype) 133 | for epoch in range(args.train.epochs): 134 | train_one_epoch(loader_train, saver, optimizer, model, dtype, scaler, epoch, args, scheduler, num_batches, 135 | loader_test) 136 | # 手动gc,防止内存泄漏 137 | gc.collect() 138 | 139 | 140 | def train_one_epoch(loader_train, saver, optimizer, model, dtype, scaler, epoch, args, scheduler, num_batches, 141 | loader_test): 142 | for batch_idx, data in enumerate(loader_train): 143 | train_one_step(batch_idx, data, saver, optimizer, model, dtype, scaler, epoch, args, scheduler, num_batches, 144 | loader_test) 145 | 146 | 147 | def train_one_step(batch_idx, data, saver, optimizer, model, dtype, scaler, epoch, args, scheduler, num_batches, 148 | loader_test): 149 | saver.global_step_increment() 150 | optimizer.zero_grad() 151 | 152 | # unpack data 153 | for k in range(len(data)): 154 | if k < 2: 155 | data[k] = data[k].to(args.device) 156 | # print('>>', data[2][0]) 157 | 158 | # forward 159 | if dtype == torch.float32: 160 | loss = model.train_and_loss(mel=data[0], gt_f0=data[1], loss_scale=args.loss.loss_scale) 161 | else: 162 | with autocast(device_type=args.device, dtype=dtype): 163 | loss = model.train_and_loss(mel=data[0], gt_f0=data[1], loss_scale=args.loss.loss_scale) 164 | 165 | # handle nan loss 166 | if torch.isnan(loss): 167 | # raise ValueError(' [x] nan loss ') 168 | print(' [x] nan loss ') 169 | loss = None 170 | return 171 | else: 172 | # backpropagate 173 | if dtype == torch.float32: 174 | loss.backward() 175 | optimizer.step() 176 | else: 177 | scaler.scale(loss).backward() 178 | scaler.step(optimizer) 179 | scaler.update() 180 | scheduler.step() 181 | 182 | # log loss 183 | if saver.global_step % args.train.interval_log == 0: 184 | current_lr = optimizer.param_groups[0]['lr'] 185 | saver.log_info( 186 | 'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | lr: {:.6} | loss: {:.3f} | time: {} | step: {}'.format( 187 | epoch, 188 | batch_idx, 189 | num_batches, 190 | args.env.expdir, 191 | args.train.interval_log / saver.get_interval_time(), 192 | current_lr, 193 | loss.item(), 194 | saver.get_total_time(), 195 | saver.global_step 196 | ) 197 | ) 198 | 199 | saver.log_value({ 200 | 'train/loss': loss.item() 201 | }) 202 | 203 | saver.log_value({ 204 | 'train/lr': current_lr 205 | }) 206 | 207 | # validation 208 | if saver.global_step % args.train.interval_val == 0: 209 | optimizer_save = optimizer if args.train.save_opt else None 210 | 211 | # save latest 212 | saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}', config_dict=dict(args)) 213 | last_val_step = saver.global_step - args.train.interval_val 214 | if last_val_step % args.train.interval_force_save != 0: 215 | saver.delete_model(postfix=f'{last_val_step}') 216 | 217 | # run testing set 218 | test_loss, rpa, rca, oa, vfa, vr = test(args, model, loader_test, saver) 219 | 220 | # log loss 221 | saver.log_info( 222 | ' --- --- \nloss: {:.3f}. '.format( 223 | test_loss, 224 | ) 225 | ) 226 | 227 | saver.log_value({ 228 | 'validation/loss': test_loss 229 | }) 230 | if USE_MIR: 231 | saver.log_value({ 232 | 'validation/rpa': rpa 233 | }) 234 | saver.log_value({ 235 | 'validation/rca': rca 236 | }) 237 | saver.log_value({ 238 | 'validation/oa': oa 239 | }) 240 | saver.log_value({ 241 | 'validation/vfa': vfa 242 | }) 243 | saver.log_value({ 244 | 'validation/vr': vr 245 | }) 246 | 247 | model.train() 248 | -------------------------------------------------------------------------------- /train/train_wav.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torch.optim import lr_scheduler 4 | from savertools import utils 5 | from data_loaders_wav import get_data_loaders 6 | from solver_wav import train 7 | import torchfcpe 8 | 9 | 10 | def parse_args(args=None, namespace=None): 11 | """Parse command-line arguments.""" 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "-c", 15 | "--config", 16 | type=str, 17 | required=True, 18 | help="path to the config file") 19 | parser.add_argument( 20 | "-j", 21 | "--jump", 22 | type=str, 23 | required=False, 24 | default=None, 25 | help="jump cache for redis") 26 | return parser.parse_args(args=args, namespace=namespace) 27 | 28 | 29 | if __name__ == '__main__': 30 | # parse commands 31 | cmd = parse_args() 32 | 33 | # load config 34 | args = utils.load_config(cmd.config) 35 | print(' > config:', cmd.config) 36 | print(' > exp:', args.env.expdir) 37 | 38 | # load model 39 | model = torchfcpe.spawn_model(args) 40 | 41 | # load parameters 42 | optimizer = torch.optim.AdamW(model.parameters()) 43 | initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device) 44 | for param_group in optimizer.param_groups: 45 | param_group['initial_lr'] = args.train.lr 46 | param_group['lr'] = args.train.lr * args.train.gamma ** max((initial_global_step - 2) // args.train.decay_step, 47 | 0) 48 | param_group['weight_decay'] = args.train.weight_decay 49 | scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma, 50 | last_epoch=initial_global_step - 2) 51 | 52 | # device 53 | if args.device == 'cuda': 54 | torch.cuda.set_device(args.env.gpu_id) 55 | model.to(args.device) 56 | 57 | for state in optimizer.state.values(): 58 | for k, v in state.items(): 59 | if torch.is_tensor(v): 60 | state[k] = v.to(args.device) 61 | 62 | # datas 63 | if (str(cmd.jump) == 'True') or (str(cmd.jump) == 'true'): 64 | jump = True 65 | else: 66 | jump = False 67 | loader_train, loader_valid = get_data_loaders(args) 68 | 69 | # run 70 | train(args, initial_global_step, model, optimizer, scheduler, loader_train, loader_valid) 71 | -------------------------------------------------------------------------------- /train/utils_1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from scipy import signal as sg 4 | import pdb 5 | 6 | 7 | def make_lowshelf(g, fc, Q, fs=44100): 8 | """Generate filter coefficients for 2nd order Lowshelf filter. 9 | This function follows the code from the JUCE DSP library 10 | which can be found in `juce_IIRFilter.cpp`. 11 | 12 | The design equations are based upon those found in the Cookbook 13 | formulae for audio equalizer biquad filter coefficients 14 | by Robert Bristow-Johnson. 15 | https://www.w3.org/2011/audio/audio-eq-cookbook.html 16 | Args: 17 | g (float): Gain factor in dB. 18 | fc (float): Cutoff frequency in Hz. 19 | Q (float): Q factor. 20 | fs (float): Sampling frequency in Hz. 21 | Returns: 22 | tuple: (b, a) filter coefficients 23 | """ 24 | # convert gain from dB to linear 25 | g = np.power(10,(g/20)) 26 | 27 | # initial values 28 | A = np.max([0.0, np.sqrt(g)]) 29 | aminus1 = A - 1 30 | aplus1 = A + 1 31 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs 32 | coso = np.cos(omega) 33 | beta = np.sin(omega) * np.sqrt(A) / Q 34 | aminus1TimesCoso = aminus1 * coso 35 | 36 | # coefs calculation 37 | b0 = A * (aplus1 - aminus1TimesCoso + beta) 38 | b1 = A * 2 * (aminus1 - aplus1 * coso) 39 | b2 = A * (aplus1 - aminus1TimesCoso - beta) 40 | a0 = aplus1 + aminus1TimesCoso + beta 41 | a1 = -2 * (aminus1 + aplus1 * coso) 42 | a2 = aplus1 + aminus1TimesCoso - beta 43 | 44 | # output coefs 45 | #b = np.array([b0/a0, b1/a0, b2/a0]) 46 | #a = np.array([a0/a0, a1/a0, a2/a0]) 47 | 48 | return np.array([[b0/a0, b1/a0, b2/a0, 1.0, a1/a0, a2/a0]]) 49 | 50 | 51 | 52 | def make_highself(g, fc, Q, fs=44100): 53 | """Generate filter coefficients for 2nd order Highshelf filter. 54 | This function follows the code from the JUCE DSP library 55 | which can be found in `juce_IIRFilter.cpp`. 56 | 57 | The design equations are based upon those found in the Cookbook 58 | formulae for audio equalizer biquad filter coefficients 59 | by Robert Bristow-Johnson. 60 | https://www.w3.org/2011/audio/audio-eq-cookbook.html 61 | Args: 62 | g (float): Gain factor in dB. 63 | fc (float): Cutoff frequency in Hz. 64 | Q (float): Q factor. 65 | fs (float): Sampling frequency in Hz. 66 | Returns: 67 | tuple: (b, a) filter coefficients 68 | """ 69 | # convert gain from dB to linear 70 | g = np.power(10,(g/20)) 71 | 72 | # initial values 73 | A = np.max([0.0, np.sqrt(g)]) 74 | aminus1 = A - 1 75 | aplus1 = A + 1 76 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs 77 | coso = np.cos(omega) 78 | beta = np.sin(omega) * np.sqrt(A) / Q 79 | aminus1TimesCoso = aminus1 * coso 80 | 81 | # coefs calculation 82 | b0 = A * (aplus1 + aminus1TimesCoso + beta) 83 | b1 = A * -2 * (aminus1 + aplus1 * coso) 84 | b2 = A * (aplus1 + aminus1TimesCoso - beta) 85 | a0 = aplus1 - aminus1TimesCoso + beta 86 | a1 = 2 * (aminus1 - aplus1 * coso) 87 | a2 = aplus1 - aminus1TimesCoso - beta 88 | 89 | # output coefs 90 | #b = np.array([b0/a0, b1/a0, b2/a0]) 91 | #a = np.array([a0/a0, a1/a0, a2/a0]) 92 | 93 | return np.array([[b0/a0, b1/a0, b2/a0, 1.0, a1/a0, a2/a0]]) 94 | 95 | 96 | 97 | def make_peaking(g, fc, Q, fs=44100): 98 | """Generate filter coefficients for 2nd order Peaking EQ. 99 | This function follows the code from the JUCE DSP library 100 | which can be found in `juce_IIRFilter.cpp`. 101 | 102 | The design equations are based upon those found in the Cookbook 103 | formulae for audio equalizer biquad filter coefficients 104 | by Robert Bristow-Johnson. 105 | https://www.w3.org/2011/audio/audio-eq-cookbook.html 106 | Args: 107 | g (float): Gain factor in dB. 108 | fc (float): Cutoff frequency in Hz. 109 | Q (float): Q factor. 110 | fs (float): Sampling frequency in Hz. 111 | Returns: 112 | tuple: (b, a) filter coefficients 113 | """ 114 | # convert gain from dB to linear 115 | g = np.power(10,(g/20)) 116 | 117 | # initial values 118 | A = np.max([0.0, np.sqrt(g)]) 119 | omega = (2 * np.pi * np.max([fc, 2.0])) / fs 120 | alpha = np.sin(omega) / (Q * 2) 121 | c2 = -2 * np.cos(omega) 122 | alphaTimesA = alpha * A 123 | alphaOverA = alpha / A 124 | 125 | # coefs calculation 126 | b0 = 1 + alphaTimesA 127 | b1 = c2 128 | b2 = 1 - alphaTimesA 129 | a0 = 1 + alphaOverA 130 | a1 = c2 131 | a2 = 1 - alphaOverA 132 | 133 | # output coefs 134 | #b = np.array([b0/a0, b1/a0, b2/a0]) 135 | #a = np.array([a0/a0, a1/a0, a2/a0]) 136 | 137 | return np.array([[b0/a0, b1/a0, b2/a0, 1.0, a1/a0, a2/a0]]) 138 | 139 | 140 | 141 | def params2sos(G, Fc, Q, fs): 142 | """Convert 5 band EQ paramaters to 2nd order sections. 143 | Takes a vector with shape (13,) of denormalized EQ parameters 144 | and calculates filter coefficients for each of the 5 filters. 145 | These coefficients (2nd order sections) are then stored into a 146 | single (5,6) matrix. This matrix can be fed to `scipy.signal.sosfreqz()` 147 | in order to determine the frequency response of the cascasd of 148 | all five biquad filters. 149 | Args: 150 | x (float): Gain factor in dB. 151 | fs (float): Sampling frequency in Hz. 152 | Returns: 153 | ndarray: filter coefficients for 5 band EQ stored in (5,6) matrix. 154 | [[b1_0, b1_1, b1_2, a1_0, a1_1, a1_2], # lowshelf coefficients 155 | [b2_0, b2_1, b2_2, a2_0, a2_1, a2_2], # first band coefficients 156 | [b3_0, b3_1, b3_2, a3_0, a3_1, a3_2], # second band coefficients 157 | [b4_0, b4_1, b4_2, a4_0, a4_1, a4_2], # third band coefficients 158 | [b5_0, b5_1, b5_2, a5_0, a5_1, a5_2]] # highshelf coefficients 159 | """ 160 | # generate filter coefficients from eq params 161 | c0 = make_lowshelf(G[0], Fc[0], Q[0], fs=fs) 162 | c1 = make_peaking (G[1], Fc[1], Q[1], fs=fs) 163 | c2 = make_peaking (G[2], Fc[2], Q[2], fs=fs) 164 | c3 = make_peaking (G[3], Fc[3], Q[3], fs=fs) 165 | c4 = make_peaking (G[4], Fc[4], Q[4], fs=fs) 166 | c5 = make_peaking (G[5], Fc[5], Q[5], fs=fs) 167 | c6 = make_peaking (G[6], Fc[6], Q[6], fs=fs) 168 | c7 = make_peaking (G[7], Fc[7], Q[7], fs=fs) 169 | c8 = make_peaking (G[8], Fc[8], Q[8], fs=fs) 170 | c9 = make_highself(G[9], Fc[9], Q[9], fs=fs) 171 | 172 | # stuff coefficients into second order sections structure 173 | sos = np.concatenate([c0,c1,c2,c3,c4,c5,c6,c7,c8,c9], axis=0) 174 | 175 | return sos 176 | 177 | 178 | import parselmouth 179 | def change_gender(x, fs, lo, hi, ratio_fs, ratio_ps, ratio_pr): 180 | s = parselmouth.Sound(x, sampling_frequency=fs) 181 | f0 = s.to_pitch_ac(pitch_floor=lo, pitch_ceiling=hi, time_step=0.8/lo) 182 | f0_np = f0.selected_array['frequency'] 183 | f0_med = np.median(f0_np[f0_np!=0]).item() 184 | ss = parselmouth.praat.call([s, f0], "Change gender", ratio_fs, f0_med*ratio_ps, ratio_pr, 1.0) 185 | return ss.values.squeeze(0) 186 | 187 | def change_gender_f0(x, fs, lo, hi, ratio_fs, new_f0_med, ratio_pr): 188 | s = parselmouth.Sound(x, sampling_frequency=fs) 189 | ss = parselmouth.praat.call(s, "Change gender", lo, hi, ratio_fs, new_f0_med, ratio_pr, 1.0) 190 | return ss.values.squeeze(0) 191 | -------------------------------------------------------------------------------- /train/utils_all.py: -------------------------------------------------------------------------------- 1 | import colorednoise as cn 2 | import random 3 | import numpy as np 4 | from sklearn.metrics import mean_squared_error 5 | import torch 6 | import librosa 7 | import pyworld 8 | import soundfile 9 | 10 | from utils_1 import params2sos 11 | from scipy.signal import sosfilt 12 | 13 | Qmin, Qmax = 2, 5 14 | 15 | 16 | def add_noise(wav, noise_ratio=0.7, brown_noise_ratio=1.): 17 | beta = random.random() * 2 # the exponent 18 | y = cn.powerlaw_psd_gaussian(beta, wav.shape[0]) 19 | m = np.sqrt(mean_squared_error(wav, np.zeros_like(y))) 20 | 21 | if beta >= 0 and beta <= 1.5: 22 | wav += (noise_ratio * random.random()) * m * y 23 | else: 24 | wav += (brown_noise_ratio * random.random()) * m * y 25 | return wav 26 | 27 | 28 | def add_noise_slice(wav, sr, duration, add_factor=0.50, noise_ratio=0.7, brown_noise_ratio=1.): 29 | slice_length = int(duration * sr) 30 | n_frames = int(wav.shape[-1] // slice_length) 31 | slice_length_noise = int(slice_length * add_factor) 32 | for n in range(n_frames): 33 | left, right = int(n * slice_length), int((n + 1) * slice_length) 34 | offset = random.randint(left, right - slice_length_noise) 35 | if wav[offset:offset + slice_length_noise].shape[0] != 0: 36 | wav[offset:offset + slice_length_noise] = add_noise(wav[offset:offset + slice_length_noise], 37 | noise_ratio=noise_ratio, 38 | brown_noise_ratio=brown_noise_ratio) 39 | return wav 40 | 41 | 42 | def add_mel_mask(mel, iszeropad=False, esp=1e-5): 43 | if iszeropad: 44 | return torch.ones_like(mel) * esp 45 | else: 46 | return (random.random() * 0.9 + 0.1) * torch.randn_like(mel) 47 | 48 | 49 | def add_mel_mask_slice(mel, sr, duration, hop_size=512, add_factor=0.3, vertical_offset=True, vertical_factor=0.05, 50 | iszeropad=True, islog=True, block_num=5, esp=1e-5): 51 | if islog: 52 | mel = torch.exp(mel) 53 | slice_length = int(duration * sr) // hop_size 54 | n_frames = int(mel.shape[-1] // slice_length) 55 | n_mels = mel.shape[-2] 56 | for n in range(n_frames): 57 | line_num = n_mels // block_num 58 | for i in range(block_num): 59 | now_vertical_factor = vertical_factor + random.random() * 0.1 60 | now_add_factor = add_factor + random.random() * 0.1 61 | slice_length_noise = int(slice_length * now_add_factor) 62 | if vertical_offset: 63 | v_offset = int(random.uniform(line_num * i, line_num * (i + 1) - now_vertical_factor)) 64 | n_v_down = v_offset 65 | n_v_up = int(v_offset + now_vertical_factor * n_mels) 66 | else: 67 | n_v_down = 0 68 | n_v_up = n_mels 69 | left, right = int(n * slice_length), int((n + 1) * slice_length) 70 | offset = int(random.uniform(left, right - slice_length_noise)) 71 | if mel[n_v_down:n_v_up, offset:offset + slice_length_noise].shape[-1] != 0: 72 | mel[n_v_down:n_v_up, offset:offset + slice_length_noise] = add_mel_mask( 73 | mel[n_v_down:n_v_up, offset:offset + slice_length_noise], iszeropad, esp) 74 | if islog: 75 | mel = torch.log(torch.clamp(mel, min=esp)) 76 | return mel 77 | 78 | 79 | def random_eq(wav, sr): 80 | rng = np.random.default_rng() 81 | z = rng.uniform(0, 1, size=(10,)) 82 | Q = Qmin * (Qmax / Qmin) ** z 83 | G = rng.uniform(-12, 12, size=(10,)) 84 | Fc = np.exp(np.linspace(np.log(60), np.log(7600), 10)) 85 | sos = params2sos(G, Fc, Q, sr) 86 | wav = sosfilt(sos, wav) 87 | peak = np.abs(wav).max() 88 | if peak > 0.98: 89 | wav = 0.98 * wav / peak 90 | return wav 91 | 92 | 93 | def worldSynthesize(wav, target_sr=44100, hop_length=512, fft_size=2048, f0_in=None): 94 | f0, t = pyworld.dio(wav.astype(np.double), fs=target_sr, frame_period=1000 * hop_length / target_sr) 95 | f0 = pyworld.stonemask(wav.astype(np.double), f0, t, target_sr) 96 | if f0 is not None: 97 | f0 = f0_in.astype(np.double) 98 | ap = pyworld.d4c(wav.astype(np.double), f0, t, target_sr, fft_size=fft_size) 99 | sp = pyworld.cheaptrick(wav.astype(np.double), f0, t, target_sr, fft_size=fft_size) 100 | synthesized = pyworld.synthesize(f0, sp, ap, target_sr, frame_period=1000 * hop_length / target_sr) 101 | 102 | peak = np.abs(synthesized).max() 103 | synthesized = 0.98 * synthesized / peak 104 | 105 | return synthesized, f0 106 | 107 | 108 | # soundfile.write(f'world_{wav_name}.wav', synthesized, target_sr) 109 | # np.save(f"f0_{wav_name}.npy",f0) 110 | 111 | 112 | def add_noise_snb(wav, snb, beta): 113 | # 将信噪比转换为信号与噪声的能量比例 114 | snb = 10 ** (snb / 10) 115 | noise = cn.powerlaw_psd_gaussian(beta, wav.shape[0]) 116 | rms_signal = np.sqrt(np.mean(wav ** 2)) 117 | rms_noise = np.sqrt(np.mean(noise ** 2)) 118 | wav = wav + noise * (rms_signal / rms_noise) / snb 119 | return wav 120 | 121 | 122 | def add_noise_slice_snb(wav, sr, duration, add_factor=0.50, snb=0.7, beta=1.0): 123 | slice_length = int(duration * sr) 124 | n_frames = int(wav.shape[-1] // slice_length) 125 | slice_length_noise = int(slice_length * add_factor) 126 | for n in range(n_frames): 127 | left, right = int(n * slice_length), int((n + 1) * slice_length) 128 | offset = random.randint(left, right - slice_length_noise) 129 | if wav[offset:offset + slice_length_noise].shape[0] != 0: 130 | wav[offset:offset + slice_length_noise] = add_noise_snb(wav[offset:offset + slice_length_noise], snb, beta) 131 | return wav 132 | 133 | 134 | def add_pub_noise_snb(wav, snb): 135 | # 将信噪比转换为信号与噪声的能量比例 136 | import os 137 | noise_path = r'path/to/noise/data/dir' 138 | noise_list = os.listdir(noise_path) 139 | noise_path = os.path.join(noise_path, random.choice(noise_list)) 140 | snb = 10 ** (snb / 10) 141 | noise, sr = librosa.load(noise_path, sr=16000) 142 | if len(wav) > len(noise): 143 | noise = np.tile(noise, len(wav) // len(noise) + 1) 144 | if len(wav) < len(noise): 145 | noise = noise[:len(wav)] 146 | rms_signal = np.sqrt(np.mean(wav ** 2)) 147 | rms_noise = np.sqrt(np.mean(noise ** 2)) 148 | wav = wav + noise * (rms_signal / rms_noise) / snb 149 | return wav 150 | 151 | 152 | def add_pub_noise_snb2(wav, snb): 153 | # 将信噪比转换为信号与噪声的能量比例 154 | import os 155 | noise_path = r'path/to/noise/data/dir' 156 | noise_list = os.listdir(noise_path) 157 | noise_path = os.path.join(noise_path, random.choice(noise_list)) 158 | snb = 10 ** (snb / 10) 159 | noise, sr = librosa.load(noise_path, sr=16000) 160 | if len(wav) > len(noise): 161 | noise = np.tile(noise, len(wav) // len(noise) + 1) 162 | if len(wav) < len(noise): 163 | offset = int(random.choice(range(len(noise) - len(wav)))) 164 | noise = noise[offset:offset + len(wav)] 165 | rms_signal = np.sqrt(np.mean(wav ** 2)) 166 | rms_noise = np.sqrt(np.mean(noise ** 2)) 167 | wav = wav + noise * (rms_signal / rms_noise) / snb 168 | return wav 169 | --------------------------------------------------------------------------------